Merge remote-tracking branch 'upstream/master' into detection_postprocess

This commit is contained in:
Advait Jain 2020-11-18 22:10:22 -08:00
commit 6714af7331
1377 changed files with 46971 additions and 18548 deletions

View File

@ -602,6 +602,10 @@ build:release_windows_common --config=release_common
build:release_windows_common --define=no_tensorflow_py_deps=true
build:release_windows_common --announce_rc
# First available in VS 16.4. Speeds Windows compile times by a lot. See
# https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
build:release_windows_common --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions
build:release_cpu_windows --config=release_windows_common
build:release_gpu_windows --config=release_windows_common

View File

@ -5,7 +5,6 @@
[![Python](https://img.shields.io/pypi/pyversions/tensorflow.svg?style=plastic)](https://badge.fury.io/py/tensorflow)
[![PyPI](https://badge.fury.io/py/tensorflow.svg)](https://badge.fury.io/py/tensorflow)
**`Documentation`** |
------------------- |
[![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) |
@ -61,6 +60,7 @@ commands.
*Nightly binaries are available for testing using the
[tf-nightly](https://pypi.python.org/pypi/tf-nightly) and
[tf-nightly-cpu](https://pypi.python.org/pypi/tf-nightly-cpu) packages on PyPi.*
#### *Try your first TensorFlow program*
```shell
@ -159,8 +159,6 @@ Container Type | Status | Art
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
* [Machine Learning with TensorFlow on GCP](https://www.coursera.org/specializations/machine-learning-tensorflow-gcp)
* [TensorFlow Codelabs](https://codelabs.developers.google.com/?cat=TensorFlow)
* [TensorFlow Chat Room on StackOverflow (not actively monitored by the
TensorFlow team)](https://chat.stackoverflow.com/rooms/216694/tensorflow)
* [TensorFlow Blog](https://blog.tensorflow.org)
* [Learn ML with TensorFlow](https://www.tensorflow.org/resources/learn-ml)
* [TensorFlow Twitter](https://twitter.com/tensorflow)

View File

@ -45,11 +45,26 @@
* Removed deprecated `Interpreter::UseNNAPI(bool)` C++ API.
* Use `NnApiDelegate()` and related delegate configuration methods
directly.
* 16 bits quantization
* Added int16x8 support for ABS, REDUCE_MAX and REDUCE_MIN operators.
* Added support for saved model's session initializer through
`TFLiteConverter.from_saved_model`.
* TF Core:
* Corrected higher-order gradients of control flow constructs (`tf.cond`,
`tf.while_loop`, and compositions like `tf.foldl`) computed with
`tf.GradientTape` inside a `tf.function`.
* `tf.summary`:
* New `tf.summary.graph` allows manual write of TensorFlow graph
(`tf.Graph` or `tf.compat.v1.GraphDef`) as a summary. This is not a
replacement for the trace-based API.
* Set `/d2ReducedOptimizeHugeFunctions` by default for Windows builds. This
provides a big compile-time speedup, and effectively raises the minimum
supported MSVC version to 16.4 (current: 16.8).
* See: https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
## Thanks to our Contributors
This release contains contributions from many people at Google, as well as:

View File

@ -1168,41 +1168,13 @@ def set_system_libs_flag(environ_cp):
write_to_bazelrc('build --define=%s=%s' % (varname, environ_cp[varname]))
def is_reduced_optimize_huge_functions_available(environ_cp):
"""Check to see if the system supports /d2ReducedOptimizeHugeFunctions.
The above compiler flag is a new compiler flag introduced to the Visual Studio
compiler in version 16.4 (available in Visual Studio 2019, Preview edition
only, as of 2019-11-19). TensorFlow needs this flag to massively reduce
compile times, but until 16.4 is officially released, we can't depend on it.
See also
https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
Because it's very annoying to check this manually (to check the MSVC installed
versions, you need to use the registry, and it's not clear if Bazel will be
using that install version anyway), we expect enviroments who know they may
use this flag to export TF_VC_VERSION=16.4
TODO(angerson, gunan): Remove this function when TensorFlow's minimum VS
version is upgraded to 16.4.
Arguments:
environ_cp: Environment of the current execution
Returns:
boolean, whether or not /d2ReducedOptimizeHugeFunctions is available on this
machine.
"""
return float(environ_cp.get('TF_VC_VERSION', '0')) >= 16.4
def set_windows_build_flags(environ_cp):
"""Set Windows specific build options."""
if is_reduced_optimize_huge_functions_available(environ_cp):
write_to_bazelrc(
'build --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions'
)
# First available in VS 16.4. Speeds up Windows compile times by a lot. See
# https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
# pylint: disable=line-too-long
write_to_bazelrc('build --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions')
if get_var(
environ_cp, 'TF_OVERRIDE_EIGEN_STRONG_INLINE', 'Eigen strong inline',

View File

@ -588,9 +588,11 @@ config_setting(
# DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST!
# Instead, please use public APIs or public build rules TF provides.
# If you need functionality that is not exposed, we will work with you to expand our public APIs.
# TODO(b/173549186): Move Google-internal TF code out of learning/brain
package_group(
name = "internal",
packages = [
"//learning/brain/mlir/...",
"//learning/lib/ami/simple_ml/...",
"//tensorflow/...",
],

View File

@ -199,6 +199,7 @@ tf_cuda_library(
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
":logging",
":tf_status",
":tf_tensor",
"@com_google_absl//absl/strings",

View File

@ -51,6 +51,7 @@ tf_cuda_library(
":immediate_execution_context",
":immediate_execution_operation",
":immediate_execution_tensor_handle",
":immediate_execution_distributed_manager",
":abstract_tensor_handle",
":tfe_context_internal",
":tfe_cancellation_manager_internal",
@ -70,6 +71,7 @@ tf_cuda_library(
"//tensorflow/core:core_cpu",
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:context_distributed_manager",
"//tensorflow/core/common_runtime/eager:core",
"//tensorflow/core/common_runtime/eager:eager_executor",
"//tensorflow/core/common_runtime/eager:execute",
@ -119,6 +121,7 @@ filegroup(
"gradients.h",
"gradients_internal.h",
"immediate_execution_context.h",
"immediate_execution_distributed_manager.h",
"immediate_execution_operation.h",
"immediate_execution_tensor_handle.h",
"tape.h",
@ -176,6 +179,7 @@ cc_library(
"//tensorflow/c:c_api_internal",
"//tensorflow/c:conversion_macros",
"//tensorflow/c:tf_status",
"//tensorflow/core:framework",
"//tensorflow/core/platform:casts",
"//tensorflow/core/platform:types",
],
@ -224,6 +228,34 @@ cc_library(
],
)
cc_library(
name = "unified_api_testutil",
testonly = 1,
srcs = [
"unified_api_testutil.cc",
],
hdrs = [
"unified_api_testutil.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":abstract_context",
":abstract_tensor_handle",
":c_api_experimental",
":c_api_test_util",
":c_api_unified_internal",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core:framework",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
"@com_google_absl//absl/container:flat_hash_set",
],
)
tf_cuda_cc_test(
name = "gradients_test",
size = "small",
@ -240,6 +272,7 @@ tf_cuda_cc_test(
":c_api_test_util",
":c_api_unified_internal",
":gradients_internal",
":unified_api_testutil",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper",
@ -260,6 +293,29 @@ tf_cuda_cc_test(
],
)
tf_cuda_cc_test(
name = "unified_api_test",
size = "small",
srcs = [
"unified_api_test.cc",
],
args = ["--heap_check=local"],
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags(),
deps = [
":c_api_experimental",
":c_api_unified_internal",
":unified_api_testutil",
"//tensorflow/c:tf_status_helper",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:errors",
],
)
cc_library(
name = "gradients_util",
srcs = [
@ -449,8 +505,10 @@ cc_library(
"//tensorflow:internal",
],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:refcount",
"//tensorflow/core/platform:status",
],
)
@ -529,6 +587,19 @@ cc_library(
],
)
cc_library(
name = "immediate_execution_distributed_manager",
hdrs = ["immediate_execution_distributed_manager.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "immediate_execution_context",
hdrs = ["immediate_execution_context.h"],
@ -537,12 +608,14 @@ cc_library(
],
deps = [
":abstract_context",
":immediate_execution_distributed_manager",
":immediate_execution_operation",
":immediate_execution_tensor_handle",
"//tensorflow/c:tensor_interface",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],

View File

@ -17,8 +17,10 @@ limitations under the License.
#include <memory>
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/refcount.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
// Abstract interface to a Tensor handle in either tracing or immediate
@ -32,6 +34,9 @@ class AbstractTensorHandle : public core::RefCounted {
public:
// Returns tensor dtype.
virtual tensorflow::DataType DataType() const = 0;
// Returns tensor shape. If tensor has unknown rank, shape remains untouched.
virtual tensorflow::Status Shape(
tensorflow::PartialTensorShape* shape) const = 0;
AbstractTensorHandleKind getKind() const { return kind_; }

View File

@ -21,16 +21,11 @@ limitations under the License.
#include <string>
#include <vector>
#include "tensorflow/c/eager/abstract_tensor_handle.h"
// clang-format off
#include "tensorflow/core/platform/platform.h"
// clang-format on
#include "absl/algorithm/container.h"
#include "absl/memory/memory.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
@ -39,59 +34,39 @@ limitations under the License.
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_tensor_internal.h"
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
#endif
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/device_filters.pb.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/common_runtime/copy_tensor.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/execute.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
#include "tensorflow/core/distributed_runtime/remote_device.h"
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
#endif // !IS_MOBILE_PLATFORM
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/blocking_counter.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/random.h"
#include "tensorflow/core/platform/refcount.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/public/version.h"
// "tensorflow/core/platform/platform.h" must be included first before using
// PLATFORM_GOOGLE, IS_MOBILE_PLATFORM, etc.
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
#if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
#endif // !IS_MOBILE_PLATFORM
using tensorflow::string;
namespace {
@ -100,611 +75,6 @@ string DeviceName(const tensorflow::Device* d) {
return (d == nullptr) ? "cpu:0" : d->name();
}
#if !defined(IS_MOBILE_PLATFORM)
bool AreLocalDevicesCompatible(const tensorflow::EagerContext* context,
const tensorflow::ServerDef& server_def) {
if (server_def.job_name() != context->HostCPU()->parsed_name().job) {
return false;
}
return server_def.default_session_config().SerializeAsString() ==
context->session_options().config.SerializeAsString();
}
tensorflow::Status AddRemoteDevicesToMgr(
const std::vector<string>& added_remote_workers,
tensorflow::WorkerCacheInterface* worker_cache,
tensorflow::DynamicDeviceMgr* remote_device_mgr) {
std::vector<std::unique_ptr<tensorflow::Device>> remote_devices;
tensorflow::mutex remote_devices_mu;
int num_added_workers = added_remote_workers.size();
tensorflow::BlockingCounter counter(num_added_workers);
std::vector<tensorflow::Status> statuses(num_added_workers);
for (int i = 0; i < num_added_workers; i++) {
tensorflow::NewRemoteDevices(
tensorflow::Env::Default(), worker_cache, added_remote_workers[i],
[i, &statuses, &counter, &remote_devices, &remote_devices_mu](
const tensorflow::Status& s,
std::vector<tensorflow::Device*>* devices) {
statuses[i] = s;
if (s.ok()) {
tensorflow::mutex_lock l(remote_devices_mu);
for (tensorflow::Device* d : *devices) {
remote_devices.emplace_back(d);
}
}
counter.DecrementCount();
});
}
counter.Wait();
for (int i = 0; i < num_added_workers; i++) {
TF_RETURN_IF_ERROR(statuses[i]);
}
TF_RETURN_IF_ERROR(remote_device_mgr->AddDevices(std::move(remote_devices)));
return tensorflow::Status::OK();
}
tensorflow::Status GetAllRemoteDevices(
const std::vector<string>& remote_workers,
tensorflow::WorkerCacheInterface* worker_cache,
std::unique_ptr<tensorflow::DynamicDeviceMgr>* device_mgr) {
auto remote_device_mgr = absl::make_unique<tensorflow::DynamicDeviceMgr>();
TF_RETURN_IF_ERROR(AddRemoteDevicesToMgr(remote_workers, worker_cache,
remote_device_mgr.get()));
*device_mgr = std::move(remote_device_mgr);
return tensorflow::Status::OK();
}
tensorflow::Status RemoveRemoteDevicesFromMgr(
const std::vector<string>& removed_remote_workers,
tensorflow::DynamicDeviceMgr* remote_device_mgr) {
const std::vector<tensorflow::Device*> remote_devices =
(remote_device_mgr->ListDevices());
std::vector<tensorflow::Device*> devices_to_remove;
for (tensorflow::Device* d : remote_devices) {
for (const string& remote_worker : removed_remote_workers) {
if (tensorflow::DeviceNameUtils::IsSameAddressSpace(remote_worker,
d->name())) {
devices_to_remove.emplace_back(d);
break;
}
}
}
TF_RETURN_IF_ERROR(remote_device_mgr->RemoveDevices(devices_to_remove));
return tensorflow::Status::OK();
}
tensorflow::Status ListRemoteWorkers(tensorflow::ServerInterface* server,
const string& local_worker,
std::vector<string>* remote_workers) {
tensorflow::GrpcServer* grpc_server =
dynamic_cast<tensorflow::GrpcServer*>(server);
if (grpc_server == nullptr) {
return tensorflow::errors::Internal(
"Currently, TFE_NewContext only supports tensorflow::GrpcServer.");
}
grpc_server->master_env()->worker_cache->ListWorkers(remote_workers);
remote_workers->erase(
std::remove(remote_workers->begin(), remote_workers->end(), local_worker),
remote_workers->end());
return tensorflow::Status::OK();
}
void DifferentiateWorkerLists(const std::vector<string>* current_list,
const std::vector<string>* new_list,
std::vector<string>* added,
std::vector<string>* removed,
std::vector<string>* existing) {
// Get STL set_difference and set_intersection with one list traversal.
// Similar to the set_difference library function, the input lists
// (`current_list` and `new_list`) must be sorted before calling the function.
added->resize(new_list->size());
removed->resize(current_list->size());
existing->resize(current_list->size());
std::vector<string>::const_iterator curr_it = current_list->begin();
std::vector<string>::const_iterator new_it = new_list->begin();
std::vector<string>::iterator added_it = added->begin();
std::vector<string>::iterator removed_it = removed->begin();
std::vector<string>::iterator existing_it = existing->begin();
while (curr_it != current_list->end() && new_it != new_list->end()) {
if (*curr_it < *new_it) {
*removed_it++ = *curr_it++;
} else if (*curr_it > *new_it) {
*added_it++ = *new_it++;
} else {
*existing_it++ = *curr_it++;
new_it++;
}
}
removed_it = std::copy(curr_it, current_list->end(), removed_it);
added_it = std::copy(new_it, new_list->end(), added_it);
added->resize(added_it - added->begin());
removed->resize(removed_it - removed->begin());
existing->resize(existing_it - existing->begin());
}
tensorflow::Status GetReplacedFromExistingWorkers(
const std::vector<string>* existing_workers, tensorflow::uint64 context_id,
tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* client_cache,
std::vector<string>* replaced_workers) {
tensorflow::BlockingCounter counter(existing_workers->size());
std::vector<tensorflow::Status> statuses(existing_workers->size());
tensorflow::eager::KeepAliveRequest request;
request.set_context_id(context_id);
std::vector<tensorflow::eager::KeepAliveResponse> responses(
existing_workers->size());
for (int i = 0; i < existing_workers->size(); i++) {
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
statuses[i] =
client_cache->GetClient(existing_workers->at(i), &eager_client);
if (!statuses[i].ok()) {
counter.DecrementCount();
continue;
}
eager_client->KeepAliveAsync(
&request, &responses[i],
[i, &statuses, &counter](const tensorflow::Status& s) {
statuses[i] = s;
counter.DecrementCount();
});
}
counter.Wait();
for (int i = 0; i < existing_workers->size(); i++) {
// If the RPC fails (indicating that the requested ID doesn't exist on
// remote), or the returned view ID is not equal to the local one
// (indicating that the remote worker has a stale view of cluster), treat
// the worker as replaced.
if (!statuses[i].ok() ||
responses[i].context_view_id() != context_view_id) {
replaced_workers->emplace_back(existing_workers->at(i));
}
}
return tensorflow::Status::OK();
}
tensorflow::Status CreateRemoteContexts(
TFE_Context* ctx, const std::vector<string>& remote_workers,
tensorflow::uint64 context_id, tensorflow::uint64 context_view_id,
int keep_alive_secs, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
const bool lazy_copy_remote_function_inputs,
const tensorflow::eager::CreateContextRequest& base_request) {
int num_remote_workers = remote_workers.size();
tensorflow::BlockingCounter counter(num_remote_workers);
std::vector<tensorflow::Status> statuses(num_remote_workers);
for (int i = 0; i < num_remote_workers; i++) {
const string& remote_worker = remote_workers[i];
tensorflow::DeviceNameUtils::ParsedName parsed_name;
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
&parsed_name)) {
statuses[i] = tensorflow::errors::InvalidArgument(
"Unable to parse ", remote_worker, " as a device name");
counter.DecrementCount();
continue;
}
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
if (eager_client == nullptr) {
statuses[i] = tensorflow::errors::Internal(
"Cannot find a client for the given target:", remote_worker);
}
if (!statuses[i].ok()) {
counter.DecrementCount();
continue;
}
tensorflow::eager::CreateContextRequest request;
tensorflow::eager::CreateContextResponse* response =
new tensorflow::eager::CreateContextResponse();
request.set_context_id(context_id);
request.set_context_view_id(context_view_id);
*request.mutable_server_def() = server_def;
request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task);
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
server_def.default_session_config());
std::vector<bool> filtered_device_mask;
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->FilterDevicesForRemoteWorkers(
remote_worker, base_request.cluster_device_attributes(),
&filtered_device_mask);
DCHECK_EQ(filtered_device_mask.size(),
base_request.cluster_device_attributes_size());
for (int i = 0; i < filtered_device_mask.size(); i++) {
if (filtered_device_mask[i]) {
const auto& da = base_request.cluster_device_attributes(i);
*request.add_cluster_device_attributes() = da;
}
}
request.set_async(async);
request.set_keep_alive_secs(keep_alive_secs);
request.set_lazy_copy_remote_function_inputs(
lazy_copy_remote_function_inputs);
eager_client->CreateContextAsync(
&request, response,
[i, &statuses, &counter, response](const tensorflow::Status& s) {
statuses[i] = s;
delete response;
counter.DecrementCount();
});
}
counter.Wait();
tensorflow::StatusGroup sg;
for (int i = 0; i < num_remote_workers; i++) {
if (TF_PREDICT_FALSE(!statuses[i].ok())) {
sg.Update(statuses[i]);
}
}
return sg.as_summary_status();
}
tensorflow::Status UpdateRemoteContexts(
TFE_Context* ctx, const std::vector<string>& remote_workers,
const std::vector<string>& added_workers,
const std::vector<string>& removed_workers, tensorflow::uint64 context_id,
tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers,
const tensorflow::eager::CreateContextRequest& base_request) {
int num_remote_workers = remote_workers.size();
tensorflow::BlockingCounter counter(num_remote_workers);
std::vector<tensorflow::Status> statuses(num_remote_workers);
int cluster_device_count = base_request.cluster_device_attributes_size();
std::unordered_set<string> added_or_removed(added_workers.begin(),
added_workers.end());
std::copy(removed_workers.begin(), removed_workers.end(),
std::inserter(added_or_removed, added_or_removed.end()));
// Whether each device is in the updated (added or removed) workers
std::vector<bool> device_added_or_removed(cluster_device_count);
for (int i = 0; i < base_request.cluster_device_attributes_size(); i++) {
const auto& da = base_request.cluster_device_attributes().at(i);
tensorflow::DeviceNameUtils::ParsedName pn;
tensorflow::DeviceNameUtils::ParseFullName(da.name(), &pn);
string task_name;
tensorflow::DeviceNameUtils::GetTaskName(pn, &task_name);
if (added_or_removed.find(task_name) != added_or_removed.end()) {
device_added_or_removed[i] = true;
}
}
for (int i = 0; i < num_remote_workers; i++) {
const string& remote_worker = remote_workers[i];
tensorflow::DeviceNameUtils::ParsedName parsed_name;
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
&parsed_name)) {
statuses[i] = tensorflow::errors::InvalidArgument(
"Unable to parse ", remote_worker, " as a device name");
counter.DecrementCount();
continue;
}
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
if (eager_client == nullptr) {
statuses[i] = tensorflow::errors::Internal(
"Cannot find a client for the given target:", remote_worker);
}
if (!statuses[i].ok()) {
counter.DecrementCount();
continue;
}
std::vector<bool> filtered_device_mask;
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->FilterDevicesForRemoteWorkers(
remote_worker, base_request.cluster_device_attributes(),
&filtered_device_mask);
DCHECK_EQ(filtered_device_mask.size(), cluster_device_count);
// If any of the devices that match the device filters are in the set of
// added or removed workers, we must send a complete UpdateContextRequest.
// Otherwise, only send a simple request to increment context view ID.
std::vector<bool> added_or_removed_filtered_devices(cluster_device_count);
std::transform(device_added_or_removed.begin(),
device_added_or_removed.end(), filtered_device_mask.begin(),
added_or_removed_filtered_devices.begin(),
std::logical_and<bool>());
const bool full_update_request =
std::accumulate(added_or_removed_filtered_devices.begin(),
added_or_removed_filtered_devices.end(), false,
std::logical_or<bool>());
tensorflow::eager::UpdateContextRequest request;
auto* response = new tensorflow::eager::UpdateContextResponse();
request.set_context_id(context_id);
request.set_context_view_id(context_view_id);
if (full_update_request) {
*request.mutable_server_def() = server_def;
request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task);
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
server_def.default_session_config());
for (int i = 0; i < cluster_device_count; i++) {
if (filtered_device_mask[i]) {
const auto& da = base_request.cluster_device_attributes(i);
*request.add_cluster_device_attributes() = da;
}
}
}
eager_client->UpdateContextAsync(
&request, response,
[i, &statuses, &counter, response](const tensorflow::Status& s) {
statuses[i] = s;
delete response;
counter.DecrementCount();
});
}
counter.Wait();
for (int i = 0; i < num_remote_workers; i++) {
TF_RETURN_IF_ERROR(statuses[i]);
}
return tensorflow::Status::OK();
}
tensorflow::Status UpdateTFE_ContextWithServerDef(
int keep_alive_secs, const tensorflow::ServerDef& server_def,
TFE_Context* ctx, bool reset_context) {
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
// server object (which currently CHECK-fails) and we miss the error, instead,
// we log the error, and then return to allow the user to see the error
// message.
#define LOG_AND_RETURN_IF_ERROR(...) \
do { \
const ::tensorflow::Status _status = (__VA_ARGS__); \
if (TF_PREDICT_FALSE(!_status.ok())) { \
LOG(ERROR) << _status.error_message(); \
return _status; \
} \
} while (0);
string worker_name =
tensorflow::strings::StrCat("/job:", server_def.job_name(),
"/replica:0/task:", server_def.task_index());
// List of current remote workers before updating server_def. Unused if
// resetting the server_def.
std::vector<string> curr_remote_workers;
// List of updated remote workers.
std::vector<string> remote_workers;
// New server created for new server_def. Unused if updating server_def.
std::unique_ptr<tensorflow::ServerInterface> new_server;
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
tensorflow::GrpcServer* grpc_server;
if (reset_context) {
const tensorflow::DeviceMgr* device_mgr =
AreLocalDevicesCompatible(context, server_def)
? context->local_device_mgr()
: nullptr;
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServerWithOptions(
server_def, {device_mgr}, &new_server));
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
LOG_AND_RETURN_IF_ERROR(
ListRemoteWorkers(new_server.get(), worker_name, &remote_workers));
} else {
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
&curr_remote_workers));
// No need to check the cast here, since `ListRemoteWorkers` already checks
// if the server is a GRPC server or not.
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
LOG_AND_RETURN_IF_ERROR(
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
}
tensorflow::uint64 context_id = context->GetContextId();
tensorflow::uint64 context_view_id = context->GetContextViewId();
if (reset_context) {
context_id = tensorflow::EagerContext::NewContextId();
context_view_id = 0;
// Make master eager context accessible by local eager service, which might
// receive send tensor requests from remote workers.
LOG_AND_RETURN_IF_ERROR(
grpc_server->AddMasterEagerContextToEagerService(context_id, context));
}
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
LOG_AND_RETURN_IF_ERROR(
grpc_server->master_env()->worker_cache->GetEagerClientCache(
&remote_eager_workers));
// For cluster update, use a status group to aggregate statuses from
// * adding and removing remote devices
// * creating remote contexts on newly added workers
// * updating remote contexts on existing workers
// * updating the master context
// Note that we should not return immediately on errors in the middle of these
// updates to prevent cluster from having inconsistent context views.
//
// Unused if `reset_context` is True.
tensorflow::StatusGroup sg;
// When updating an existing context, populate the following lists with:
// * added_workers: set(remote_workers) - set(curr_remote_workers)
// * removed_workers: set(curr_remote_workers) - set(remote_workers)
// * existing_workers: set(curr_remote_workers) intersect set(remote_workers)
// * replaced_workers: workers with the same task names and potentially the
// same `hostname:port`s, but replaced by different processes
std::vector<string> added_workers;
std::vector<string> removed_workers;
std::vector<string> existing_workers;
std::vector<string> replaced_workers;
// New remote device manager created for new server_def. Unused if updating
// server_def.
std::unique_ptr<tensorflow::DynamicDeviceMgr> new_remote_device_mgr;
tensorflow::DynamicDeviceMgr* remote_device_mgr = nullptr;
if (reset_context) {
LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices(
remote_workers, grpc_server->master_env()->worker_cache,
&new_remote_device_mgr));
remote_device_mgr = new_remote_device_mgr.get();
} else {
context->ClearCachesAndDefaultExecutor();
// TODO(b/143914772): Potential memory leak if rendezvous has pending
// tensors for removed / replaced workers.
remote_device_mgr = context->GetOwnedRemoteDeviceMgr();
if (remote_device_mgr == nullptr) {
LOG_AND_RETURN_IF_ERROR(tensorflow::errors::InvalidArgument(
"Updating context with an invalid set of remote devices."));
}
std::sort(curr_remote_workers.begin(), curr_remote_workers.end());
std::sort(remote_workers.begin(), remote_workers.end());
DifferentiateWorkerLists(&curr_remote_workers, &remote_workers,
&added_workers, &removed_workers,
&existing_workers);
sg.Update(GetReplacedFromExistingWorkers(
&existing_workers, context_id, context->GetContextViewId(), server_def,
remote_eager_workers.get(), &replaced_workers));
if (VLOG_IS_ON(1)) {
VLOG(1) << "Updating cluster with following changes";
for (const string& w : added_workers) VLOG(1) << " Added worker " << w;
for (const string& w : removed_workers)
VLOG(1) << " Removed worker " << w;
for (const string& w : replaced_workers)
VLOG(1) << " Replaced worker " << w;
}
if (!replaced_workers.empty()) {
// Treat replaced workers as removed then added back, so that we recreate
// remote devices and contexts, and re-register functions on those workers
removed_workers.insert(removed_workers.end(), replaced_workers.begin(),
replaced_workers.end());
added_workers.insert(added_workers.end(), replaced_workers.begin(),
replaced_workers.end());
for (const string& w : replaced_workers) {
existing_workers.erase(
std::remove(existing_workers.begin(), existing_workers.end(), w),
existing_workers.end());
}
}
sg.Update(RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr));
sg.Update(AddRemoteDevicesToMgr(added_workers,
grpc_server->master_env()->worker_cache,
remote_device_mgr));
}
std::vector<tensorflow::DeviceAttributes> cluster_device_attributes;
remote_device_mgr->ListDeviceAttributes(&cluster_device_attributes);
std::vector<tensorflow::DeviceAttributes> local_device_attributes;
grpc_server->worker_env()->device_mgr->ListDeviceAttributes(
&local_device_attributes);
// This request make sure that we can create Rendezvous properly between
// Local and Remote context.
tensorflow::eager::CreateContextRequest base_request;
for (const auto& da : cluster_device_attributes) {
*base_request.add_cluster_device_attributes() = da;
}
for (const auto& da : local_device_attributes) {
*base_request.add_cluster_device_attributes() = da;
}
// Initialize remote eager workers.
if (reset_context) {
const tensorflow::Status s = CreateRemoteContexts(
ctx, remote_workers, context_id, context_view_id, keep_alive_secs,
server_def, remote_eager_workers.get(), context->Executor().Async(),
context->LazyCopyFunctionRemoteInputs(), base_request);
// NOTE: the remote tasks could fail after `GetAllRemoteDevices` and cause
// the CreateRemoteContexts to fail. We currently only log instead of
// directly returning the error, since returning here will cause the server
// object to be destroyed (which currently CHECK-fails). The client will
// see additional errors if ops are subsequently sent to the failed workers.
if (TF_PREDICT_FALSE(!s.ok())) {
LOG(ERROR) << "Error when creating contexts on remote targets: "
<< s.error_message()
<< "\nExecuting remote ops or functions on these remote "
"targets will fail.";
}
} else {
if (sg.ok()) {
// Create remote contexts on the newly added workers only if the master
// has collected all device information from them (i.e., the
// GetAllRemoteDevices call returns succussfully). Note that in rare cases
// GetAllRemoteDevices can still fail even with RPCs configured to wait
// until the remote workers to become alive. If the master creates remote
// contexts on the workers whose devices are still not collected, those
// workers will be treated as existing workers subsequently, so the master
// will never get devices from them even with retrying UpdateServerDef.
sg.Update(CreateRemoteContexts(
ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
server_def, remote_eager_workers.get(), context->Executor().Async(),
context->LazyCopyFunctionRemoteInputs(), base_request));
}
if (!existing_workers.empty()) {
if (VLOG_IS_ON(1)) {
for (const string& w : existing_workers) {
VLOG(1) << "Updating cluster with existing worker " << w;
}
}
// The master's context_view_id will be incremented by one in the
// UpdateRemoteMaster call later. We want existing workers to also have
// the updated context_view_id, so we must set their context_view_id to
// the master's current context_view_id + 1.
sg.Update(UpdateRemoteContexts(ctx, existing_workers, added_workers,
removed_workers, context_id,
context_view_id + 1, server_def,
remote_eager_workers.get(), base_request));
}
}
auto session_name = tensorflow::strings::StrCat("eager_", context_id);
if (reset_context) {
tensorflow::RemoteRendezvous* r =
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
auto* device_mgr = grpc_server->worker_env()->device_mgr;
std::shared_ptr<tensorflow::WorkerSession> worker_session;
LOG_AND_RETURN_IF_ERROR(
grpc_server->worker_env()->session_mgr->CreateSession(
session_name, server_def, base_request.cluster_device_attributes(),
true));
LOG_AND_RETURN_IF_ERROR(
grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
session_name, &worker_session));
// Initialize remote tensor communication based on worker session.
LOG_AND_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
tensorflow::eager::CreateClusterFLR(context_id, context,
worker_session.get());
auto remote_mgr = absl::make_unique<tensorflow::eager::RemoteMgr>(
/*is_master=*/true, context);
LOG_AND_RETURN_IF_ERROR(context->InitializeRemoteMaster(
std::move(new_server), grpc_server->worker_env(), worker_session,
std::move(remote_eager_workers), std::move(new_remote_device_mgr),
remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr,
std::move(remote_mgr)));
// NOTE: We start the server after all other initialization, because the
// GrpcServer cannot be destroyed after it is started.
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
} else {
sg.Update(grpc_server->worker_env()->session_mgr->UpdateSession(
session_name, server_def, base_request.cluster_device_attributes(),
/*isolate_session_state=*/true));
sg.Update(context->UpdateRemoteMaster(context_id,
std::move(remote_eager_workers),
added_workers, removed_workers));
LOG_AND_RETURN_IF_ERROR(sg.as_summary_status());
}
#undef LOG_AND_RETURN_IF_ERROR
return tensorflow::Status::OK();
}
#endif // !IS_MOBILE_PLATFORM
} // namespace
extern "C" {
@ -735,7 +105,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
#else
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
return nullptr;
#endif
#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
}
std::vector<std::unique_ptr<tensorflow::Device>> devices;
status->status = tensorflow::DeviceFactory::AddDevices(
@ -747,13 +117,18 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get());
return tensorflow::wrap(new tensorflow::EagerContext(
tensorflow::EagerContext* eager_context = new tensorflow::EagerContext(
opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
/*device_mgr_owned*/ true, r));
/*device_mgr_owned*/ true, r);
#if !defined(IS_MOBILE_PLATFORM)
eager_context->SetDistributedManager(
std::make_unique<tensorflow::EagerContextDistributedManager>(
eager_context));
#endif // !IS_MOBILE_PLATFORM
return tensorflow::wrap(eager_context);
}
void TFE_DeleteContext(TFE_Context* ctx) {
@ -791,26 +166,9 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
"Invalid tensorflow.ServerDef protocol buffer");
return;
}
if (server_def.has_cluster_device_filters()) {
const auto& cdf = server_def.cluster_device_filters();
for (const auto& jdf : cdf.jobs()) {
const string remote_prefix = "/job:" + jdf.name() + "/task:";
for (const auto& tdf : jdf.tasks()) {
const int32_t task_index = tdf.first;
std::vector<string> device_filters(tdf.second.device_filters_size());
for (int i = 0; i < tdf.second.device_filters_size(); i++) {
device_filters[i] = tdf.second.device_filters(i);
}
const string remote_worker = remote_prefix + std::to_string(task_index);
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status =
context->SetRemoteDeviceFilters(remote_worker, device_filters);
}
}
}
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
ctx, /*reset_context=*/true);
status->status =
tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
server_def, /*reset_context=*/true, keep_alive_secs);
#endif // !IS_MOBILE_PLATFORM
}
@ -835,14 +193,9 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
status->status = tensorflow::errors::InvalidArgument(
"Trying to update a context with invalid context id.");
}
if (server_def.has_cluster_device_filters()) {
LOG(WARNING) << "Device filters can only be specified when initializing "
"the cluster. Any changes in device filters are ignored "
"when updating the server def.";
}
// TODO(haoyuzhang): Check server_def compatibility before the update
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
ctx, /*reset_context=*/false);
status->status =
tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
server_def, /*reset_context=*/false, keep_alive_secs);
#endif // !IS_MOBILE_PLATFORM
}
@ -854,44 +207,11 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
"TFE_ContextSetServerDef not supported on mobile");
return false;
#else // !defined(IS_MOBILE_PLATFORM)
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
tensorflow::GrpcServer* grpc_server =
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
if (grpc_server == nullptr) {
status->status =
tensorflow::errors::Internal("Failed to get tensorflow::GrpcServer.");
return false;
}
tensorflow::WorkerInterface* wi =
grpc_server->master_env()->worker_cache->GetOrCreateWorker(worker_name);
if (wi == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"Unable to find worker interface corresponding to task ", worker_name);
return false;
}
tensorflow::GetStatusRequest request;
tensorflow::GetStatusResponse response;
tensorflow::Status remote_status;
tensorflow::Notification done;
wi->GetStatusAsync(/*opts_=*/nullptr, &request, &response, /*fail_fast=*/true,
[&remote_status, &done](const tensorflow::Status& s) {
remote_status = s;
done.Notify();
});
done.WaitForNotification();
// We set OK status so the call does not raise any exceptions. Instead, caller
// users the return value to tell if the remote worker is alive.
status->status = tensorflow::Status::OK();
if (remote_status.ok()) {
return true;
}
LOG(INFO) << "Remote worker " << worker_name
<< " is not alive: " << remote_status.error_message();
return false;
bool is_alive;
status->status =
tensorflow::unwrap(ctx)->GetDistributedManager()->CheckRemoteAlive(
worker_name, &is_alive);
return is_alive;
#endif // !IS_MOBILE_PLATFORM
}

View File

@ -134,7 +134,9 @@ TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
}
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
TF_DataType dtype, TF_Status* s) {
TF_DataType dtype, TF_Shape shape,
TF_Status* s) {
DCHECK_GE(shape.num_dims, -1);
TracingTensorHandle* t;
TracingContext* tracing_ctx = dyn_cast<TracingContext>(unwrap(func));
if (!tracing_ctx) {
@ -143,8 +145,20 @@ TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
"TF_AddFunctionParameter must be called on a TracingContext."));
return nullptr;
}
tensorflow::PartialTensorShape partial_shape;
if (shape.num_dims != -1) {
DCHECK(shape.dim_sizes != nullptr);
Status status = tensorflow::PartialTensorShape::MakePartialShape(
reinterpret_cast<tensorflow::int64*>(shape.dim_sizes), shape.num_dims,
&partial_shape);
if (!status.ok()) {
Set_TF_Status_from_Status(s, status);
return nullptr;
}
}
Set_TF_Status_from_Status(
s, tracing_ctx->AddParameter(static_cast<DataType>(dtype), &t));
s, tracing_ctx->AddParameter(static_cast<DataType>(dtype), partial_shape,
&t));
return wrap(t);
}

View File

@ -64,10 +64,16 @@ TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*,
TF_Status* s);
void TF_DeleteExecutionContext(TF_ExecutionContext*);
// Represents a (partially-defined) shape.
typedef struct TF_Shape {
int num_dims; // Must be >= -1; -1 represents unknown rank.
int64_t* dim_sizes;
} TF_Shape;
// Add a new parameter to a TensorFlow Function.
// TODO(aminim): what about shape?
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
TF_DataType dtype, TF_Status* s);
TF_DataType dtype, TF_Shape shape,
TF_Status* s);
// Create an operation suitable to use with the provided context. The operation
// requires its type (e.g. "AddV2") to be set independently.

View File

@ -25,6 +25,8 @@ limitations under the License.
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
@ -43,22 +45,50 @@ class GraphContext;
class GraphOperation;
class GraphTensor;
auto& kUnknownDim = shape_inference::InferenceContext::kUnknownDim;
auto& kUnknownRank = shape_inference::InferenceContext::kUnknownRank;
// GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index
// into the list of outputs for the operation.
class GraphTensor : public TracingTensorHandle {
public:
explicit GraphTensor(TF_Output output)
: TracingTensorHandle(kGraph), output_(output) {}
explicit GraphTensor(TF_Output output, TF_Graph* graph)
: TracingTensorHandle(kGraph), output_(output), graph_(graph) {}
tensorflow::DataType DataType() const override {
return static_cast<tensorflow::DataType>(TF_OperationOutputType(output_));
}
tensorflow::Status Shape(
tensorflow::PartialTensorShape* shape) const override {
DCHECK(shape != nullptr);
TF_Status status;
int num_dims = TF_GraphGetTensorNumDims(graph_, output_, &status);
DCHECK_GE(num_dims, -1);
TF_RETURN_IF_ERROR(StatusFromTF_Status(&status));
if (num_dims == kUnknownRank) {
return Status::OK();
}
std::vector<int64> dims(num_dims, kUnknownDim);
TF_GraphGetTensorShape(graph_, output_,
reinterpret_cast<int64_t*>(dims.data()), num_dims,
&status);
TF_RETURN_IF_ERROR(StatusFromTF_Status(&status));
TF_RETURN_IF_ERROR(tensorflow::TensorShapeUtils::MakeShape(dims, shape));
return Status::OK();
}
TF_Output output_;
// For LLVM style RTTI.
static bool classof(const AbstractTensorHandle* ptr) {
return ptr->getKind() == kGraph;
}
private:
TF_Graph* graph_; // For shape inference.
};
// GraphOperation wraps and populates a TF_OperationDescription.
@ -135,7 +165,7 @@ class GraphOperation : public TracingOperation {
TF_DeleteStatus(s);
*num_retvals = TF_OperationNumOutputs(operation);
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = new GraphTensor({operation, i});
retvals[i] = new GraphTensor({operation, i}, g_);
}
return Status::OK();
}
@ -326,12 +356,18 @@ class GraphContext : public TracingContext {
return new GraphOperation(graph_.get());
}
Status AddParameter(DataType dtype, TracingTensorHandle** output) override {
Status AddParameter(DataType dtype, const PartialTensorShape& shape,
TracingTensorHandle** output) override {
TracingOperationPtr operation(CreateOperation());
TF_RETURN_IF_ERROR(operation->Reset("Placeholder", nullptr));
TF_RETURN_IF_ERROR(
operation->SetOpName(absl::StrCat("_input_", inputs_.size()).c_str()));
TF_RETURN_IF_ERROR(operation->SetAttrType("dtype", dtype));
if (!shape.unknown_rank()) {
TF_RETURN_IF_ERROR(operation->SetAttrShape(
"shape", reinterpret_cast<int64_t*>(shape.dim_sizes().data()),
shape.dims()));
}
int num_outputs = 1;
std::vector<AbstractTensorHandle*> outputs(num_outputs);
TF_RETURN_IF_ERROR(operation->Execute(

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/types.h"
@ -107,7 +108,8 @@ class TracingContext : public AbstractContext {
public:
// Add a function parameter and return the corresponding tensor.
virtual Status AddParameter(DataType dtype, TracingTensorHandle**) = 0;
virtual Status AddParameter(DataType dtype, const PartialTensorShape& shape,
TracingTensorHandle**) = 0;
// Finalize this context and make a function out of it. The context is in a
// invalid state after this call and must be destroyed.

View File

@ -359,7 +359,7 @@ TEST_P(UnifiedCAPI, TestBasicGraph) {
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
auto* placeholder_t =
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract operation.
@ -450,7 +450,7 @@ TEST_P(UnifiedCAPI, TestBasicGraphMatMul) {
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
auto* placeholder_t =
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract operation.
@ -553,9 +553,9 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Create a first "Add" computing `arg0 + arg1`.
@ -709,9 +709,9 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraphMatMul) {
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Create a first "Add" computing `arg0 + arg1`.
@ -975,7 +975,7 @@ TEST_P(UnifiedCAPI, TF_AbstractTensorGetEagerTensorOnGraphTensorRaises) {
// Add a placeholder to the graph.
auto placeholder_t =
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, status.get());
TF_AbstractTensorGetEagerTensor(placeholder_t, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/unified_api_testutil.h"
#include "tensorflow/c/experimental/gradients/array_grad.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
@ -65,6 +66,8 @@ Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Sub", SubRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Mul", MulRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Log1p", Log1pRegisterer));
TF_RETURN_IF_ERROR(registry->Register("DivNoNan", DivNoNanRegisterer));
return Status::OK();
}
@ -73,8 +76,10 @@ Status RegisterGradients(GradientRegistry* registry) {
// return grad(y, {inputs[0], inputs[1]})
Status AddGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
TapeVSpace vspace(ctx);
auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
@ -107,8 +112,10 @@ Status AddGradModel(AbstractContext* ctx,
// return grad(y, {inputs[0]})
Status ExpGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
TapeVSpace vspace(ctx);
auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
@ -137,8 +144,10 @@ Status ExpGradModel(AbstractContext* ctx,
// return grad(y, {inputs[0]})
Status SqrtGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
TapeVSpace vspace(ctx);
auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
@ -168,8 +177,10 @@ Status SqrtGradModel(AbstractContext* ctx,
// This should return [nullptr, 1].
Status IdentityNGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
TapeVSpace vspace(ctx);
auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(ToId(inputs[0]));
@ -202,8 +213,10 @@ Status IdentityNGradModel(AbstractContext* ctx,
// return grad(y, {inputs[0]})
Status NegGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
TapeVSpace vspace(ctx);
auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(ToId(inputs[0]));
@ -233,8 +246,10 @@ Status NegGradModel(AbstractContext* ctx,
// return grad(y, {inputs[0], inputs[1]})
Status SubGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
TapeVSpace vspace(ctx);
auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
@ -267,8 +282,10 @@ Status SubGradModel(AbstractContext* ctx,
// return grad(y, {inputs[0], inputs[1]})
Status MulGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
@ -297,122 +314,72 @@ Status MulGradModel(AbstractContext* ctx,
return Status::OK();
}
AbstractContext* BuildFunction(const char* fn_name) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get());
return unwrap(graph_ctx);
}
// Computes
// y = log(1 + inputs[0])
// return grad(y, {inputs[0]})
Status Log1pGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
std::vector<AbstractTensorHandle*> log1p_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::Log1p(tape_ctx.get(), inputs,
absl::MakeSpan(log1p_outputs),
"Log1p")); // Compute log(1 + x).
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
Status CreateParamsForInputs(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
std::vector<AbstractTensorHandle*>* params) {
tracing::TracingTensorHandle* handle = nullptr;
for (auto input : inputs) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
input->DataType(), &handle));
params->emplace_back(handle);
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(log1p_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto log1p_output : log1p_outputs) {
log1p_output->Unref();
}
outputs[0] = out_grads[0];
delete tape;
return Status::OK();
}
using Model = std::function<Status(
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
absl::Span<AbstractTensorHandle*>, const GradientRegistry&)>;
// Computes
// y = inputs[0] / inputs[1]
// return grad(y, {inputs[0], inputs[1]})
Status DivNoNanGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y.
std::vector<AbstractTensorHandle*> div_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::DivNoNan(tape_ctx.get(), inputs,
absl::MakeSpan(div_outputs),
"DivNoNan")); // Compute x / y.
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
// Runs `model` maybe wrapped in a function.
Status RunModel(Model model, AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
const GradientRegistry& registry) {
if (use_function) {
const char* fn_name = "test_fn";
std::unique_ptr<AbstractFunction> scoped_func;
// Returning null tensors from a tf.function is not supported, so we keep
// track of indices in the model's outputs are nullptr in this set.
// The FunctionDef only outputs the non-null tensors. We later pad the
// function op outputs to have nullptrs at the `null_indices`.
absl::flat_hash_set<int> null_indices;
{
AbstractContextPtr func_ctx(BuildFunction(fn_name));
std::vector<AbstractTensorHandle*> func_inputs;
func_inputs.reserve(inputs.size());
TF_RETURN_IF_ERROR(
CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
vector<AbstractTensorHandle*> model_outputs;
model_outputs.resize(outputs.size());
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
absl::MakeSpan(model_outputs), registry));
for (auto func_input : func_inputs) {
func_input->Unref();
}
AbstractFunction* func = nullptr;
OutputList output_list;
output_list.expected_num_outputs = 0;
output_list.outputs.reserve(outputs.size());
for (int i = 0; i < model_outputs.size(); i++) {
if (model_outputs[i]) {
output_list.outputs.emplace_back(model_outputs[i]);
output_list.expected_num_outputs += 1;
} else {
null_indices.insert(i);
}
}
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
->Finalize(&output_list, &func));
scoped_func.reset(func);
for (auto output : output_list.outputs) {
output->Unref();
}
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
}
AbstractOperationPtr fn_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr));
for (auto input : inputs) {
TF_RETURN_IF_ERROR(fn_op->AddInput(input));
}
int retvals = outputs.size() - null_indices.size();
vector<AbstractTensorHandle*> fn_outputs(retvals);
TF_RETURN_IF_ERROR(fn_op->Execute(
absl::Span<AbstractTensorHandle*>(fn_outputs.data(), fn_outputs.size()),
&retvals));
int skipped_indices = 0;
for (int i = 0; i < outputs.size(); i++) {
if (!null_indices.contains(i)) {
outputs[i] = fn_outputs[i - skipped_indices];
} else {
skipped_indices += 1;
}
}
TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
return Status::OK();
} else {
return model(ctx, inputs, outputs, registry);
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(div_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto div_output : div_outputs) {
div_output->Unref();
}
}
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
*ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get()));
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_DeleteContextOptions(opts);
return Status::OK();
}
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
delete tape;
return Status::OK();
}
@ -467,7 +434,7 @@ TEST_P(CppGradients, TestAddGrad) {
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(AddGradModel, ctx.get(), {x.get(), y.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
/*use_function=*/!std::get<2>(GetParam()));
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
@ -507,18 +474,15 @@ TEST_P(CppGradients, TestExpGrad) {
x.reset(x_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Pseudo-code:
//
// tape.watch(x)
// y = exp(x)
// outputs = tape.gradient(y, x)
std::vector<AbstractTensorHandle*> outputs(1);
s = RunModel(ExpGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
Status s =
RunModel(ExpGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()));
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
@ -551,18 +515,15 @@ TEST_P(CppGradients, TestSqrtGrad) {
x.reset(x_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Pseudo-code:
//
// tape.watch(x)
// y = sqrt(x)
// outputs = tape.gradient(y, x)
std::vector<AbstractTensorHandle*> outputs(1);
s = RunModel(SqrtGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
Status s =
RunModel(SqrtGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()));
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
@ -620,7 +581,7 @@ TEST_P(CppGradients, TestIdentityNGrad) {
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(IdentityNGradModel, ctx.get(), {x1.get(), x2.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
/*use_function=*/!std::get<2>(GetParam()));
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
EXPECT_EQ(outputs[0], nullptr);
@ -665,7 +626,7 @@ TEST_P(CppGradients, TestNegGrad) {
// outputs = tape.gradient(y, x)
std::vector<AbstractTensorHandle*> outputs(1);
s = RunModel(NegGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
/*use_function=*/!std::get<2>(GetParam()));
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
@ -706,10 +667,6 @@ TEST_P(CppGradients, TestSubGrad) {
y.reset(y_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Pseudo-code:
//
// tape.watch(x)
@ -717,9 +674,9 @@ TEST_P(CppGradients, TestSubGrad) {
// y = x - y
// outputs = tape.gradient(y, [x, y])
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(SubGradModel, ctx.get(), {x.get(), y.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
Status s = RunModel(SubGradModel, ctx.get(), {x.get(), y.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()));
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
@ -767,10 +724,6 @@ TEST_P(CppGradients, TestMulGrad) {
y.reset(y_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Pseudo-code:
//
// tape.watch(x)
@ -778,9 +731,9 @@ TEST_P(CppGradients, TestMulGrad) {
// y = x * y
// outputs = tape.gradient(y, [x, y])
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(MulGradModel, ctx.get(), {x.get(), y.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
Status s = RunModel(MulGradModel, ctx.get(), {x.get(), y.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()));
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
@ -800,6 +753,104 @@ TEST_P(CppGradients, TestMulGrad) {
TF_DeleteTensor(result_tensor);
}
TEST_P(CppGradients, TestLog1pGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
// Pseudo-code:
//
// tape.watch(x)
// y = log(1 + x)
// outputs = tape.gradient(y, x)
std::vector<AbstractTensorHandle*> outputs(1);
Status s =
RunModel(Log1pGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()));
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_NEAR(*result_value, 0.5, 0.001);
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
}
TEST_P(CppGradients, TestDivNoNanGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
AbstractTensorHandlePtr y;
{
AbstractTensorHandle* y_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
y.reset(y_raw);
}
// Pseudo-code:
//
// tape.watch(x)
// tape.watch(y)
// y = x / y
// outputs = tape.gradient(y, [x, y])
std::vector<AbstractTensorHandle*> outputs(2);
Status s = RunModel(DivNoNanGradModel, ctx.get(), {x.get(), y.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()));
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_NEAR(*result_value, 0.5, 0.001);
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
s = getValue(outputs[1], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_NEAR(*result_value, -0.25, 0.001);
outputs[1]->Unref();
TF_DeleteTensor(result_tensor);
}
TEST_P(CppGradients, TestSetAttrString) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
@ -224,8 +225,10 @@ Status CreateParamsForInputs(AbstractContext* ctx,
vector<AbstractTensorHandle*>* params) {
tracing::TracingTensorHandle* handle = nullptr;
for (auto input : inputs) {
PartialTensorShape shape;
TF_RETURN_IF_ERROR(input->Shape(&shape));
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
input->DataType(), &handle));
input->DataType(), shape, &handle));
params->emplace_back(handle);
}
return Status::OK();
@ -314,4 +317,4 @@ Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
}
} // namespace gradients
} // namespace tensorflow
} // namespace tensorflow

View File

@ -21,12 +21,15 @@ limitations under the License.
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/immediate_execution_distributed_manager.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/protobuf/config.pb.h"
@ -138,8 +141,8 @@ class ImmediateExecutionContext : public AbstractContext {
}
//===--------------------------------------------------------------------===//
// Following are legacy features in TF Eager Runtime.
// TODO(tf-runtime): Figure out a way to deprecate following features after
// Following are features in current TF Eager Runtime.
// TODO(tfrt-devs): Figure out a way to deprecate following features after
// migrated to TFRT.
//===--------------------------------------------------------------------===//
// Clear pending nodes in thread executors and kernel caches.
@ -157,6 +160,34 @@ class ImmediateExecutionContext : public AbstractContext {
// Update the Eager Executor for current thread.
virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
//===--------------------------------------------------------------------===//
// Following are helper functions to assist integrating TFRT with current
// TF eager runtime.
// TODO(b/172877902): These helper functions are currently used to support
// PyFuncOp on TFRT, and might be useful for ops that directly use low
// level TF APIs. Remove/replace the following functions when TFRT native
// ops are implemented.
//===--------------------------------------------------------------------===//
// Create an abstract tensor handle from tensorflow::Tensor.
virtual ImmediateExecutionTensorHandle* CreateLocalHandleFromTFTensor(
tensorflow::Tensor& t, const char* d_name) = 0;
// Convert a TFRT TensorHandle to tensorflow::TensorHandle.
virtual ImmediateExecutionTensorHandle* TFTensorHandleFromInterface(
ImmediateExecutionTensorHandle* handle) = 0;
//===--------------------------------------------------------------------===//
// Distributed runtime related functions.
//===--------------------------------------------------------------------===//
#if !defined(IS_MOBILE_PLATFORM)
// Set a distributed manager that helps set up, update, and check liveness
// of member tasks in the cluster.
virtual void SetDistributedManager(
std::unique_ptr<ImmediateExecutionDistributedManager> distributed) = 0;
virtual ImmediateExecutionDistributedManager* GetDistributedManager() = 0;
#endif // !IS_MOBILE_PLATFORM
protected:
explicit ImmediateExecutionContext(AbstractContextKind kind)
: AbstractContext(kind) {}

View File

@ -0,0 +1,45 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_immediate_execution_distributed_manager_H_
#define TENSORFLOW_C_EAGER_immediate_execution_distributed_manager_H_
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
class ImmediateExecutionContext;
class ServerDef;
class ImmediateExecutionDistributedManager {
public:
virtual ~ImmediateExecutionDistributedManager() {}
// Set up distributed execution environment on local and remote tasks.
// When `reset_context` is true, initialize new cluster context state based on
// cluster configurations provided in `server_def`; otherwise, update existing
// context state with the provided `server_def`.
// Contexts created on remote tasks will be considered stale and garbage
// collected after `keep_alive_secs` of inactivity.
virtual Status SetOrUpdateServerDef(const ServerDef& server_def,
bool reset_context,
int keep_alive_secs) = 0;
// Check if the remote task is alive.
virtual Status CheckRemoteAlive(const std::string& remote_task_name,
bool* is_alive) = 0;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_immediate_execution_distributed_manager_H_

View File

@ -76,6 +76,7 @@ cc_library(
"//tensorflow/c:c_api",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",

View File

@ -328,6 +328,17 @@ ParallelDevice::Execute(TFE_Context* context,
const char* operation_name,
const TFE_OpAttrs* attributes, int expected_max_outputs,
TF_Status* status) const {
std::vector<PartialTensorShape> expected_output_shapes(expected_max_outputs);
return Execute(context, inputs, operation_name, attributes,
expected_output_shapes, status);
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
ParallelDevice::Execute(
TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
const std::vector<PartialTensorShape>& expected_output_shapes,
TF_Status* status) const {
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
// Compute per-device per-output tensors
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
@ -344,7 +355,7 @@ ParallelDevice::Execute(TFE_Context* context,
}
device_thread->StartExecute(context, operation_name,
std::move(device_inputs), attributes,
expected_max_outputs);
expected_output_shapes.size());
}
StatusPtr first_bad_status(nullptr);
for (int device_index = 0; device_index < underlying_devices_.size();
@ -386,8 +397,15 @@ ParallelDevice::Execute(TFE_Context* context,
for (int j = 0; j < underlying_devices_.size(); ++j) {
components.push_back(std::move(per_device_output_tensors[j][i]));
}
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
*this, std::move(components), status));
if (expected_output_shapes[i].IsFullyDefined()) {
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
*this, std::move(components),
absl::Span<const int64>(expected_output_shapes[i].dim_sizes()),
status));
} else {
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
*this, std::move(components), status));
}
if (TF_GetCode(status) != TF_OK) return result;
}
result.emplace(std::move(per_device_outputs));
@ -396,9 +414,27 @@ ParallelDevice::Execute(TFE_Context* context,
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status) {
std::vector<TensorHandlePtr> components, absl::Span<const int64> shape,
TF_Status* status) {
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
std::vector<int64_t> shape(
// Verify that the TensorHandle's shape and dtype match all of the component
// shapes and dtypes.
for (TensorHandlePtr& component : components) {
if (TFE_TensorHandleDataType(component.get()) != dtype) {
TF_SetStatus(status, TF_INTERNAL,
"Components of a ParallelTensor must all have "
"the same dtype");
return nullptr;
}
}
return std::unique_ptr<ParallelTensor>(
new ParallelTensor(parallel_device, std::move(components), shape, dtype));
}
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status) {
std::vector<int64> shape(
TFE_TensorHandleNumDims(components[0].get(), status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
@ -406,11 +442,10 @@ std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
if (TF_GetCode(status) != TF_OK) return nullptr;
}
// Verify that the TensorHandle's shape and dtype match all of the component
// shapes and dtypes.
// Verify that the TensorHandle's shape matches all of the component shapes.
for (TensorHandlePtr& component : components) {
for (int i = 0; i < shape.size(); ++i) {
int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
int64 tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (tensor_dim != shape[i]) {
// TODO(allenl): Allow shapes to differ.
@ -419,17 +454,10 @@ std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
"the same shape");
return nullptr;
}
if (TFE_TensorHandleDataType(component.get()) != dtype) {
TF_SetStatus(status, TF_INTERNAL,
"Components of a ParallelTensor must all have "
"the same dtype");
return nullptr;
}
}
}
return std::unique_ptr<ParallelTensor>(new ParallelTensor(
parallel_device, std::move(components), std::move(shape), dtype));
return FromTensorHandles(parallel_device, std::move(components),
absl::Span<const int64>(shape), status);
}
} // namespace parallel_device

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
namespace parallel_device {
@ -93,6 +94,15 @@ class ParallelDevice {
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const;
// Accepts inferred shapes for outputs, which if fully defined will avoid
// querying the shapes of the underlying TensorHandles. This allows async
// computation to continue without blocking.
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute(
TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
const std::vector<PartialTensorShape>& expected_output_shapes,
TF_Status* status) const;
private:
// A sequence of device names, indicating which devices replicated operations
// are forwarded to.
@ -117,10 +127,15 @@ class ParallelDevice {
class ParallelTensor {
public:
// Construct a ParallelTensor from TensorHandles placed on the component
// devices of a ParallelDevice.
// devices of a ParallelDevice. Inspects `components` to determine a shape.
static std::unique_ptr<ParallelTensor> FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status);
// Uses the provided shape without additional checks, which avoids blocking.
static std::unique_ptr<ParallelTensor> FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, absl::Span<const int64> shape,
TF_Status* status);
size_t num_tensors() const { return tensors_.size(); }
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
@ -132,10 +147,10 @@ class ParallelTensor {
private:
ParallelTensor(const ParallelDevice& device,
std::vector<TensorHandlePtr> tensors,
std::vector<int64_t> shape, const TF_DataType dtype)
absl::Span<const int64> shape, const TF_DataType dtype)
: device_(device),
tensors_(std::move(tensors)),
shape_(std::move(shape)),
shape_(shape.begin(), shape.end()),
dtype_(dtype) {}
const ParallelDevice& device_;

View File

@ -80,5 +80,41 @@ TEST(PARALLEL_DEVICE_LIB, TestOpWithError) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
TEST(PARALLEL_DEVICE_LIB, TestExplicitOutputShape) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*xla*/ false,
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
2),
TF_DeleteBuffer);
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
status.get());
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::vector<std::string> devices{
"/job:localhost/replica:0/task:0/device:CPU:0",
"/job:localhost/replica:0/task:0/device:CPU:1"};
ParallelDevice parallel_device(std::move(devices));
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> handle_op(
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetAttrType(handle_op.get(), "dtype", TF_FLOAT);
TFE_OpSetAttrShape(handle_op.get(), "shape", /*dims=*/nullptr, /*num_dims=*/0,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
auto outputs = parallel_device.Execute(
context.get(), std::vector<ParallelTensor*>(), "VarHandleOp",
TFE_OpGetAttrs(handle_op.get()), {PartialTensorShape({})}, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const std::vector<std::unique_ptr<ParallelTensor>>& handles = *outputs;
EXPECT_EQ(0, handles[0]->shape().size());
}
} // namespace parallel_device
} // namespace tensorflow

View File

@ -0,0 +1,205 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/unified_api_testutil.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
class UnifiedAPI
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected:
void SetUp() override {
TF_StatusPtr status(TF_NewStatus());
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get());
CHECK_EQ(errors::OK, s.code()) << s.error_message();
}
public:
bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; }
bool UseFunction() const { return std::get<2>(GetParam()); }
};
// Checks that inputs[0] is a scalar.
Status TestScalarShape(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
PartialTensorShape shape;
TF_RETURN_IF_ERROR(inputs[0]->Shape(&shape));
if (shape.dims() != 0) {
return errors::InvalidArgument(
"Tensor expected to have scalar shape found rank: ", shape.dims());
}
return Status::OK();
}
TEST_P(UnifiedAPI, TestTensorShapeScalar) {
if (UseFunction() && UseMlir()) {
// TODO(b/173074167): Remove this.
GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
}
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
Status s = RunModel(TestScalarShape, ctx.get(),
/*inputs=*/{x.get()},
/*outputs=*/{},
/*use_function=*/UseFunction());
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
}
// Checks that inputs[0] is a matrix with shape 2x4.
Status TestTensorShape2x4(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
PartialTensorShape shape;
TF_RETURN_IF_ERROR(inputs[0]->Shape(&shape));
if (shape.dims() != 2) {
return errors::InvalidArgument(
"Tensor expected to have rank 2 found rank: ", shape.dims());
}
int64 dim_sizes[] = {2, 4};
for (int i = 0; i < shape.dims(); i++) {
if (shape.dim_size(i) != dim_sizes[i]) {
return errors::InvalidArgument("Dim ", i, " expected to be of size ",
dim_sizes[i],
" found: ", shape.dim_size(i));
}
}
return Status::OK();
}
TEST_P(UnifiedAPI, TestTensorShape2x4) {
if (UseFunction() && UseMlir()) {
// TODO(b/173074167): Remove this.
GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
}
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
float data[] = {0., 0., 0., 0., 0., 0., 0., 0};
int64 dim_sizes[] = {2, 4};
Status s =
TestTensorHandleWithDimsFloat(ctx.get(), data, dim_sizes, 2, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
Status s = RunModel(TestTensorShape2x4, ctx.get(),
/*inputs=*/{x.get()},
/*outputs=*/{},
/*use_function=*/UseFunction());
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
}
TEST_P(UnifiedAPI, TestUnknownShapeTracing) {
if (!UseFunction()) {
GTEST_SKIP() << "Tracing only test.";
}
if (UseMlir()) {
// TODO(b/173074167): Remove this.
GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
}
AbstractContextPtr ctx(BuildFunction("test_fn"));
AbstractTensorHandlePtr x;
{
tracing::TracingTensorHandle* x_raw = nullptr;
PartialTensorShape shape;
Status s = dyn_cast<tracing::TracingContext>(ctx.get())->AddParameter(
DT_FLOAT, shape, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
PartialTensorShape shape;
Status s = x->Shape(&shape);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ASSERT_TRUE(shape.unknown_rank());
}
TEST_P(UnifiedAPI, TestPartialShapeTracing) {
if (!UseFunction()) {
GTEST_SKIP() << "Tracing only test.";
}
if (UseMlir()) {
GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
}
AbstractContextPtr ctx(BuildFunction("test_fn"));
AbstractTensorHandlePtr x;
{
tracing::TracingTensorHandle* x_raw = nullptr;
PartialTensorShape shape;
int64 dim_sizes[] = {2, -1};
Status s = PartialTensorShape::MakePartialShape(dim_sizes, 2, &shape);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
s = dyn_cast<tracing::TracingContext>(ctx.get())->AddParameter(
DT_FLOAT, shape, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
PartialTensorShape shape;
Status s = x->Shape(&shape);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ASSERT_FALSE(shape.unknown_rank());
ASSERT_EQ(2, shape.dim_size(0));
ASSERT_EQ(-1, shape.dim_size(1));
}
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(
UnifiedCppAPI, UnifiedAPI,
::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(true, false),
/*use_function*/ ::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(
UnifiedCppAPI, UnifiedAPI,
::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(false),
/*use_function*/ ::testing::Values(true, false)));
#endif
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,161 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/unified_api_testutil.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
AbstractContext* BuildFunction(const char* fn_name) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get());
return unwrap(graph_ctx);
}
Status CreateParamsForInputs(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
std::vector<AbstractTensorHandle*>* params) {
tracing::TracingTensorHandle* handle = nullptr;
for (auto input : inputs) {
PartialTensorShape shape;
TF_RETURN_IF_ERROR(input->Shape(&shape));
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
input->DataType(), shape, &handle));
params->emplace_back(handle);
}
return Status::OK();
}
// Runs `model` maybe wrapped in a function.
Status RunModel(Model model, AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, bool use_function) {
if (use_function) {
const char* fn_name = "test_fn";
std::unique_ptr<AbstractFunction> scoped_func;
// Returning null tensors from a tf.function is not supported, so we keep
// track of indices in the model's outputs are nullptr in this set.
// The FunctionDef only outputs the non-null tensors. We later pad the
// function op outputs to have nullptrs at the `null_indices`.
absl::flat_hash_set<int> null_indices;
{
AbstractContextPtr func_ctx(BuildFunction(fn_name));
std::vector<AbstractTensorHandle*> func_inputs;
func_inputs.reserve(inputs.size());
TF_RETURN_IF_ERROR(
CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
std::vector<AbstractTensorHandle*> model_outputs;
model_outputs.resize(outputs.size());
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
absl::MakeSpan(model_outputs)));
for (auto func_input : func_inputs) {
func_input->Unref();
}
AbstractFunction* func = nullptr;
OutputList output_list;
output_list.expected_num_outputs = 0;
output_list.outputs.reserve(outputs.size());
for (int i = 0; i < model_outputs.size(); i++) {
if (model_outputs[i]) {
output_list.outputs.emplace_back(model_outputs[i]);
output_list.expected_num_outputs += 1;
} else {
null_indices.insert(i);
}
}
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
->Finalize(&output_list, &func));
scoped_func.reset(func);
for (auto output : output_list.outputs) {
output->Unref();
}
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
}
AbstractOperationPtr fn_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr));
for (auto input : inputs) {
TF_RETURN_IF_ERROR(fn_op->AddInput(input));
}
int retvals = outputs.size() - null_indices.size();
std::vector<AbstractTensorHandle*> fn_outputs(retvals);
TF_RETURN_IF_ERROR(fn_op->Execute(
absl::Span<AbstractTensorHandle*>(fn_outputs.data(), fn_outputs.size()),
&retvals));
int skipped_indices = 0;
for (int i = 0; i < outputs.size(); i++) {
if (!null_indices.contains(i)) {
outputs[i] = fn_outputs[i - skipped_indices];
} else {
skipped_indices += 1;
}
}
TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
return Status::OK();
} else {
return model(ctx, inputs, outputs);
}
}
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
*ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get()));
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_DeleteContextOptions(opts);
return Status::OK();
}
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return Status::OK();
}
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float* data,
int64* dims, int num_dims,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager = TestTensorHandleWithDimsFloat(
eager_ctx, data, reinterpret_cast<int64_t*>(dims), num_dims);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,61 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_UNIFIED_API_TESTUTIL_H_
#define TENSORFLOW_C_EAGER_UNIFIED_API_TESTUTIL_H_
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
// Builds and returns a `TracingContext` using the default tracing impl.
AbstractContext* BuildFunction(const char* fn_name);
// Creates parameters (placeholders) in the tracing `ctx` using the shape and
// dtype of `inputs`.
Status CreateParamsForInputs(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
std::vector<AbstractTensorHandle*>* params);
// A callable that takes tensor inputs and returns zero or more tensor outputs.
using Model = std::function<Status(AbstractContext*,
absl::Span<AbstractTensorHandle* const>,
absl::Span<AbstractTensorHandle*>)>;
// Runs `model` maybe wrapped in a function call op. This can be thought as
// being equivalent to the following python code.
//
// if use_function:
// outputs = tf.function(model)(inputs)
// else:
// outputs = model(inputs)
Status RunModel(Model model, AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, bool use_function);
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
// Get a Scalar TensorHandle with given float value.
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
AbstractTensorHandle** tensor);
// Get a Matrix TensorHandle with given float values and dimensions.
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float* data,
int64* dims, int num_dims,
AbstractTensorHandle** tensor);
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_UNIFIED_API_TESTUTIL_H_

View File

@ -21,10 +21,14 @@ limitations under the License.
#include "tensorflow/c/experimental/ops/nn_ops.h"
using std::vector;
using tensorflow::ops::Add;
using tensorflow::ops::Conj;
using tensorflow::ops::Div;
using tensorflow::ops::DivNoNan;
using tensorflow::ops::MatMul;
using tensorflow::ops::Mul;
using tensorflow::ops::Neg;
using tensorflow::ops::OnesLike;
using tensorflow::ops::SqrtGrad;
namespace tensorflow {
@ -289,6 +293,117 @@ class MulGradientFunction : public GradientFunction {
vector<AbstractTensorHandle*> forward_inputs;
};
class Log1pGradientFunction : public GradientFunction {
public:
explicit Log1pGradientFunction(vector<AbstractTensorHandle*> f_inputs)
: forward_inputs(f_inputs) {}
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
// TODO(vnvo2409): Add control dependency
/* Given upstream grad U and a Log1p op: Y = log(1 + X), the gradients are:
*
* dX = U / (1 + X)
*
*/
AbstractTensorHandle* upstream_grad = grad_inputs[0];
AbstractTensorHandle* X = forward_inputs[0];
grad_outputs->resize(1);
vector<AbstractTensorHandle*> temp_outputs(1);
// Calculate conjugate of X
std::string name = "Conj_Log1p_Grad_X";
TF_RETURN_IF_ERROR(
Conj(ctx->ctx, {X}, absl::MakeSpan(temp_outputs), name.c_str()));
AbstractTensorHandle* Conj_X = temp_outputs[0];
// Creates Ones
name = "OnesLike_Log1p_Grad_X";
TF_RETURN_IF_ERROR(OnesLike(ctx->ctx, {Conj_X},
absl::MakeSpan(temp_outputs), name.c_str()));
AbstractTensorHandle* Ones_X = temp_outputs[0];
name = "Add_Log1p_Grad_X";
// Calculate 1 + Conj(X)
TF_RETURN_IF_ERROR(Add(ctx->ctx, {Ones_X, Conj_X},
absl::MakeSpan(temp_outputs), name.c_str()));
AbstractTensorHandle* Conj_XP1 = temp_outputs[0];
name = "Div_Log1p_Grad_X";
// Calculate U / (1 + Conj(X))
TF_RETURN_IF_ERROR(Div(ctx->ctx, {upstream_grad, Conj_XP1},
absl::MakeSpan(temp_outputs), name.c_str()));
(*grad_outputs)[0] = temp_outputs[0];
return Status::OK();
}
~Log1pGradientFunction() override {}
private:
vector<AbstractTensorHandle*> forward_inputs;
};
class DivNoNanGradientFunction : public GradientFunction {
public:
explicit DivNoNanGradientFunction(vector<AbstractTensorHandle*> f_inputs,
vector<AbstractTensorHandle*> f_outputs)
: forward_inputs(f_inputs), forward_outputs(f_outputs) {}
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
// TODO(vnvo2409): Add shape broadcasting
/* Given upstream grad U and a Div op: Z = X/Y, the gradients are:
*
* dX = U / Y
* dY = -U*X / Y^2 = (X/Y) * -U / Y = -U*Z / Y
*
*/
AbstractTensorHandle* upstream_grad = grad_inputs[0];
AbstractTensorHandle* Y = forward_inputs[1];
AbstractTensorHandle* Z = forward_outputs[0];
grad_outputs->resize(2);
vector<AbstractTensorHandle*> temp_outputs(1);
// Calculate dX = U / Y
std::string name = "Div_Grad_X";
TF_RETURN_IF_ERROR(DivNoNan(ctx->ctx, {upstream_grad, Y},
absl::MakeSpan(temp_outputs), name.c_str()));
(*grad_outputs)[0] = temp_outputs[0];
// Calculate dY = -U*Z / Y
name = "Neg_Div_Grad_Y";
TF_RETURN_IF_ERROR(Neg(ctx->ctx, {upstream_grad},
absl::MakeSpan(temp_outputs), name.c_str())); // -U
AbstractTensorHandle* MinusU = temp_outputs[0];
name = "Mul_Div_Grad_Y";
TF_RETURN_IF_ERROR(Mul(ctx->ctx, {MinusU, Z}, absl::MakeSpan(temp_outputs),
name.c_str())); // -U*Z
AbstractTensorHandle* UZ = temp_outputs[0];
name = "Div_Grad_Y";
TF_RETURN_IF_ERROR(DivNoNan(ctx->ctx, {UZ, Y}, absl::MakeSpan(temp_outputs),
name.c_str())); // -U*Z / Y
(*grad_outputs)[1] = temp_outputs[0];
return Status::OK();
}
~DivNoNanGradientFunction() override {}
private:
vector<AbstractTensorHandle*> forward_inputs;
vector<AbstractTensorHandle*> forward_outputs;
};
} // namespace
BackwardFunction* AddRegisterer(const ForwardOperation& op) {
@ -354,5 +469,23 @@ BackwardFunction* MulRegisterer(const ForwardOperation& op) {
return new BackwardFunction(gradient_function, default_gradients);
}
BackwardFunction* Log1pRegisterer(const ForwardOperation& op) {
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto gradient_function = new Log1pGradientFunction(op.inputs);
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
BackwardFunction* DivNoNanRegisterer(const ForwardOperation& op) {
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto gradient_function = new DivNoNanGradientFunction(op.inputs, op.outputs);
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
} // namespace gradients
} // namespace tensorflow

View File

@ -27,6 +27,8 @@ BackwardFunction* SqrtRegisterer(const ForwardOperation& op);
BackwardFunction* NegRegisterer(const ForwardOperation& op);
BackwardFunction* SubRegisterer(const ForwardOperation& op);
BackwardFunction* MulRegisterer(const ForwardOperation& op);
BackwardFunction* Log1pRegisterer(const ForwardOperation& op);
BackwardFunction* DivNoNanRegisterer(const ForwardOperation& op);
} // namespace gradients
} // namespace tensorflow

View File

@ -81,5 +81,17 @@ Status ExpandDims(AbstractContext* ctx,
return op->Execute(outputs, &num_retvals);
}
Status OnesLike(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(op.get(), name));
TF_RETURN_IF_ERROR(op->AddInput(inputs[0]));
int num_retvals = 1;
return op->Execute(outputs, &num_retvals);
}
} // namespace ops
} // namespace tensorflow

View File

@ -42,6 +42,10 @@ Status ExpandDims(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status OnesLike(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
} // namespace ops
} // namespace tensorflow

View File

@ -44,8 +44,18 @@ Status Conj(AbstractContext* ctx,
if (DataTypeIsFloating(BaseType(dtype)) ||
DataTypeIsInteger(BaseType(dtype))) {
TF_RETURN_IF_ERROR(Identity(ctx, inputs, outputs, name));
} else if (DataTypeIsComplex(BaseType(dtype)) ||
BaseType(dtype) == DT_VARIANT) {
AbstractOperationPtr conj_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(conj_op->Reset("Conj", /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(conj_op.get(), name));
TF_RETURN_IF_ERROR(conj_op->AddInput(inputs[0]));
int num_retvals = 1;
TF_RETURN_IF_ERROR(conj_op->Execute(outputs, &num_retvals));
} else {
return errors::Unimplemented("Conj does not support complex types yet.");
return errors::InvalidArgument(
"Expected numeric or variant tensor, got dtype ", dtype);
}
return Status::OK();
}
@ -118,6 +128,19 @@ Status Sum(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
return Status::OK();
}
Status Div(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr div_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(div_op->Reset("Div", /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(div_op.get(), name));
TF_RETURN_IF_ERROR(div_op->AddInput(inputs[0])); // x
TF_RETURN_IF_ERROR(div_op->AddInput(inputs[1])); // y
int num_retvals = 1;
TF_RETURN_IF_ERROR(div_op->Execute(outputs, &num_retvals)); // z = x / y
return Status::OK();
}
Status DivNoNan(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
@ -172,5 +195,18 @@ Status SqrtGrad(AbstractContext* ctx,
return s;
}
Status Log1p(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr log1p_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(log1p_op->Reset("Log1p", /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(log1p_op.get(), name));
TF_RETURN_IF_ERROR(log1p_op->AddInput(inputs[0]));
int num_retvals = 1;
Status s = log1p_op->Execute(outputs, &num_retvals);
return s;
}
} // namespace ops
} // namespace tensorflow

View File

@ -44,6 +44,9 @@ Status Sum(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
Status Sub(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Div(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status DivNoNan(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
@ -59,6 +62,10 @@ Status SqrtGrad(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Log1p(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
} // namespace ops
} // namespace tensorflow

View File

@ -115,8 +115,8 @@ genrule(
# have control of the full GPU.
cmd = "CUDA_VISIBLE_DEVICES='' " +
"$(location :make_test_graphs) --out_dir $(@D)",
exec_tools = [":make_test_graphs"],
tags = ["manual"],
tools = [":make_test_graphs"],
)
tf_library(

View File

@ -127,7 +127,7 @@ def tf_library(
"$(location " + tfcompile_tool + ")" +
" --config=$(location " + config + ")" +
" --dump_fetch_nodes > $@"),
exec_tools = [tfcompile_tool],
tools = [tfcompile_tool],
# Run tfcompile on the build host, rather than forge, since it's
# typically way faster on the local machine.
local = 1,
@ -162,7 +162,7 @@ def tf_library(
"//tensorflow/python/tools:freeze_graph)" +
freeze_args
),
exec_tools = ["//tensorflow/python/tools:freeze_graph"],
tools = ["//tensorflow/python/tools:freeze_graph"],
tags = tags,
)
tfcompile_graph = freeze_file
@ -242,7 +242,7 @@ def tf_library(
" --out_function_object=$(@D)/" + function_object_file +
" " + flags + " " + profiling_flag + " " + mlir_flag + " " + traceme_flag
),
exec_tools = [tfcompile_tool],
tools = [tfcompile_tool],
visibility = visibility,
testonly = testonly,
# Run tfcompile on the build host since it's typically faster on the
@ -281,7 +281,7 @@ def tf_library(
" --out_session_module=$(@D)/" + session_module_pb +
" " + flags
),
exec_tools = [tfcompile_tool],
tools = [tfcompile_tool],
visibility = visibility,
testonly = testonly,
local = 1,

View File

@ -508,7 +508,7 @@ cc_library(
":flags",
":jit_compilation_passes",
"//tensorflow/compiler/jit/kernels:xla_ops_no_jit_rewrite_registration",
"//tensorflow/compiler/mlir:mlir_bridge_rollout_policy",
"//tensorflow/compiler/tf2xla:mlir_bridge_pass",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla:xla_op_registry",
"//tensorflow/core:core_cpu_internal",

View File

@ -115,7 +115,7 @@ xla::StatusOr<std::string> GetCompilerIr(
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
XlaComputationLaunchContext::BuildXlaCompilerArguments(
constant_arg_indices, inputs, variable_infos);
constant_arg_indices, inputs, variable_infos, dev);
TF_RETURN_IF_ERROR(args.status());
switch (stage) {

View File

@ -206,8 +206,9 @@ static Status CompileToLocalExecutable(
may_alias_resource_update;
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
XlaComputationLaunchContext::BuildXlaCompilerArguments(constants, inputs,
variable_infos);
XlaComputationLaunchContext::BuildXlaCompilerArguments(
constants, inputs, variable_infos,
static_cast<Device*>(ctx->device()));
TF_RETURN_IF_ERROR(args.status());
return cache->Compile(options, function, *args, compile_options,
lazy ? XlaCompilationCache::CompileMode::kLazy
@ -246,8 +247,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
VLOG(1) << "Executing XLA Computation...";
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
se::DeviceMemoryAllocator* allocator = GetAllocator(
&tf_allocator_adapter, ctx->device(),

View File

@ -1990,6 +1990,8 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
"StatelessCase",
"StatelessIf",
"StatelessMultinomial",
"StatelessRandomGetAlg",
"StatelessRandomGetKeyCounter",
"StatelessRandomGetKeyCounterAlg",
"StatelessRandomNormal",
"StatelessRandomNormalV2",
@ -2040,6 +2042,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
"UnsortedSegmentSum",
"VarIsInitializedOp",
"VariableShape",
"Where",
"While",
"XlaBroadcastHelper",
"XlaConv",

View File

@ -140,6 +140,7 @@ XlaCompilationCache::BuildSignature(
for (const XlaCompiler::Argument& arg : args) {
switch (arg.kind) {
case XlaCompiler::Argument::kConstant:
case XlaCompiler::Argument::kConstantResource:
signature.arg_values.push_back(arg.constant_value);
break;
case XlaCompiler::Argument::kParameter:
@ -288,7 +289,7 @@ Status XlaCompilationCache::CompileSingleOp(
const ConfigProto* config = ctx->function_library()->config_proto();
// TODO(b/171039585): Support tf.VarIsInitializedOp using MLIR.
bool use_mlir = config &&
GetMlirBridgeRolloutPolicy(*config) ==
GetMlirBridgeRolloutPolicy(*graph, *config) ==
MlirBridgeRolloutPolicy::kEnabledByUser &&
node_def.op() != "VarIsInitializedOp";
#ifdef LIBTPU_ON_GCE

View File

@ -153,7 +153,8 @@ Status XlaCompileOnDemandOp::Compile(
ctx, variables_indices, variable_infos, variable_args));
args = XlaComputationLaunchContext::BuildXlaCompilerArguments(
constant_input_indices, inputs, variable_infos);
constant_input_indices, inputs, variable_infos,
static_cast<Device*>(ctx->device()));
TF_RETURN_IF_ERROR(args.status());
}

View File

@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_kernel_creator.h"
#include "tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
@ -23,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/mlir_bridge_pass.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/node_def_builder.h"
@ -89,10 +89,21 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
// Make sure that kernels have been registered on the JIT device.
XlaOpRegistry::RegisterCompilationKernels();
// Get function body, constant args, and resource args.
NameAttrList function;
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
const FunctionBody* fbody = nullptr;
std::vector<int> constant_arg_indices;
std::vector<int> resource_arg_indices;
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
// Only check for compilability if the MLIR bridge is not enabled.
MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(absl::nullopt);
if (policy == MlirBridgeRolloutPolicy::kDisabledByUser ||
policy == MlirBridgeRolloutPolicy::kDisabledAfterGraphAnalysis) {
absl::optional<ConfigProto> config_proto;
if (flr->config_proto()) {
config_proto = *flr->config_proto();
}
if (!IsMlirBridgePassEnabled(*fbody->graph, config_proto)) {
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
@ -121,15 +132,6 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
}
}
// Get function body, constant args, and resource args.
NameAttrList function;
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
const FunctionBody* fbody = nullptr;
std::vector<int> constant_arg_indices;
std::vector<int> resource_arg_indices;
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
MemoryTypeVector input_memory_types =
GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);

View File

@ -449,15 +449,14 @@ Status XlaComputationLaunchContext::PopulateOutputs(
auto transfer_manager,
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
xla::Shape output_host_shape = output.on_host_shape();
xla::Shape output_device_shape = output.on_device_shape();
TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
stream, &output, &output_host_shape, &output_device_shape));
stream, &output, &output_device_shape));
output.set_shapes(output_host_shape, output_device_shape);
output.set_shapes(output_device_shape, output_device_shape);
for (int i = 0; i < ctx->num_outputs(); ++i) {
const xla::Shape& subshape =
xla::ShapeUtil::GetSubshape(output_host_shape, {i});
xla::ShapeUtil::GetSubshape(output_device_shape, {i});
TensorShape shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
output_tensor_shapes.push_back(shape);
@ -564,11 +563,26 @@ xla::StatusOr<std::vector<XlaCompiler::Argument>>
XlaComputationLaunchContext::BuildXlaCompilerArguments(
absl::Span<int const> must_be_constant_idxs,
absl::Span<const Tensor* const> inputs,
absl::Span<VariableInfo const> variable_args) {
absl::Span<VariableInfo const> variable_args, Device* device) {
CHECK(absl::c_is_sorted(must_be_constant_idxs));
std::vector<XlaCompiler::Argument> out;
out.resize(inputs.size());
// TODO(cheshire): Avoid duplication with framework/op_kernel.h
DeviceContext* device_context = nullptr;
TF_RETURN_IF_ERROR(device->TryGetDeviceContext(&device_context));
bool using_default_context = false;
auto cleanup = xla::MakeCleanup([&] {
if (device_context != nullptr && !using_default_context) {
device_context->Unref();
}
});
if (device_context == nullptr) {
using_default_context = true;
auto* dev_info = device->tensorflow_gpu_device_info();
if (dev_info) device_context = dev_info->default_context;
}
absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup;
for (const VariableInfo& info : variable_args) {
CHECK(!info.var() || info.lock_held())
@ -581,18 +595,7 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments(
const Tensor* input = inputs[input_num];
XlaCompiler::Argument& arg = out[input_num];
if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
// Handles compile-time constants.
// TODO(b/157241314): Support constants located in resource variables.
TF_RET_CHECK(input->dtype() != DT_RESOURCE)
<< "tf2xla bridge does not support must-be-constants located in "
"resource variables; try moving them to a tensor";
arg.kind = XlaCompiler::Argument::kConstant;
arg.type = input->dtype();
arg.shape = input->shape();
arg.constant_value = *input;
} else if (variable_info_lookup.count(input_num)) {
if (variable_info_lookup.count(input_num)) {
// Handles resource variables.
TF_RET_CHECK(input->dtype() == DT_RESOURCE);
const VariableInfo& variable = *variable_info_lookup[input_num];
@ -613,6 +616,25 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments(
arg.type = DT_INVALID;
arg.shape = TensorShape();
}
if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
TF_RET_CHECK(variable.var() && variable.var()->is_initialized);
const Tensor* value = variable.var()->tensor();
Tensor value_on_host(value->dtype(), value->shape());
if (!device_context) {
value_on_host = *value;
} else {
TF_RETURN_IF_ERROR(device_context->CopyDeviceTensorToCPUSync(
value, "", device, &value_on_host));
}
arg.kind = XlaCompiler::Argument::kConstantResource;
arg.constant_value = value_on_host;
}
} else if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
arg.kind = XlaCompiler::Argument::kConstant;
arg.type = input->dtype();
arg.shape = input->shape();
arg.constant_value = *input;
} else {
// Normal inputs.
TF_RET_CHECK(input->dtype() != DT_RESOURCE);

View File

@ -143,7 +143,8 @@ class XlaComputationLaunchContext {
static xla::StatusOr<std::vector<XlaCompiler::Argument>>
BuildXlaCompilerArguments(absl::Span<int const> must_be_constant_idxs,
absl::Span<const Tensor* const> inputs,
absl::Span<VariableInfo const> variable_args);
absl::Span<VariableInfo const> variable_args,
Device* device);
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
// `variables` is a map from TensorFlow argument number to resource variable.

View File

@ -3,7 +3,11 @@
load("//tensorflow:tensorflow.bzl", "filegroup")
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_binary",
"tf_cc_test",
)
package(
default_visibility = [
@ -126,12 +130,14 @@ cc_library(
srcs = ["mlir_graph_optimization_pass.cc"],
hdrs = ["mlir_graph_optimization_pass.h"],
deps = [
"//tensorflow/compiler/mlir:mlir_bridge_rollout_policy",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
"//tensorflow/compiler/mlir/tensorflow:device_util",
"//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_set",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
@ -198,11 +204,22 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/jit:flags",
"//tensorflow/core:graph",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:optional",
],
)
tf_cc_test(
name = "mlir_graph_optimization_pass_test",
srcs = ["mlir_graph_optimization_pass_test.cc"],
deps = [
":mlir_graph_optimization_pass",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
filegroup(
name = "litfiles",
srcs = glob(["runlit*py"]),

View File

@ -51,10 +51,10 @@ filegroup(
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td",
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectTdFiles",
"@llvm-project//mlir:include/mlir/Interfaces/CopyOpInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
"@llvm-project//mlir:include/mlir/Interfaces/ViewLikeInterface.td",
],
)
@ -464,7 +464,6 @@ cc_library(
":hlo",
":lhlo",
":lhlo_gpu",
"@llvm-project//mlir:IR",
],
)
@ -500,7 +499,6 @@ cc_library(
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
)
@ -639,12 +637,10 @@ cc_library(
":lhlo",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:ViewLikeInterface",
],
@ -668,6 +664,7 @@ cc_library(
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:ShapeTransforms",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:StandardOpsTransforms",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
@ -702,12 +699,10 @@ cc_library(
deps = [
":cycle_detector",
":hlo",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
],
alwayslink = 1,
@ -738,7 +733,6 @@ cc_library(
deps = [
":hlo",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
@ -759,7 +753,6 @@ cc_library(
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
],
alwayslink = 1,
@ -777,8 +770,6 @@ cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,
@ -797,7 +788,6 @@ cc_library(
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,
@ -838,11 +828,9 @@ cc_library(
":hlo",
":lower_complex_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,

View File

@ -43,18 +43,6 @@ list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules")
option(MHLO_BUILD_EMBEDDED "Build MHLO as part of another project" OFF)
#-------------------------------------------------------------------------------
# MSVC defaults
#-------------------------------------------------------------------------------
if(MSVC)
add_compile_options(
$<$<CONFIG:>:/MD>
$<$<CONFIG:Debug>:/MD>
$<$<CONFIG:Release>:/MD>
)
endif()
#-------------------------------------------------------------------------------
# MLIR/LLVM Configuration
#-------------------------------------------------------------------------------

View File

@ -925,7 +925,7 @@ def HLO_CustomCallOp: HLO_Op<"custom_call", []>, BASE_HLO_CustomCallOp {
DefaultValuedAttr<BoolAttr, "false">:$has_side_effect,
DefaultValuedAttr<StrAttr, "">:$backend_config
);
let results = (outs HLO_Tensor);
let results = (outs Variadic<HLO_Tensor>);
let hasCustomHLOConverter = 1;
}

View File

@ -264,10 +264,11 @@ def LHLO_WhileOp: LHLO_Op<"while", [SameVariadicOperandSize]>,
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
}
def LHLO_CustomCallOp : LHLO_Op<"custom_call", []>, BASE_HLO_CustomCallOp {
def LHLO_CustomCallOp : LHLO_Op<"custom_call", [AttrSizedOperandSegments]>,
BASE_HLO_CustomCallOp {
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$args,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$output,
StrAttr:$call_target_name,
DefaultValuedAttr<BoolAttr, "false">:$has_side_effect,
DefaultValuedAttr<StrAttr, "">:$backend_config

View File

@ -1268,7 +1268,8 @@ class DynamicReshapeOpNotActuallyDynamic
void DynamicReshapeOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<DynamicReshapeOpNotActuallyDynamic>(context);
results.insert<DynamicReshapeOpNotActuallyDynamic, ShapeOfDynamicReshape>(
context);
}
//===----------------------------------------------------------------------===//

View File

@ -28,3 +28,6 @@ def DynamicBroadcastToOwnShape_2 : Pat<
(HLO_DynamicBroadcastInDimOp:$op $x, (Shape_ShapeOfOp $x), $attr),
(replaceWithValue $x)>;
def ShapeOfDynamicReshape : Pat<
(Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)),
(replaceWithValue $shape)>;

View File

@ -13,6 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Enable the use of M_* math constants.
// NOTE: this must be first in the file to ensure that if cmath is transitively
// included by any other header it has the define set on first processing.
// https://docs.microsoft.com/en-us/cpp/c-runtime-library/math-constants
#define _USE_MATH_DEFINES
#include <cmath>
#include <numeric>
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
@ -87,6 +88,32 @@ Value InsertAlloc(Location loc, OpResult result,
return alloc;
}
/// Converts the results of the operation `op` to memref types and append them
/// to the `results` vector.
LogicalResult ConvertResults(Operation* op, SmallVectorImpl<Value>& results,
ConversionPatternRewriter& rewriter) {
for (auto result : llvm::enumerate(op->getResults())) {
RankedTensorType resultType =
result.value().getType().dyn_cast<RankedTensorType>();
if (!resultType) return failure();
if (resultType.hasStaticShape()) {
results.push_back(InsertAlloc(op->getLoc(), result.value(), &rewriter));
continue;
}
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
if (!shape_type_op) return failure();
SmallVector<Value, 1> results_shape;
auto status = shape_type_op.reifyReturnTypeShapes(rewriter, results_shape);
if (failed(status)) return failure();
results.push_back(
InsertDynamicAllocAndDealloc(op->getLoc(), result.value(),
results_shape[result.index()], &rewriter));
}
return success();
}
template <typename HloOpTy>
class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
public:
@ -95,29 +122,8 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
HloOpTy hloOp, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
Operation* op = hloOp.getOperation();
const auto& original_results = op->getResults();
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
for (auto result : llvm::enumerate(original_results)) {
RankedTensorType resultType =
result.value().getType().dyn_cast<RankedTensorType>();
if (!resultType) {
return failure();
}
if (resultType.hasStaticShape()) {
buffer_args.push_back(
InsertAlloc(op->getLoc(), result.value(), &rewriter));
} else {
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
if (!shape_type_op) return failure();
SmallVector<Value, 1> results_shape;
auto status =
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape);
if (failed(status)) return failure();
buffer_args.push_back(InsertDynamicAllocAndDealloc(
op->getLoc(), result.value(), results_shape.front(), &rewriter));
}
}
if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
buffer_args, op->getAttrs());
rewriter.replaceOp(
@ -139,28 +145,8 @@ class HloToLhloOpConverter<mhlo::DotOp> : public BaseOpConversion<mhlo::DotOp> {
mhlo::DotOp hloOp, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
Operation* op = hloOp.getOperation();
const auto& original_results = op->getResults();
SmallVector<Value, 2> buffer_args(operands.begin(), operands.end());
for (auto result : llvm::enumerate(original_results)) {
RankedTensorType resultType =
result.value().getType().dyn_cast<RankedTensorType>();
if (!resultType) {
return failure();
}
if (resultType.hasStaticShape()) {
buffer_args.push_back(
InsertAlloc(op->getLoc(), result.value(), &rewriter));
} else {
SmallVector<Value, 1> results_shape;
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
if (!shape_type_op) return failure();
if (failed(
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape)))
return failure();
buffer_args.push_back(InsertDynamicAllocAndDealloc(
op->getLoc(), result.value(), results_shape.front(), &rewriter));
}
}
if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
// TODO(silvasean): Move this helper to MLIR core.
auto make_elements_attr = [&rewriter](ArrayRef<int64_t> integers) {
@ -180,6 +166,32 @@ class HloToLhloOpConverter<mhlo::DotOp> : public BaseOpConversion<mhlo::DotOp> {
}
};
struct HloToLhloCustomCallOpConverter
: public BaseOpConversion<mhlo::CustomCallOp> {
public:
using BaseOpConversion<mhlo::CustomCallOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
mhlo::CustomCallOp hloOp, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
Operation* op = hloOp.getOperation();
SmallVector<Value, 2> buffer_args(operands.begin(), operands.end());
if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
auto lhloOp = rewriter.create<lmhlo::CustomCallOp>(
op->getLoc(), llvm::None, buffer_args, op->getAttrs());
// Setup AttrSizedOperandSegments attribute to indicate number of operands
// for args and outputs.
const int32_t segments[2] = {static_cast<int32_t>(operands.size()),
static_cast<int32_t>(op->getNumResults())};
lhloOp.setAttr(lhloOp.getOperandSegmentSizeAttr(),
rewriter.getI32VectorAttr(segments));
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
return success();
}
};
struct HloToLhloDynamicBroadcastInDimOpConverter
: public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
public:
@ -194,8 +206,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
Value transformed_operand =
InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
rewriter.create<lmhlo::BroadcastInDimOp>(
loc, transformed_operand, resultBuffer, op.broadcast_dimensions());
rewriter.create<lmhlo::CopyOp>(loc, transformed_operand, resultBuffer);
rewriter.replaceOp(op, {resultBuffer});
@ -211,48 +222,76 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
auto loc = op.getLoc();
auto operand_type = operand.getType().cast<MemRefType>();
auto operand_shape = operand_type.getShape();
auto operand_rank = operand_type.getRank();
SmallVector<Value, 2> sizes, strides;
sizes.reserve(operand_shape.size());
strides.reserve(operand_shape.size());
auto result_type = op.getType().cast<RankedTensorType>();
auto result_rank = result_type.getRank();
Value zero = b->create<ConstantIndexOp>(loc, 0);
Value one = b->create<ConstantIndexOp>(loc, 1);
for (auto dim : llvm::enumerate(op.broadcast_dimensions())) {
Value broadcast_dim_value =
b->create<ConstantIndexOp>(loc, dim.value().getSExtValue());
Value result_dim_size = b->create<ExtractElementOp>(
loc, op.output_dimensions(), broadcast_dim_value);
Value operand_dim_size =
ShapedType::isDynamic(operand_shape[dim.index()])
? b->create<DimOp>(loc, operand, dim.index()).getResult()
: b->create<ConstantIndexOp>(loc, operand_shape[dim.index()])
.getResult();
// TODO(pifon): Revisit if this cast is needed. Maybe we can use
// tensor<index> for `output_dimensions` as well.
// Compute a reversed scan product. Compute the stride for the dimensions so
// far, working from minor to major dimensions. Additionally, save the
// operand shape Values to use in the next loop.
SmallVector<Value, 2> operand_strides(operand_rank, one);
SmallVector<Value, 2> operand_sizes(operand_rank, one);
Value stride_so_far = one;
for (int i = operand_rank - 1; i >= 0; --i) {
Value operand_dim_size =
ShapedType::isDynamic(operand_shape[i])
? b->create<DimOp>(loc, operand, i).getResult()
: b->create<ConstantIndexOp>(loc, operand_shape[i]).getResult();
operand_sizes[i] = operand_dim_size;
operand_strides[i] = stride_so_far;
if (i > 0) {
stride_so_far = b->create<MulIOp>(loc, stride_so_far, operand_dim_size);
}
}
SmallVector<Value, 2> sizes, strides;
sizes.reserve(result_rank);
strides.reserve(result_rank);
DenseMap<int, int> output_to_input_dim;
for (auto dim : llvm::enumerate(op.broadcast_dimensions())) {
output_to_input_dim[dim.value().getSExtValue()] = dim.index();
}
for (int i = 0; i < result_rank; ++i) {
Value i_val = b->create<ConstantIndexOp>(loc, i);
Value result_dim_size =
b->create<ExtractElementOp>(loc, op.output_dimensions(), i_val);
if (!result_dim_size.getType().isIndex()) {
result_dim_size =
b->create<IndexCastOp>(loc, result_dim_size, b->getIndexType());
}
sizes.push_back(result_dim_size);
auto it = output_to_input_dim.find(i);
// If the rank of the output is greater than the rank of the input, i.e.
// there was no output dimension in the inverse broadcast_dimensions map
// we also set stride to 0 to emulate padding of the shape with 1s and the
// corresponding expansion.
if (it == output_to_input_dim.end()) {
strides.push_back(zero);
continue;
}
// There can be two cases:
// 1) Operand dim == result dim => expansion is not needed => stride := 1.
// 1) Operand dim == result dim => expansion is not needed
// => stride flattened buffer stride
// 2) Operand dim < result dim => expansion is needed => stride := 0.
Value is_expansion = b->create<CmpIOp>(loc, CmpIPredicate::slt,
operand_dim_size, result_dim_size);
strides.push_back(
b->create<mlir::SelectOp>(loc, is_expansion, zero, one));
// Size of input dim can be set to the size of the corresponding output
// dimension for both cases.
sizes.push_back(result_dim_size);
int dim = it->second;
Value is_expansion = b->create<CmpIOp>(
loc, CmpIPredicate::slt, operand_sizes[dim], result_dim_size);
strides.push_back(b->create<mlir::SelectOp>(loc, is_expansion, zero,
operand_strides[dim]));
}
// Type-erased memref type with static rank, dynamic sizes and strides.
SmallVector<int64_t, 2> dynamic_layout(operand_shape.size(),
SmallVector<int64_t, 2> dynamic_layout(result_rank,
MemRefType::kDynamicStrideOrOffset);
SmallVector<int64_t, 2> dynamic_shape(operand_shape.size(),
SmallVector<int64_t, 2> dynamic_shape(result_rank,
MemRefType::kDynamicSize);
auto type_erased_memref_type = MemRefType::get(
dynamic_shape, operand_type.getElementType(),
@ -517,11 +556,8 @@ struct HloLegalizeToLhlo
ConversionTarget target(context);
target.addLegalDialect<lmhlo::LmhloDialect>();
target.addLegalDialect<StandardOpsDialect>();
target.addLegalOp<ModuleOp>();
target.addIllegalOp<mlir::TensorLoadOp>();
target.addIllegalOp<mlir::TensorStoreOp>();
target.addLegalOp<ModuleTerminatorOp>();
target.addLegalOp<TensorFromElementsOp>();
target.addIllegalDialect<mhlo::MhloDialect>();
BufferizeTypeConverter converter;
@ -543,9 +579,8 @@ struct HloLegalizeToLhlo
});
populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
populateWithBufferizeOpConversionPatterns<mlir::ReturnOp, mlir::ReturnOp,
lmhlo::CopyOp>(
&context, converter, patterns);
populateFuncOpTypeConversionPattern(patterns, &context, converter);
populateCallOpTypeConversionPattern(patterns, &context, converter);
populateShapeStructuralTypeConversionsAndLegality(&context, converter,
patterns, target);
if (failed(applyPartialConversion(getOperation(), target,
@ -560,6 +595,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
// clang-format off
patterns->insert<
HloToLhloCustomCallOpConverter,
HloToLhloDotGeneralOpConverter,
HloToLhloDynamicBroadcastInDimOpConverter,
HloToLhloDynamicReshapeConverter,
@ -576,7 +612,6 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
HloToLhloOpConverter<mhlo::ConvertOp>,
HloToLhloOpConverter<mhlo::CopyOp>,
HloToLhloOpConverter<mhlo::CosOp>,
HloToLhloOpConverter<mhlo::CustomCallOp>,
HloToLhloOpConverter<mhlo::DivOp>,
HloToLhloOpConverter<mhlo::DotOp>,
HloToLhloOpConverter<mhlo::ExpOp>,
@ -607,7 +642,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
HloToLhloReturnOpConverter,
HloToLhloTensorLoadOpConverter,
HloToLhloTensorStoreOpConverter
>(context);
>(*converter, context);
// clang-format on
}

View File

@ -85,13 +85,66 @@ class LhloFuseLinalgPass
if (!definingOp) {
continue;
}
if (auto viewLike = dyn_cast<ViewLikeOpInterface>(definingOp)) {
auto alias = viewLike.getViewSource();
if (result_buffers.insert(alias).second) {
worklist.push_back(alias);
}
continue;
}
if (auto tensor_load = dyn_cast<TensorLoadOp>(definingOp)) {
auto alias = tensor_load.memref();
if (result_buffers.insert(alias).second) {
worklist.push_back(alias);
}
continue;
}
if (auto tensor_to_memref = dyn_cast<TensorToMemrefOp>(definingOp)) {
auto alias = tensor_to_memref.tensor();
if (result_buffers.insert(alias).second) {
worklist.push_back(alias);
}
continue;
}
if (auto tensor_cast = dyn_cast<TensorCastOp>(definingOp)) {
auto alias = tensor_cast.source();
if (result_buffers.insert(alias).second) {
worklist.push_back(alias);
}
continue;
}
if (auto regionInterface =
dyn_cast<RegionBranchOpInterface>(definingOp)) {
for (Region& region : regionInterface.getOperation()->getRegions()) {
// Only consider regions that can return to the parent region.
SmallVector<RegionSuccessor, 2> successorRegions;
regionInterface.getSuccessorRegions(region.getRegionNumber(),
successorRegions);
if (llvm::none_of(successorRegions, [&](auto successorRegion) {
return successorRegion.isParent();
}))
continue;
// Iterate over all immediate terminators and record the values
// corresponding to result_buffers of interest.
for (Block& block : region) {
if (block.empty()) continue;
Operation& operation = block.back();
if (!operation.hasTrait<OpTrait::ReturnLike>()) continue;
auto idx = result.dyn_cast<OpResult>().getResultNumber();
if (result_buffers.insert(operation.getOperand(idx)).second) {
worklist.push_back(operation.getOperand(idx));
}
}
}
}
}
MLIRContext* ctx = func.getContext();
OpBuilder b(func);
func.walk([&](linalg::GenericOp generic_op) {
@ -114,10 +167,10 @@ class LhloFuseLinalgPass
// Fuse producers of tiled linalg ops.
llvm::SmallDenseSet<Operation*> erase_set;
SmallVector<Operation*, 8> linalg_ops;
SmallVector<LinalgOp, 8> linalg_ops;
func.walk([&](LinalgOp op) { linalg_ops.push_back(op); });
for (auto* op : llvm::reverse(linalg_ops)) {
for (unsigned id = 0, e = LinalgOp(op).getNumInputs(); id < e; ++id) {
for (LinalgOp op : llvm::reverse(linalg_ops)) {
for (unsigned id = 0, e = op.getNumInputs(); id < e; ++id) {
linalg::Aliases aliases;
linalg::LinalgDependenceGraph graph(aliases, linalg_ops);
if (auto info = fuseProducerOfBuffer(b, op, id, graph)) {

View File

@ -50,6 +50,8 @@ class SinkConstantsToControlFlowPass
} else if (auto if_op = llvm::dyn_cast<IfOp>(op)) {
SinkToRegion(&if_op.true_branch());
SinkToRegion(&if_op.false_branch());
} else if (auto reduce_window_op = llvm::dyn_cast<ReduceWindowOp>(op)) {
SinkToRegion(&reduce_window_op.body());
} else if (auto sort_op = llvm::dyn_cast<SortOp>(op)) {
SinkToRegion(&sort_op.comparator());
}

View File

@ -575,6 +575,16 @@ func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<
return %0 : tensor<4x1xf32>
}
// CHECK-LABEL: func @shape_of_dynamic_reshape
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]]
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9]+]]
func @shape_of_dynamic_reshape(%arg0: tensor<*xf32>, %shape: tensor<2xindex>) -> tensor<2xindex> {
// CHECK: return [[ARG1]]
%0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
%1 = shape.shape_of %0 : tensor<?x?xf32> -> tensor<2xindex>
return %1 : tensor<2xindex>
}
// CHECK-LABEL: do_not_dce_while_with_outfeed
func @do_not_dce_while_with_outfeed(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK: mhlo.while

View File

@ -1,4 +1,6 @@
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck %s
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting \
// RUN: -buffer-deallocation -split-input-file -cse %s -o - \
// RUN: | FILECHECK_OPTS="" FileCheck %s
// CHECK-LABEL: func @attrs
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
@ -153,64 +155,41 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
// -----
func @external_func() -> tensor<3xi64>
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + d2 * s2)>
// CHECK-LABEL: func @dyn_broadcast
func @dyn_broadcast(%operand: memref<?x?xf32>) {
// CHECK-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
func @dyn_broadcast(%operand: memref<?x?xf32>) -> index {
// CHECK-SAME: %[[OPERAND:.*]]: memref<?x?xf32>
%tensor_operand = tensor_load %operand : memref<?x?xf32>
%c1 = constant 1 : i64
%shape = tensor_from_elements %c1, %c1, %c1 : tensor<3xi64>
%tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
// CHECK: %[[SHAPE:.*]] = tensor_from_elements
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
// CHECK: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<3xi64>
// CHECK: %[[IC1:.*]] = index_cast %[[EL1]] : i64 to index
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64>
// CHECK: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]])
// CHECK: %[[C0_:.*]] = constant 0 : index
// CHECK: %[[C1_:.*]] = constant 1 : index
// CHECK: %[[C1__:.*]] = constant 1 : index
// CHECK: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64>
// CHECK: %[[C0___:.*]] = constant 0 : index
// CHECK: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], %[[C0___]] : memref<?x?xf32>
// CHECK: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index
// CHECK: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]]
// CHECK: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index
// CHECK: %[[C2_:.*]] = constant 2 : index
// CHECK: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64>
// CHECK: %[[C1___:.*]] = constant 1 : index
// CHECK: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], %[[C1___]] : memref<?x?xf32>
// CHECK: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to
// CHECK-SAME: offset: [0],
// CHECK-SAME: sizes: {{\[}}%[[RESULT_DIM_1]], %[[RESULT_DIM_2]]]
// CHECK-SAME: strides: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
// CHECK-SAME: : memref<?x?xf32> to memref<?x?xf32, #map>
// CHECK: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
// CHECK-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
// CHECK-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> ()
// Do not store the value back to avoid the tensor-store being rewritten to
// a copy into the pre-allocated argument.
return
%rank = rank %tensor_result : tensor<?x?x?xf32>
return %rank : index
}
// CHECK: %[[SHAPE:.*]] = tensor_from_elements
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64>
// CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64>
// CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
// CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref<?x?xf32>
// CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index
// CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref<?x?xf32>
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPER_DIM_0]], %[[SIZE_1]] : index
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0]], %[[OP_STRIDE_0]] : index
// CHECK: %[[EXPAND_2:.*]] = cmpi "slt", %[[OPER_DIM_1]], %[[SIZE_2]] : index
// CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]]: memref<?x?xf32> to memref<?x?x?xf32, #map>
// CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref<?x?x?xf32, #map>, memref<?x?x?xf32>) -> ()
// CHECK: dealloc %[[RESULT]] : memref<?x?x?xf32>
// -----
@ -483,11 +462,9 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
// CHECK: %[[C0_:.*]] = constant 0 : index
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64>
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// CHECK: %[[C1_:.*]] = constant 1 : index
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64>
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
@ -508,11 +485,9 @@ func @tanh_dyn(%arg0: tensor<?x?xf32>) {
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
// CHECK: %[[C0_:.*]] = constant 0 : index
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64>
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// CHECK: %[[C1_:.*]] = constant 1 : index
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64>
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
@ -613,7 +588,7 @@ func @transpose(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
func @custom_call(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memref<4x4xf16>) {
%arg0_tensor = tensor_load %arg0 : memref<2x2xf32>
%arg1_tensor = tensor_load %arg1 : memref<2x3xf32>
// CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false}
// CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false, operand_segment_sizes = dense<[2, 1]> : vector<2xi32>}
%result_tensor = "mhlo.custom_call"(%arg0_tensor, %arg1_tensor)
{backend_config = "", call_target_name = "foo", has_side_effect = false}
: (tensor<2x2xf32>, tensor<2x3xf32>) -> tensor<4x4xf16>
@ -623,6 +598,22 @@ func @custom_call(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memre
// ----
// CHECK-LABEL: func @custom_call_multiout
// CHECK-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>, [[RESULT:%.*]]: memref<4x4xf16>)
func @custom_call_multiout(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memref<4x4xf16>) {
%arg0_tensor = tensor_load %arg0 : memref<2x2xf32>
%arg1_tensor = tensor_load %arg1 : memref<2x3xf32>
// CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}, %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false, operand_segment_sizes = dense<2> : vector<2xi32>}
%temp:2 = "mhlo.custom_call"(%arg0_tensor, %arg1_tensor)
{backend_config = "", call_target_name = "foo", has_side_effect = false}
: (tensor<2x2xf32>, tensor<2x3xf32>) -> (tensor<4x4xf16>, tensor<4x4xf16>)
%result_tensor = "mhlo.add"(%temp#0, %temp#1) : (tensor<4x4xf16>, tensor<4x4xf16>) -> tensor<4x4xf16>
tensor_store %result_tensor, %result: memref<4x4xf16>
return
}
// ----
// CHECK-LABEL: func @isfinite
func @isfinite(%arg0: memref<2x2xf32>, %result: memref<2x2xi1>) {
%arg0_tensor = tensor_load %arg0 : memref<2x2xf32>
@ -645,7 +636,7 @@ func @shape_assuming_memref(%arg0: tensor<?xf16>) -> tensor<?xf16> {
%4 = tensor_cast %3 : tensor<?xindex> to tensor<1xindex>
%5 = "mhlo.dynamic_broadcast_in_dim"(%0, %4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f16>, tensor<1xindex>) -> tensor<?xf16>
%6 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf16>, tensor<1xindex>) -> tensor<?xf16>
// CHECK: "lmhlo.maximum"(%6, %9, %20) : (memref<?xf16>, memref<?xf16>, memref<?xf16>) -> ()
// CHECK: "lmhlo.maximum"(%{{.*}}, %{{.*}}, %{{.*}}) : (memref<?xf16>, memref<?xf16>, memref<?xf16>) -> ()
%7 = mhlo.maximum %5, %6 : tensor<?xf16>
// CHECK: shape.assuming_yield %{{.*}} : memref<?xf16>
shape.assuming_yield %7 : tensor<?xf16>

View File

@ -299,3 +299,131 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
// PLOOP: absf
// PLOOP: memref_reshape
// -----
// Confirm that tiling information is passed through RegionBranchOpInterfaces.
// This test also uses memref_reshape, just to have a value to return through
// the if statement.
func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
-> memref<*xf32> {
%c1 = constant 1 : index
%c0 = constant 0 : index
%1 = alloc(%arg2) : memref<?xf32>
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
ins(%arg0 : memref<?xf32>) outs(%1 : memref<?xf32>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
%13 = absf %arg3 : f32
linalg.yield %13 : f32
}
%true = constant 1 : i1
%3 = scf.if %true -> memref<*xf32> {
%2 = memref_reshape %1(%arg1)
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
scf.yield %2 : memref<*xf32>
} else {
%2 = memref_reshape %1(%arg1)
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
scf.yield %2 : memref<*xf32>
}
return %3 : memref<*xf32>
}
// CHECK-LABEL: func @branching_result
// CHECK: %[[C1:.*]] = constant 1
// CHECK-NOT: linalg.generic
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK: absf
// CHECK: scf.if
// CHECK: memref_reshape
// CHECK: scf.yield
// CHECK: else
// CHECK: memref_reshape
// CHECK: scf.yield
// TILED-LABEL: func @branching_result
// TILED-DAG: %[[C2:.*]] = constant 2
// TILED-NOT: linalg.generic
// TILED: scf.for {{.*}} step %[[C2]]
// TILED-NOT: scf.for
// TILED: linalg.generic
// TILED: absf
// TILED: scf.if
// TILED: memref_reshape
// TILED: scf.yield
// TILED: else
// TILED: memref_reshape
// TILED: scf.yield
// PLOOP-LABEL: func @branching_result
// PLOOP-NOT: linalg.generic
// PLOOP: scf.parallel
// PLOOP-NOT: scf.parallel
// PLOOP: linalg.generic
// PLOOP: absf
// PLOOP: scf.if
// PLOOP: memref_reshape
// PLOOP: scf.yield
// PLOOP: else
// PLOOP: memref_reshape
// PLOOP: scf.yield
// -----
// Confirm that tiling information is passed through tensor_load, tensor_cast
// and memref_to_tensor operations.
func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
-> memref<?xf32> {
%c1 = constant 1 : index
%1 = alloc() : memref<32xf32>
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
ins(%arg0 : memref<32xf32>) outs(%1 : memref<32xf32>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
%13 = absf %arg3 : f32
linalg.yield %13 : f32
}
%2 = tensor_load %1 : memref<32xf32>
%3 = tensor_cast %2 : tensor<32xf32> to tensor<?xf32>
%4 = tensor_to_memref %3 : memref<?xf32>
return %4 : memref<?xf32>
}
// CHECK-LABEL: func @tensor_ops
// CHECK: %[[C1:.*]] = constant 1
// CHECK-NOT: linalg.generic
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK: absf
// CHECK: tensor_load
// CHECK: tensor_cast
// CHECK: tensor_to_memref
// TILED-LABEL: func @tensor_ops
// TILED-DAG: %[[C2:.*]] = constant 2
// TILED-NOT: linalg.generic
// TILED: scf.for {{.*}} step %[[C2]]
// TILED-NOT: scf.for
// TILED: linalg.generic
// TILED: absf
// TILED: tensor_load
// TILED: tensor_cast
// TILED: tensor_to_memref
// PLOOP-LABEL: func @tensor_ops
// PLOOP-NOT: linalg.generic
// PLOOP: scf.parallel
// PLOOP-NOT: scf.parallel
// PLOOP: linalg.generic
// PLOOP: absf
// PLOOP: tensor_load
// PLOOP: tensor_cast
// PLOOP: tensor_to_memref

View File

@ -3,13 +3,13 @@
// Tests for types, ops with custom constraints, verifiers, printer or parser
// methods.
// CHECK-LABEL: func @token_type() -> !mhlo.token
func @token_type() -> !mhlo.token
// CHECK-LABEL: func private @token_type() -> !mhlo.token
func private @token_type() -> !mhlo.token
// -----
// expected-error@+1 {{unknown mhlo type: foobar}}
func @invalid_type() -> !mhlo.foobar
func private @invalid_type() -> !mhlo.foobar
// -----
@ -1281,3 +1281,12 @@ func @set_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<1x128x512xf32> {
%result = "mhlo.set_dimension_size"(%I, %dim) {dimension = 3 : i64} : (tensor<1x128x512xf32>, tensor<i32>) -> tensor<1x128x512xf32>
return %result : tensor<1x128x512xf32>
}
// -----
// CHECK: func @custom_call_multiple_outputs
func @custom_call_multiple_outputs(%x: tensor<2xf32>) -> tensor<2xf32> {
%0:2 = "mhlo.custom_call"(%x) {backend_config="", call_target_name = "foo", has_side_effect = false} : (tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>)
%1 = "mhlo.add"(%0#0, %0#1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
return %1 : tensor<2xf32>
}

View File

@ -35,9 +35,9 @@ filegroup(
"ir/tfl_ops.td",
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectTdFiles",
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
],
)
@ -390,6 +390,7 @@ cc_library(
"transforms/generated_legalize_tf.inc",
"transforms/generated_lower_static_tensor_list.inc",
"transforms/generated_prepare_tf.inc",
"transforms/insert_call_once_op.cc",
"transforms/legalize_tf.cc",
"transforms/legalize_tf_while.cc",
"transforms/lower_static_tensor_list.cc",
@ -427,6 +428,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
"//tensorflow/compiler/mlir/tensorflow:unroll_batch_matmul_pass",
"//tensorflow/compiler/mlir/tensorflow:verification_utils",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
"//tensorflow/compiler/xla:status",
@ -464,6 +466,7 @@ cc_library(
":validators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:verification_utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",

View File

@ -167,7 +167,8 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
case 32:
return tflite::TensorType_INT32;
case 64:
return tflite::TensorType_INT64;
return itype.isUnsigned() ? tflite::TensorType_UINT64
: tflite::TensorType_INT64;
}
} else if (auto q_uniform_type =
type.dyn_cast<mlir::quant::UniformQuantizedType>()) {
@ -453,6 +454,11 @@ class Translator {
mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
// Build call once operator.
BufferOffset<tflite::Operator> BuildCallOnceOperator(
mlir::TFL::CallOnceOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
// Builds custom operators.
// Templated on a) data type of custom_option to be stored into flatbuffer,
// and b) TFL custom op type.
@ -787,6 +793,22 @@ BufferOffset<tflite::Operator> Translator::BuildIfOperator(
builtin_options);
}
BufferOffset<tflite::Operator> Translator::BuildCallOnceOperator(
mlir::TFL::CallOnceOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results) {
auto opcode_index =
GetOpcodeIndex("call_once", tflite::BuiltinOperator_CALL_ONCE);
int init_subgraph_index =
subgraph_index_map_.at(op.session_init_function().str());
auto builtin_options =
tflite::CreateCallOnceOptions(builder_, init_subgraph_index).Union();
auto inputs = builder_.CreateVector(operands);
auto outputs = builder_.CreateVector(results);
return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
tflite::BuiltinOptions_CallOnceOptions,
builtin_options);
}
BufferOffset<tflite::Operator> Translator::BuildWhileOperator(
mlir::TF::WhileOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results) {
@ -1026,6 +1048,12 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
return llvm::None;
}
if (*builtin_code == tflite::BuiltinOperator_CALL_ONCE) {
if (auto initOp = dyn_cast<mlir::TFL::CallOnceOp>(inst)) {
return BuildCallOnceOperator(initOp, operands, results);
}
}
std::string op_name = inst->getName().getStringRef().str();
uint32_t opcode_index = GetOpcodeIndex(op_name, *builtin_code);

View File

@ -448,13 +448,54 @@ StatusOr<Operation*> BuildExternalConstOp(const tflite::TensorT& tensor,
return op.getOperation();
}
// Gets a constant splat for the given value of type. Requires value to be of
// type static shaped RankedTensorType. `unique_index` is used to get the unique
// value for the attribute.
static mlir::ElementsAttr GetSplat(RankedTensorType type, int unique_index,
OpBuilder builder) {
mlir::Type element_ty = getElementTypeOrSelf(type);
if (element_ty.isSignlessInteger())
return DenseElementsAttr::get(
type, builder.getIntegerAttr(element_ty, unique_index));
if (element_ty.isa<mlir::FloatType>())
return DenseElementsAttr::get(
type, builder.getFloatAttr(element_ty, unique_index));
if (auto qtype = element_ty.dyn_cast<QuantizedType>()) {
mlir::RankedTensorType new_type =
RankedTensorType::get(type.getShape(), qtype.getStorageType());
return DenseElementsAttr::get(
new_type, builder.getIntegerAttr(qtype.getStorageType(), unique_index));
}
llvm_unreachable("unhandled element type");
}
// TODO(b/172664358): Creates a new op instead of reusing constant op.
// Creates a constant op to represent stateful variable. The function static
// variable `stateful_variable_idx` is used as a unique value for each constant
// to avoid CSEed. `tensor` is the data structure of flatbuffer. `shaped_type`
// is the ShapedType for the const op.
Operation* BuildVariableOp(const tflite::TensorT& tensor,
mlir::RankedTensorType shaped_type,
OpBuilder builder, Location loc) {
static int stateful_variable_idx = 0;
mlir::ElementsAttr value =
GetSplat(shaped_type, stateful_variable_idx++, builder);
if (IsQuantized(tensor)) {
auto op = builder.create<tfl::QConstOp>(
loc, mlir::TypeAttr::get(shaped_type), value);
return op.getOperation();
}
auto op = builder.create<tfl::ConstOp>(loc, value);
return op.getOperation();
}
StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
const std::vector<uint8_t>& buffer,
OpBuilder builder, Location loc) {
if (buffer.empty()) {
return errors::InvalidArgument("Constant's buffer may not be empty");
}
bool is_variable, OpBuilder builder,
Location loc) {
TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder,
/*shapeless_are_scalars=*/true,
/*is_constant=*/true));
@ -466,7 +507,9 @@ StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
auto elem_type = shaped_type.getElementType();
mlir::ElementsAttr value;
if (auto float_type = elem_type.dyn_cast<mlir::FloatType>()) {
if (is_variable) {
return BuildVariableOp(tensor, shaped_type, builder, loc);
} else if (auto float_type = elem_type.dyn_cast<mlir::FloatType>()) {
TF_ASSIGN_OR_RETURN(value,
ConvertFloatBuffer(shaped_type, float_type, buffer));
} else if (elem_type.isa<mlir::IntegerType, QuantizedType>()) {
@ -846,19 +889,8 @@ StatusOr<FuncOp> ConvertSubgraph(
GetTensorIndices(subgraph, ordered_input_arrays));
}
// Add state variables to inputs.
absl::flat_hash_set<int32_t> input_index_set(func_inputs.begin(),
func_inputs.end());
for (int i = 0, end = subgraph.tensors.size(); i < end; i++) {
auto& tensor = *subgraph.tensors.at(i);
if (tensor.is_variable && !input_index_set.contains(i)) {
func_inputs.emplace_back(i);
input_index_set.insert(i);
}
}
for (auto input_or_variable : func_inputs) {
auto& tensor = *subgraph.tensors.at(input_or_variable);
for (int input : func_inputs) {
auto& tensor = *subgraph.tensors.at(input);
// TODO(b/138222071) Graph inputs must have static shape per the exporter,
// but we cannot differentiate scalars from unranked tensors.
// Here we reverse the default assumption that shape = [] means unranked.
@ -889,7 +921,8 @@ StatusOr<FuncOp> ConvertSubgraph(
}
for (auto output : func_outputs) {
const bool is_func_input = input_index_set.contains(output);
const bool is_func_input = std::find(func_inputs.begin(), func_inputs.end(),
output) != func_inputs.end();
bool is_constant = !is_op_output[output] && !is_func_input;
// There are 2 cases tensor is scalar when it doesn't have a shape in
// flatbuffer:
@ -955,7 +988,7 @@ StatusOr<FuncOp> ConvertSubgraph(
}
func.setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
} else {
func.setVisibility(FuncOp::Visibility::Private);
func.setPrivate();
}
absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
@ -991,7 +1024,7 @@ StatusOr<FuncOp> ConvertSubgraph(
? BuildExternalConstOp(const_tensor, const_tensor.buffer,
op_builder, const_loc)
: BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data,
op_builder, const_loc);
const_tensor.is_variable, op_builder, const_loc);
if (!op_or_err.ok()) {
return emitError(const_loc, op_or_err.status().ToString()),
op_or_err.status();
@ -1051,7 +1084,7 @@ StatusOr<FuncOp> ConvertSubgraph(
? BuildExternalConstOp(const_tensor, const_tensor.buffer,
op_builder, const_loc)
: BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data,
op_builder, const_loc);
const_tensor.is_variable, op_builder, const_loc);
if (!op_or_err.ok()) {
return emitError(const_loc, op_or_err.status().ToString()),
op_or_err.status();

View File

@ -1972,6 +1972,43 @@ OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
return value();
}
//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 1);
// For now, only supports cast between integer types.
auto elements_attr = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
if (!elements_attr) {
return nullptr;
}
auto result_element_type =
getType().cast<ShapedType>().getElementType().dyn_cast<IntegerType>();
auto operand_element_type = input()
.getType()
.cast<ShapedType>()
.getElementType()
.dyn_cast<IntegerType>();
// Returns nullptr if either result/operand element type is not integer.
if (!result_element_type || !operand_element_type) {
return nullptr;
}
const bool is_input_unsigned = operand_element_type.isUnsigned();
const int output_bitwidth = result_element_type.getWidth();
// The integer cast op is the same as C integer cast. Depends on the operand
// type's signedness, we will determine whether or not sign extension is
// needed.
auto cast = [&](APInt value) {
return is_input_unsigned ? value.zextOrTrunc(output_bitwidth)
: value.sextOrTrunc(output_bitwidth);
};
return elements_attr.mapValues(result_element_type, cast);
}
//===----------------------------------------------------------------------===//
// SelectV2Op
//===----------------------------------------------------------------------===//

View File

@ -3405,7 +3405,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [
}];
let arguments = (ins
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8]>:$input,
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8, TFL_Str]>:$input,
TFL_I32Tensor:$begin,
TFL_I32Tensor:$end,
TFL_I32Tensor:$strides,
@ -3418,7 +3418,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [
);
let results = (outs
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8]>:$output
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8, TFL_Str]>:$output
);
let hasOptions = 1;
@ -3443,6 +3443,8 @@ def TFL_CastOp : TFL_Op<"cast", [
// TFLite's cast op does not utilize CastOptions, instead derives types
// from the TfLiteTensors.
let hasOptions = 0;
let hasFolder = 1;
}
def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [
@ -3877,7 +3879,7 @@ def TFL_UnidirectionalSequenceLSTMOp :
TFL_OperandHasRank<14, 1>, // cell_gate_bias
TFL_OperandHasRank<15, 1>, // output_gate_bias
TFL_OperandIsNoneOrHasRank<16, 2>, // projection_weights
TFL_OperandIsNoneOrHasRank<17, 2>, // projection_bias
TFL_OperandIsNoneOrHasRank<17, 1>, // projection_bias
TFL_StatefulOp]> {
let summary = "Unidirectional sequence lstm operator";
@ -4358,6 +4360,21 @@ def TFL_WhileOp : Op<TFL_Dialect, "while", [
let hasCanonicalizer = 1;
}
def TFL_CallOnceOp : TFL_Op<"call_once", []> {
let summary = "Invokes an initialization function";
let description = [{
This operation invokes the given initialization function for the session
initializer in tf saved model dialect.
}];
let arguments = (ins
StrAttr:$session_init_function
);
let results = (outs);
}
def TFL_CustomOp : Op<TFL_Dialect, "custom", [
NoSideEffect, NoQuantizableResult]> {
let summary = "Custom op";

View File

@ -55,7 +55,7 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
// Parse input arrays.
std::vector<string> node_names;
std::vector<string> node_dtypes;
std::vector<std::vector<int>> node_shapes;
std::vector<llvm::Optional<std::vector<int>>> node_shapes;
std::vector<llvm::Optional<double>> node_mins;
std::vector<llvm::Optional<double>> node_maxs;

View File

@ -128,7 +128,7 @@ Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
// Parse input arrays.
std::vector<string> node_names;
std::vector<string> node_dtypes;
std::vector<std::vector<int>> node_shapes;
std::vector<llvm::Optional<std::vector<int>>> node_shapes;
std::vector<llvm::Optional<double>> node_mins;
std::vector<llvm::Optional<double>> node_maxs;

View File

@ -119,6 +119,8 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
return DT_INT32;
case toco::IODataType::INT64:
return DT_INT64;
case toco::IODataType::UINT64:
return DT_UINT64;
case toco::IODataType::STRING:
return DT_STRING;
case toco::IODataType::BOOL:
@ -185,7 +187,7 @@ Status PopulateQuantizationSpecs(
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
std::vector<string>* node_dtypes,
std::vector<std::vector<int>>* node_shapes,
std::vector<llvm::Optional<std::vector<int>>>* node_shapes,
std::vector<llvm::Optional<double>>* node_mins,
std::vector<llvm::Optional<double>>* node_maxs) {
quant_specs->inference_input_type =
@ -210,8 +212,12 @@ Status PopulateQuantizationSpecs(
node_dtypes->push_back(
DataType_Name(ConvertIODataTypeToDataType(toco_data_type)));
}
node_shapes->push_back(std::vector<int>(flag.shape().dims().begin(),
flag.shape().dims().end()));
if (flag.shape().unknown_rank()) {
node_shapes->push_back(llvm::None);
} else {
node_shapes->push_back(std::vector<int>(flag.shape().dims().begin(),
flag.shape().dims().end()));
}
// Currently, only UINT8 and INT8 require inputs stats
if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) {
if (flag.has_mean_value() && flag.has_std_value()) {

View File

@ -41,7 +41,7 @@ Status PopulateQuantizationSpecs(
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
std::vector<string>* node_dtypes,
std::vector<std::vector<int>>* node_shapes,
std::vector<llvm::Optional<std::vector<int>>>* node_shapes,
std::vector<llvm::Optional<double>>* node_mins,
std::vector<llvm::Optional<double>>* node_maxs);

View File

@ -52,6 +52,12 @@ struct QuantizationSpecs {
// weight FakeQuant).
bool disable_per_channel = false;
// When set to true, the fixed output ranges of the activation ops (tanh,
// sigmoid, etc.) are not enforced. Then, to quantize these ops, quantization
// emulation ops should be specified after the ops in the input graph. This
// flag should be set to false for post-training quantization.
bool disable_enforced_fixed_output_range = false;
// The node type when the model is exported. Currently this is limited to
// DT_FLOAT, DT_HALF, DT_QINT8, and DT_QUINT8. When DT_HALF is used, the
// `weight_quantization` flag needs to set to true. When DT_QUINT8 is used,

View File

@ -587,3 +587,55 @@ func @rsqrt_bf16() -> tensor<bf16> {
// CHECK: %[[CST:.*]] = constant dense<5.000000e-01> : tensor<bf16>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @cast_i64_to_i32
func @cast_i64_to_i32() -> tensor<5xi32> {
%cst = constant dense<[-1, 0, 1, 2147483647, 2147483648]> : tensor<5xi64>
%0 = "tfl.cast"(%cst) : (tensor<5xi64>) -> tensor<5xi32>
return %0 : tensor<5xi32>
// CHECK: %[[CST:.*]] = constant dense<[-1, 0, 1, 2147483647, -2147483648]> : tensor<5xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @cast_i32_to_ui8
func @cast_i32_to_ui8() -> tensor<6xui8> {
%cst = constant dense<[0, -1, 256, 127, -128, -129]> : tensor<6xi32>
%0 = "tfl.cast"(%cst) : (tensor<6xi32>) -> tensor<6xui8>
return %0 : tensor<6xui8>
// CHECK: %[[CST:.*]] = constant dense<[0, 255, 0, 127, 128, 127]> : tensor<6xui8>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @cast_ui8_to_i8
func @cast_ui8_to_i8() -> tensor<4xi8> {
%cst = constant dense<[0, 255, 127, 128]> : tensor<4xui8>
%0 = "tfl.cast"(%cst) : (tensor<4xui8>) -> tensor<4xi8>
return %0 : tensor<4xi8>
// CHECK: %[[CST:.*]] = constant dense<[0, -1, 127, -128]> : tensor<4xi8>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @cast_i8_to_i32
func @cast_i8_to_i32() -> tensor<4xi32> {
%cst = constant dense<[0, 128, -1, -128]> : tensor<4xi8>
%0 = "tfl.cast"(%cst) : (tensor<4xi8>) -> tensor<4xi32>
return %0 : tensor<4xi32>
// CHECK: %[[CST:.*]] = constant dense<[0, -128, -1, -128]> : tensor<4xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @cast_ui8_to_i32
func @cast_ui8_to_i32() -> tensor<4xi32> {
%cst = constant dense<[0, 128, 129, 255]> : tensor<4xui8>
%0 = "tfl.cast"(%cst) : (tensor<4xui8>) -> tensor<4xi32>
return %0 : tensor<4xi32>
// CHECK: %[[CST:.*]] = constant dense<[0, 128, 129, 255]> : tensor<4xi32>
// CHECK: return %[[CST]]
}

View File

@ -411,11 +411,11 @@ versions {
# CHECK-NEXT: constant dense<[5.000000e+00, 6.000000e+00, 7.000000e+00, 8.000000e+00]>
# CHECK: "tf.If"{{.+}}else_branch = @cond_false_10{{.+}}is_stateless = true{{.+}}then_branch = @cond_true_10
# CHECK: "tf.If"{{.+}}else_branch = @cond_false0{{.+}}is_stateless = false{{.+}}then_branch = @cond_true0
# CHECK: func @cond_false_10
# CHECK: func private @cond_false_10
# CHECK-NEXT: tfl.div
# CHECK: func @cond_true_10
# CHECK: func private @cond_true_10
# CHECK-NEXT: tfl.sub
# CHECK: func @cond_false0
# CHECK: func private @cond_false0
# CHECK-NEXT: tfl.mul
# CHECK: func @cond_true0
# CHECK: func private @cond_true0
# CHECK-NEXT: tfl.add

View File

@ -78,14 +78,14 @@ versions {
}
# CHECK: func @main(%[[VAL_0:.*]]: tensor<2x5x3xf32>, %[[VAL_1:.*]]: tensor<3x7xf32>) -> tensor<2x5x7xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "Placeholder,Placeholder_1", outputs = "MatMul"}} {
# CHECK: %[[VAL_2:.*]] = constant dense<[1, 0]> : tensor<2xi32>
# CHECK: %[[VAL_3:.*]] = constant dense<[5, 3]> : tensor<2xi32>
# CHECK: %[[VAL_4:.*]] = constant dense<[3, 7]> : tensor<2xi32>
# CHECK: %[[VAL_5:.*]] = constant unit
# CHECK: %[[VAL_6:.*]] = constant dense<[1, 0, 0]> : tensor<3xi32>
# CHECK: %[[VAL_7:.*]] = constant dense<[1, 5, 3]> : tensor<3xi32>
# CHECK: %[[VAL_8:.*]] = constant dense<0> : tensor<3xi32>
# CHECK: %[[VAL_9:.*]] = constant dense<[1, 3, 7]> : tensor<3xi32>
# CHECK-DAG: %[[VAL_2:.*]] = constant dense<[1, 0]> : tensor<2xi32>
# CHECK-DAG: %[[VAL_3:.*]] = constant dense<[5, 3]> : tensor<2xi32>
# CHECK-DAG: %[[VAL_4:.*]] = constant dense<[3, 7]> : tensor<2xi32>
# CHECK-DAG: %[[VAL_5:.*]] = constant unit
# CHECK-DAG: %[[VAL_6:.*]] = constant dense<[1, 0, 0]> : tensor<3xi32>
# CHECK-DAG: %[[VAL_7:.*]] = constant dense<[1, 5, 3]> : tensor<3xi32>
# CHECK-DAG: %[[VAL_8:.*]] = constant dense<0> : tensor<3xi32>
# CHECK-DAG: %[[VAL_9:.*]] = constant dense<[1, 3, 7]> : tensor<3xi32>
# CHECK: %[[VAL_10:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_8]], %[[VAL_7]]) : (tensor<2x5x3xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x5x3xf32>
# CHECK: %[[VAL_11:.*]] = "tfl.reshape"(%[[VAL_10]], %[[VAL_3]]) : (tensor<1x5x3xf32>, tensor<2xi32>) -> tensor<5x3xf32>
# CHECK: %[[VAL_12:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_6]], %[[VAL_7]]) : (tensor<2x5x3xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x5x3xf32>

View File

@ -8,9 +8,11 @@ func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32
return %24 : tensor<1x4xf32>
// CHECK-LABEL: main
// seperate lines since there is no region for this op. third_party/tensorflow/compiler/mlir/lite/ir/tfl_ops.td: 3252
// CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg22, %arg23, %arg18, %arg19, %arg20, %arg21) ( {
// CHECK: %[[RES0:.*]] = "tfl.pseudo_const"() {value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32>
// CHECK: %[[RES1:.*]] = "tfl.pseudo_const"() {value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32>
// CHECK: %[[RES2:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %[[RES0]], %[[RES1]], %arg18, %arg19, %arg20, %arg21) ( {
// CHECK: }) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
// CHECK: return %[[RES0]]
// CHECK: return %[[RES2]]
}
@ -29,9 +31,9 @@ func @testFullyQuantizedLSTM(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248
// -----
// CHECK-LABEL: testUnidirectionalSequenceLstmWithIntermediates
func @testUnidirectionalSequenceLstmWithIntermediates(%arg0: tensor<? x ? x f32>, %arg1: tensor<? x ? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x ? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
func @testUnidirectionalSequenceLstmWithIntermediates(%arg0: tensor<? x ? x f32>, %arg1: tensor<? x ? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}

View File

@ -5,8 +5,8 @@
// CHECK: func @main(%arg0: tensor<1xf32>) -> tensor<*xf32>
// CHECK: %0 = "tf.While"(%arg0) {body = @body, cond = @cond, is_stateless = false} : (tensor<1xf32>) -> tensor<*xf32>
// CHECK: return %0 : tensor<*xf32>
// CHECK: func @cond(%arg0: tensor<*xf32>) -> tensor<*xf32>
// CHECK: func @body(%arg0: tensor<*xf32>) -> tensor<*xf32>
// CHECK: func private @cond(%arg0: tensor<*xf32>) -> tensor<*xf32>
// CHECK: func private @body(%arg0: tensor<*xf32>) -> tensor<*xf32>
func @main(%arg0: tensor<1xf32>) -> tensor<*xf32> {
%0 = "tf.While"(%arg0) {cond = @cond, body = @body, is_stateless = false} : (tensor<1xf32>) -> tensor<*xf32>

View File

@ -1,6 +1,6 @@
// RUN: tf-opt -tfl-prepare-composite-funcs-tf -tfl-fuse-tftext=true %s | FileCheck %s
func @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<1>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
func private @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {tf._input_shapes = [#tf.shape<1>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
%1 = "tf.Const"() {value = dense<[]> : tensor<0xi64>} : () -> tensor<0xi64>
%2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
@ -1026,11 +1026,11 @@ func @WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_true_23810(%arg0: t
return %1 : tensor<i1>
}
// CHECK: func @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<1>], tf.signature.is_stateful} {
// CHECK: func private @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<1>], tf.signature.is_stateful} {
// CHECK: %0:2 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<1x!tf.string>) -> (tensor<?x!tf.string>, tensor<?xi64>)
// CHECK: return %0#0, %0#1 : tensor<?x!tf.string>, tensor<?xi64>
func @whitespace_tokenizer_rank2(%arg0: tensor<?x1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>) attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<?x1>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
func private @whitespace_tokenizer_rank2(%arg0: tensor<?x1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>) attributes {tf._input_shapes = [#tf.shape<?x1>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<[]> : tensor<0xi64>} : () -> tensor<0xi64>
%1 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
%2 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
@ -2160,11 +2160,11 @@ func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_As
// CHECK: func @whitespace_tokenizer_rank2(%arg0: tensor<?x1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<?x1>], tf.signature.is_stateful} {
// CHECK: func private @whitespace_tokenizer_rank2(%arg0: tensor<?x1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>) attributes {tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<?x1>], tf.signature.is_stateful} {
// CHECK: %0:3 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<?x1x!tf.string>) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>)
// CHECK: return %0#0, %0#1, %0#2 : tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>
func @whitespace_tokenizer_rank0(%arg0: tensor<!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
func private @whitespace_tokenizer_rank0(%arg0: tensor<!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {tf._input_shapes = [#tf.shape<>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
%1 = "tf.Const"() {value = dense<[]> : tensor<0xi64>} : () -> tensor<0xi64>
%2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
@ -3190,7 +3190,7 @@ func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_
return %1 : tensor<i1>
}
// CHECK: func @whitespace_tokenizer_rank0(%arg0: tensor<!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<>], tf.signature.is_stateful} {
// CHECK: func private @whitespace_tokenizer_rank0(%arg0: tensor<!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<>], tf.signature.is_stateful} {
// CHECK: %0 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<!tf.string>) -> tensor<?x!tf.string>
// CHECK: return %0 : tensor<?x!tf.string>
@ -3213,7 +3213,7 @@ func @ngrams(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "input"}) ->
// CHECK: return %0 : tensor<?x!tf.string>
// CHECK: }
func @ngrams_ragged_rank_2(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<3xi64> {tf._user_specified_name = "args_0"}, %arg2: tensor<?xi64> {tf._user_specified_name = "args_1"}) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = "", width = 2 : i64}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<3>, #tf.shape<?>], tf.signature.is_stateful} {
func private @ngrams_ragged_rank_2(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<3xi64> {tf._user_specified_name = "args_0"}, %arg2: tensor<?xi64> {tf._user_specified_name = "args_1"}) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) attributes {tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = "", width = 2 : i64}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<3>, #tf.shape<?>], tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.Const"() {value = dense<-1> : tensor<i64>} : () -> tensor<i64>
%2 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
@ -3330,12 +3330,12 @@ func @ngrams_ragged_rank_2(%arg0: tensor<?x!tf.string> {tf._user_specified_name
%71 = "tf.Identity"(%70) {device = ""} : (tensor<3xi64>) -> tensor<3xi64>
return %68, %71, %64 : tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_27770(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_27770(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
return %1 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_27780(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_27780(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
@ -3345,12 +3345,12 @@ func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_as
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
return %5 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_28130(%arg0: tensor<i1>, %arg1: tensor<?xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<?>]} {
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_28130(%arg0: tensor<i1>, %arg1: tensor<?xi64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<?>]} {
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
return %1 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_28140(%arg0: tensor<i1>, %arg1: tensor<?xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<?>], tf.signature.is_stateful} {
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_28140(%arg0: tensor<i1>, %arg1: tensor<?xi64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<?>], tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/sub:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
@ -3359,12 +3359,12 @@ func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_as
%4 = "tf.Identity"(%3) {device = ""} : (tensor<i1>) -> tensor<i1>
return %4 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28500(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28500(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
return %1 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28510(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28510(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits/strided_slice_1:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
@ -3374,12 +3374,12 @@ func @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_Assert
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
return %5 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28900(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28900(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
return %1 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28910(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28910(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
@ -3389,12 +3389,12 @@ func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
return %5 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_29260(%arg0: tensor<i1>, %arg1: tensor<2xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<2>]} {
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_29260(%arg0: tensor<i1>, %arg1: tensor<2xi64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<2>]} {
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
return %1 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_29270(%arg0: tensor<i1>, %arg1: tensor<2xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<2>], tf.signature.is_stateful} {
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_29270(%arg0: tensor<i1>, %arg1: tensor<2xi64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<2>], tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/sub:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
@ -3403,12 +3403,12 @@ func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_
%4 = "tf.Identity"(%3) {device = ""} : (tensor<i1>) -> tensor<i1>
return %4 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_true_29650(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_true_29650(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
return %1 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_false_29660(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_false_29660(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits_1/strided_slice:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
@ -3418,12 +3418,12 @@ func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_Asse
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
return %5 : tensor<i1>
}
func @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_true_30330(%arg0: tensor<i1>, %arg1: tensor<?xi64>, %arg2: tensor<?xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<?>, #tf.shape<?>]} {
func private @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_true_30330(%arg0: tensor<i1>, %arg1: tensor<?xi64>, %arg2: tensor<?xi64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<?>, #tf.shape<?>]} {
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
return %1 : tensor<i1>
}
func @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_false_30340(%arg0: tensor<i1>, %arg1: tensor<?xi64>, %arg2: tensor<?xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
func private @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_false_30340(%arg0: tensor<i1>, %arg1: tensor<?xi64>, %arg2: tensor<?xi64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<"Inputs must have identical ragged splits"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%2 = "tf.Const"() {value = dense<"x (NGrams/SlidingWindow/RaggedGetItem/RaggedRange:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
@ -3433,12 +3433,12 @@ func @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_false_
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
return %5 : tensor<i1>
}
// CHECK: func @ngrams_ragged_rank_2(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<3xi64> {tf._user_specified_name = "args_0"}, %arg2: tensor<?xi64> {tf._user_specified_name = "args_1"}) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = "", width = 2 : i64}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<3>, #tf.shape<?>], tf.signature.is_stateful} {
// CHECK: func private @ngrams_ragged_rank_2(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<3xi64> {tf._user_specified_name = "args_0"}, %arg2: tensor<?xi64> {tf._user_specified_name = "args_1"}) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) attributes {tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = "", width = 2 : i64}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<3>, #tf.shape<?>], tf.signature.is_stateful} {
// CHECK: %0:3 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "tftext:Ngrams", custom_option = opaque<"tfl", "0x776964746800737472696E675F736570617261746F720000006178697300726564756374696F6E5F74797065000B535452494E475F4A4F494E0004221E373E040104FF152C0204141404082401"> : tensor<77xi8>} : (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>)
// CHECK: return %0#0, %0#1, %0#2 : tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>
func @sgnn_projection(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<?xi64> {tf._user_specified_name = "row_splits"}) -> tensor<?x10xf64> attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:custom:SgnnProjection", {buckets = 2147483647 : i64, hash_seed = [1902835825, -1475704015, 473120514, 1254202069, 1558833093, 1756181982, 1906603252, -1034142694, 542842690, 535515822]}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
func private @sgnn_projection(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<?xi64> {tf._user_specified_name = "row_splits"}) -> tensor<?x10xf64> attributes {tf._implements = #tf.func<@"tftext:custom:SgnnProjection", {buckets = 2147483647 : i64, hash_seed = [1902835825, -1475704015, 473120514, 1254202069, 1558833093, 1756181982, 1906603252, -1034142694, 542842690, 535515822]}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<[[1902835825], [-1475704015], [473120514], [1254202069], [1558833093], [1756181982], [1906603252], [-1034142694], [542842690], [535515822]]> : tensor<10x1xi64>} : () -> tensor<10x1xi64>
%1 = "tf.StringToHashBucketFast"(%arg0) {device = "", num_buckets = 2147483647 : i64} : (tensor<?x!tf.string>) -> tensor<?xi64>
%2 = "tf.Sgnn"(%1, %0) {device = ""} : (tensor<?xi64>, tensor<10x1xi64>) -> tensor<10x?xf64>
@ -3448,6 +3448,6 @@ func @sgnn_projection(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "va
}
// CHECK: func @sgnn_projection(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<?xi64> {tf._user_specified_name = "row_splits"}) -> tensor<?x10xf64> attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:custom:SgnnProjection", {buckets = 2147483647 : i64, hash_seed = [1902835825, -1475704015, 473120514, 1254202069, 1558833093, 1756181982, 1906603252, -1034142694, 542842690, 535515822]}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
// CHECK: func private @sgnn_projection(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<?xi64> {tf._user_specified_name = "row_splits"}) -> tensor<?x10xf64> attributes {tf._implements = #tf.func<@"tftext:custom:SgnnProjection", {buckets = 2147483647 : i64, hash_seed = [1902835825, -1475704015, 473120514, 1254202069, 1558833093, 1756181982, 1906603252, -1034142694, 542842690, 535515822]}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
// CHECK: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "tftext:custom:SgnnProjection", custom_option = opaque<"tfl", "0x686173685F736565640000000A00000071F86A71318B0AA8023F331CD59AC14AC5E7E95CDE35AD68F474A4711A3C5CC2421F5B20AE52EB1F6275636B6574730002094200030000000100000002000000FFFFFF7F44000000062E0A2601"> : tensor<93xi8>} : (tensor<?x!tf.string>, tensor<?xi64>) -> tensor<?x10xf64>
// CHECK: return %0 : tensor<?x10xf64>

View File

@ -0,0 +1,40 @@
// RUN: tf-opt -split-input-file -tfl-insert-call-once-op %s | FileCheck %s
// Tests that new call_once op is added when there is a session initializer.
module attributes {tf_saved_model.semantics} {
"tf_saved_model.session_initializer"() {initializers = [@init_all_tables]} : () -> ()
func @init_all_tables()
attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"]} {
%cst = constant dense<[1, 2, 3, 4]> : tensor<4xi64>
%cst_0 = constant dense<["a", "b", "c", "d"]> : tensor<4x!tf.string>
%0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = i64, shared_name = "hash_table_dba2ccaa-f1b1-46d6-b276-98008f69da71", use_node_name_sharing = false, value_dtype = !tf.string} : () -> tensor<!tf.resource>
"tf.LookupTableImportV2"(%0, %cst, %cst_0) {device = ""} : (tensor<!tf.resource>, tensor<4xi64>, tensor<4x!tf.string>) -> ()
return
// CHECK-LABEL: @init_all_tables
}
func @serving_default(%arg0: tensor<i64> {tf_saved_model.index_path = ["x"]}) -> (tensor<*x!tf.string> {tf_saved_model.index_path = ["r"]})
attributes {tf.entry_function = {control_outputs = "", inputs = "input:0", outputs = "hash_table_Lookup/LookupTableFindV2:0"}, tf_saved_model.exported_names = ["serving_default"]} {
%cst = constant dense<"f"> : tensor<!tf.string>
%0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = i64, shared_name = "hash_table_dba2ccaa-f1b1-46d6-b276-98008f69da71", use_node_name_sharing = false, value_dtype = !tf.string} : () -> tensor<!tf.resource>
%1 = "tf.LookupTableFindV2"(%0, %arg0, %cst) {device = ""} : (tensor<!tf.resource>, tensor<i64>, tensor<!tf.string>) -> tensor<*x!tf.string>
return %1 : tensor<*x!tf.string>
// CHECK-LABEL: @serving_default
// CHECK: "tfl.call_once"() {session_init_function = "init_all_tables"} : () -> ()
}
}
// -----
// Tests that no call_once op is added.
module attributes {tf_saved_model.semantics} {
func @no_call_once(%arg0: tensor<i64> {tf_saved_model.index_path = ["x"]}) -> (tensor<i64> {tf_saved_model.index_path = ["r"]})
attributes {tf.entry_function = {control_outputs = "", inputs = "input:0", outputs = "output:0"}, tf_saved_model.exported_names = ["serving_default"]} {
return %arg0 : tensor<i64>
// CHECK-LABEL: no_call_once
// CHECK-NOT: "tfl.call_once"
}
}

View File

@ -435,6 +435,16 @@ func @scatterNdHigherRankIndices(%arg0: tensor<4x2x2xi32>, %arg1: tensor<4x2x3xf
// CHECK: return %[[RES]]
}
func @scatter_nd_i64(%arg0: tensor<4x2x2xi64>, %arg1: tensor<4x2x3xf32>, %arg2: tensor<3xi64>) -> tensor<10x2x3xf32> {
%0 = "tf.ScatterNd"(%arg0, %arg1, %arg2) : (tensor<4x2x2xi64>, tensor<4x2x3xf32>, tensor<3xi64>) -> tensor<10x2x3xf32>
return %0 : tensor<10x2x3xf32>
// CHECK-LABEL:scatter_nd_i64
// CHECK: "tfl.cast"
// CHECK: "tfl.cast"
// CHECK: "tfl.scatter_nd"
}
func @gatherV2VectorIndices(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x3x5x20xf32> {
%0 = "tf.Const"() { value = dense<[1]> : tensor<1xi32> } : () -> tensor<1xi32>
%1 = "tf.GatherV2"(%arg0, %arg1, %0) : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x3x5x20xf32>
@ -689,6 +699,16 @@ func @reverse_v2(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1xi32>) -> tensor<1x2
// CHECK: return
}
func @reverse_v2_i64(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1xi64>) -> tensor<1x2x3x4xf32> {
%0 = "tf.ReverseV2"(%arg0, %arg1) : (tensor<1x2x3x4xf32>, tensor<1xi64>) -> tensor<1x2x3x4xf32>
return %0 : tensor<1x2x3x4xf32>
// CHECK-LABEL:reverse_v2_i64
// CHECK: "tfl.cast"
// CHECK: "tfl.reverse_v2"
// CHECK: return
}
func @matrix_diag(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
%0 = "tf.MatrixDiag"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16x16xf32>
return %0 : tensor<8x16x16xf32>
@ -763,13 +783,31 @@ func @matrix_diag_v3(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
// CHECK: return [[VAL_6]] : tensor<8x16x16xf32>
}
func @matrix_set_diag(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
%0 = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
return %0 : tensor<3x3xi32>
func @matrix_set_diag_v3(%arg0: tensor<3x3xi64>, %arg1: tensor<3xi32>) -> tensor<3x3xi64> {
%cst = constant dense<0> : tensor<i32>
%0 = "tf.MatrixSetDiagV3"(%arg0, %arg1, %cst) {align = "RIGHT_LEFT"} : (tensor<3x3xi64>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi64>
return %0 : tensor<3x3xi64>
// CHECK-LABEL: func @matrix_set_diag(
// CHECK: [[VAL_0:%.*]] = "tfl.matrix_set_diag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
// CHECK: return [[VAL_0]]
// CHECK-LABEL: func @matrix_set_diag_v3
// CHECK: "tfl.matrix_set_diag"(%arg0, %arg1) : (tensor<3x3xi64>, tensor<3xi32>) -> tensor<3x3xi64>
}
func @matrix_set_diag_v3_non_zero_k(%arg0: tensor<3x3xi64>, %arg1: tensor<3xi32>) -> tensor<3x3xi64> {
%cst = constant dense<1> : tensor<i32>
%0 = "tf.MatrixSetDiagV3"(%arg0, %arg1, %cst) : (tensor<3x3xi64>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi64>
return %0 : tensor<3x3xi64>
// CHECK-LABEL: @matrix_set_diag_v3_non_zero_k
// CHECK: tf.MatrixSetDiagV3
}
func @matrix_set_diag_v3_default_align(%arg0: tensor<3x3xi64>, %arg1: tensor<3xi32>) -> tensor<3x3xi64> {
%cst = constant dense<0> : tensor<i32>
%0 = "tf.MatrixSetDiagV3"(%arg0, %arg1, %cst) : (tensor<3x3xi64>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi64>
return %0 : tensor<3x3xi64>
// CHECK-LABEL: @matrix_set_diag_v3_default_align
// CHECK: "tfl.matrix_set_diag"(%arg0, %arg1) : (tensor<3x3xi64>, tensor<3xi32>) -> tensor<3x3xi64>
}
func @maximum(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> {
@ -996,6 +1034,15 @@ func @batch_to_space_nd_unsupported(%arg0: tensor<?x1x1x1x4xf32>, %arg1: tensor<
// CHECK: "tf.BatchToSpaceND"
}
func @batch_to_space_nd_i64(%arg0: tensor<4x2x2x3xf32>, %arg1: tensor<2xi64>, %arg2: tensor<2x2xi64>) -> tensor<?xf32> {
%0 = "tf.BatchToSpaceND"(%arg0, %arg1, %arg2) : (tensor<4x2x2x3xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<?xf32>
return %0 : tensor<?xf32>
// CHECK-LABEL: batch_to_space_nd_i64
// CHECK: "tfl.cast"
// CHECK: "tfl.cast"
// CHECK: "tfl.batch_to_space_nd"
}
func @space_to_batch_nd(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2x2xi32>) -> tensor<*xf32> {
%0 = "tf.SpaceToBatchND"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
@ -1003,6 +1050,15 @@ func @space_to_batch_nd(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<2xi32>, %arg2:
// CHECK: "tfl.space_to_batch_nd"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<*xf32>
}
func @space_to_batch_nd_i64(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<2xi64>, %arg2: tensor<2x2xi64>) -> tensor<*xf32> {
%0 = "tf.SpaceToBatchND"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<*xf32>
return %0 : tensor<*xf32>
// CHECK-LABEL: space_to_batch_nd_i64
// CHECK: "tfl.cast"
// CHECK: "tfl.cast"
// CHECK: "tfl.space_to_batch_nd"
}
func @split(%arg0: tensor<i32>, %arg1: tensor<1x4x3x3xf32>) -> tensor<1x4x3xf32> {
%0:3 = "tf.Split"(%arg0, %arg1) : (tensor<i32>, tensor<1x4x3x3xf32>) -> (tensor<1x4x3xf32>, tensor<1x4x3xf32>, tensor<1x4x3xf32>)
return %0#0 : tensor<1x4x3xf32>
@ -1122,6 +1178,13 @@ func @strided_slice_with_constant_attributes(%arg0: tensor<10x10x10xf32>, %arg1:
// CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 6 : i32, ellipsis_mask = 0 : i32, end_mask = 6 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<10x10xf32>
}
func @strided_slice_with_string(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> {
%0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string>
return %0 : tensor<1x2x2x5x!tf.string>
// CHECK-LABEL: strided_slice_with_string
// CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string>
}
func @slice1Tensor(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<?x3x5xf32> {
%0 = "tf.Slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
return %0 : tensor<?x3x5xf32>
@ -1354,8 +1417,7 @@ func @conv2d_backprop_input(%arg0: tensor<4xi32>, %arg1: tensor<3x3x1x32xf32>, %
// CHECK-LABEL: conv2d_backprop_input
// CHECK: %[[CST:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32>
// CHECK: %[[CAST:.*]] = "tfl.cast"(%[[CST]]) : (tensor<4xi32>) -> tensor<4xi32>
// CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CAST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
// CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
// CHECK: %[[CST_0:.*]] = constant unit
// CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32>
// CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32>
@ -1790,10 +1852,25 @@ func @cumsum(%arg0: tensor<3x3xf32>, %arg1: tensor<i32>) -> tensor<3x3xf32> {
// CHECK: "tfl.cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor<i32>) -> tensor<3x3xf32>
}
func @cumsum_invalid(%arg0: tensor<3x3xf32>, %arg1: tensor<i64>) -> tensor<3x3xf32> {
func @cumsum_i64(%arg0: tensor<3x3xf32>, %arg1: tensor<i64>) -> tensor<3x3xf32> {
%0 = "tf.Cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor<i64>) -> tensor<3x3xf32>
return %0 : tensor<3x3xf32>
// CHECK-LABEL: cumsum_invalid
// CHECK-NOT: "tfl.cumsum"
// CHECK-LABEL: cumsum_i64
// CHECK: "tfl.cast"
// CHECK: "tfl.cumsum"
}
func @segmentsum(%arg0: tensor<3x3xf32>, %arg1: tensor<i32>) -> tensor<*xf32> {
%0 = "tf.SegmentSum"(%arg0, %arg1) : (tensor<3x3xf32>, tensor<i32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
// CHECK-LABEL: segmentsum
// CHECK: "tfl.segment_sum"(%arg0, %arg1) : (tensor<3x3xf32>, tensor<i32>) -> tensor<*xf32>
}
func @segmentsum_i64(%arg0: tensor<3x3xf32>, %arg1: tensor<i64>) -> tensor<*xf32> {
%0 = "tf.SegmentSum"(%arg0, %arg1) : (tensor<3x3xf32>, tensor<i64>) -> tensor<*xf32>
return %0 : tensor<*xf32>
// CHECK-LABEL: segmentsum_i64
// CHECK: "tfl.cast"
// CHECK: "tfl.segment_sum"
}

View File

@ -1,6 +1,6 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s --dump-input=always
func @main(tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> {
func @main(tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
@ -129,7 +129,7 @@ func @main(tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 18,
// CHECK-NEXT: name: "arg17",
// CHECK-NEXT: quantization: {
@ -282,9 +282,36 @@ func @main(tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t
// CHECK-NEXT: }
// CHECK-EMPTY:
^bb0(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4x4xf32>, %arg18: tensor<4x4xf32>, %arg19: tensor<4x4xf32>, %arg20: tensor<4x4xf32>, %arg21: tensor<4x4xf32>):
^bb0(%arg0: tensor<4x4xf32>,
%arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>,
%arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>,
%arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>,
%arg12: tensor<4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>,
%arg16: tensor<4x4xf32>, %arg17: tensor<4xf32>,
%arg18: tensor<4x4xf32>, %arg19: tensor<4x4xf32>, %arg20: tensor<4x4xf32>, %arg21: tensor<4x4xf32>):
%0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const")
%1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const")
%2 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %0, %1, %arg18, %arg19, %arg20, %arg21) {effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "NONE", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
%2 = "tfl.unidirectional_sequence_lstm"(%arg0,
%arg1, %arg2, %arg3, %arg4,
%arg5, %arg6, %arg7, %arg8,
%arg9, %arg10, %arg11,
%arg12, %arg13, %arg14, %arg15,
%arg16, %arg17,
%0, %1,
%arg18, %arg19,%arg20, %arg21) {
effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>,
fused_activation_function = "NONE",
input_to_cell_intermediate = tensor<0xf32>,
input_to_forget_intermediate = tensor<0xf32>,
input_to_input_intermediate = tensor<0xf32>,
input_to_output_intermediate = tensor<0xf32>, time_major = true}
: (tensor<4x4xf32>,
tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>,
tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>,
tensor<4xf32>, tensor<4xf32>, tensor<4xf32>,
tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>,
tensor<4x4xf32>, tensor<4xf32>,
tensor<4x4xf32>, tensor<4x4xf32>,
tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
return %2 : tensor<4x4xf32>
}

View File

@ -663,25 +663,25 @@ func @testUnidirectionalSequenceLstmWithoutProjection(%arg0: tensor<? x ? x f32>
// -----
// CHECK-LABEL: testUnidirectionalSequenceLstm
func @testUnidirectionalSequenceLstm(%arg0: tensor<? x ? x f32>, %arg1: tensor<? x ? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x ? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
func @testUnidirectionalSequenceLstm(%arg0: tensor<? x ? x f32>, %arg1: tensor<? x ? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr
func @testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x ? x f32>, %arg1: none, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x ? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
func @testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x ? x f32>, %arg1: none, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// CHECK-LABEL: testUnidirectionalSequenceLstmWithIntermediates
func @testUnidirectionalSequenceLstmWithIntermediates(%arg0: tensor<? x ? x f32>, %arg1: none, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x ? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
func @testUnidirectionalSequenceLstmWithIntermediates(%arg0: tensor<? x ? x f32>, %arg1: none, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -1458,6 +1458,12 @@ func @testStridedSliceTFType(%arg0: tensor<12x2x2x5xui8>, %arg1: tensor<1xi32>,
return %0 : tensor<1x2x2x5x!tf.quint8>
}
// CHECK-LABEL: testStridedSliceWithString
func @testStridedSliceWithString(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> {
%0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string>
return %0 : tensor<1x2x2x5x!tf.string>
}
// -----
func @testStridedSliceWithInvalidOutputType(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5xi32> {

View File

@ -407,16 +407,16 @@ func @fuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112
}
// CHECK-LABEL: @notFuseMulIntoDepthwiseConv2d
func @notFuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> {
func @notFuseMulIntoDepthwiseConv2d(%arg0: tensor<1x4x4x2xf32>) -> tensor<1x4x4x2xf32> {
%cst0 = constant dense<[[[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0], [17.0, 18.0]]]]> : tensor<1x3x3x2xf32>
%cst1 = constant dense<2.0> : tensor<2xf32>
%cst2 = constant dense<3.0> : tensor<112x2xf32>
%cst2 = constant dense<[[3.1, 3.2], [3.1, 3.2], [3.1, 3.2], [3.1, 3.2]]> : tensor<4x2xf32>
%0 = "tfl.depthwise_conv_2d"(%arg0, %cst0, %cst1) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
%0 = "tfl.depthwise_conv_2d"(%arg0, %cst0, %cst1) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x4x4x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x4x4x2xf32>
// We cannot fuse this tfl.mul into the preceding conv op because %cst2 is not broadcast-compatible to %cst0.
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x112x112x2xf32>, tensor<112x2xf32>) -> tensor<1x112x112x2xf32>
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x4x4x2xf32>, tensor<4x2xf32>) -> tensor<1x4x4x2xf32>
return %1 : tensor<1x112x112x2xf32>
return %1 : tensor<1x4x4x2xf32>
// CHECK: %0 = "tfl.depthwise_conv_2d"(%arg0, %cst, %cst_0)
// CHECK: %1 = "tfl.mul"(%0, %cst_1)
@ -484,17 +484,17 @@ func @FuseFullyConnectedAddWithScalarRhs(%arg0: tensor<40x37xf32>, %arg1: tensor
}
// CHECK-LABEL: @FuseFullyConnectedAddWithUnfusableRhs
func @FuseFullyConnectedAddWithUnfusableRhs(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
func @FuseFullyConnectedAddWithUnfusableRhs(%arg0: tensor<4x37xf32>, %arg1: tensor<4x37xf32>) -> tensor<4x4xf32> {
%cst = constant unit
%cst2 = constant dense<2.0> : tensor<40x40xf32>
%cst2 = constant dense<[[2.0, 2.1, 2.2, 2.3], [2.0, 2.1, 2.2, 2.3], [2.0, 2.1, 2.2, 2.3], [2.0, 2.1, 2.2, 2.3]]> : tensor<4x4xf32>
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>)
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40x40xf32>) -> tensor<40x40xf32>
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x37xf32>, tensor<4x37xf32>, none) -> (tensor<4x4xf32>)
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
return %1 : tensor<40x40xf32>
return %1 : tensor<4x4xf32>
// CHECK: %[[unit:.*]] = constant unit
// CHECK: %[[filter:.*]] = constant dense<2.000000e+00> : tensor<40x40xf32>
// CHECK: %[[filter:.*]] = constant dense<{{.*}}> : tensor<4x4xf32>
// CHECK: %[[fc_result:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[unit]])
// CHECK: %[[add_result:.*]] = tfl.add %[[fc_result]], %[[filter]]
// CHECK: return %[[add_result]]
@ -578,6 +578,32 @@ func @NotReorderReshapeAddIfNotTailingDimAfter(%arg0: tensor<1x30x1x96xf32>) ->
// CHECK: return %[[rs2]]
}
// CHECK-LABEL: @NotReorderReshapeAddIf5DInputs
func @NotReorderReshapeAddIf5DInputs(%arg0: tensor<1x1x1x1x1xf32>) -> tensor<1x1x1x1x2xf32> {
%cst = constant dense<2.0> : tensor<1x1x1x1x2xf32>
%shape = constant dense<[1, 1, 1, 1, 2]> : tensor<5xi32>
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<1x1x1x1x1xf32>, tensor<5xi32>) -> tensor<1x1x1x1x2xf32>
%2 = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<1x1x1x1x2xf32>, tensor<1x1x1x1x2xf32>) -> tensor<1x1x1x1x2xf32>
return %2 : tensor<1x1x1x1x2xf32>
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0
// CHECK: %[[rs2:.*]] = tfl.add %[[rs1]]
// CHECK: return %[[rs2]]
}
// CHECK-LABEL: @NotReorderReshapeFloorDivIf5DInputs
func @NotReorderReshapeFloorDivIf5DInputs(%arg0: tensor<1x1x1x1x1xf32>) -> tensor<1x1x1x1x2xf32> {
%cst = constant dense<2.0> : tensor<1x1x1x1x2xf32>
%shape = constant dense<[1, 1, 1, 1, 2]> : tensor<5xi32>
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<1x1x1x1x1xf32>, tensor<5xi32>) -> tensor<1x1x1x1x2xf32>
%2 = "tfl.floor_div"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<1x1x1x1x2xf32>, tensor<1x1x1x1x2xf32>) -> tensor<1x1x1x1x2xf32>
return %2 : tensor<1x1x1x1x2xf32>
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0
// CHECK: %[[rs2:.*]] = tfl.floor_div %[[rs1]]
// CHECK: return %[[rs2]]
}
// CHECK-LABEL: @NotReorderReshapeAddIfNotTailingDim
func @NotReorderReshapeAddIfNotTailingDim(%arg0: tensor<40x40x1xf32>) -> tensor<40x40xf32> {
%cst = constant dense<2.0> : tensor<1x40xf32>
@ -851,17 +877,17 @@ func @fuseDivIntoConv2d_Scalar(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x1
}
// CHECK-LABEL: @fuseMulIntoConv2d_Scalar
func @fuseMulIntoConv2d_Scalar(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> {
func @fuseMulIntoConv2d_Scalar(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x1xf32> {
%cst0 = constant dense<[[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]> : tensor<1x2x2x2xf32>
%cst1 = constant dense<1.0> : tensor<2xf32>
%cst1 = constant dense<1.0> : tensor<1xf32>
%cst2 = constant dense<2.0> : tensor<f32>
%0 = "tfl.conv_2d"(%arg0, %cst0, %cst1) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<1x2x2x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<1x112x112x2xf32>, tensor<f32>) -> tensor<1x112x112x2xf32>
%0 = "tfl.conv_2d"(%arg0, %cst0, %cst1) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<1x2x2x2xf32>, tensor<1xf32>) -> tensor<1x112x112x1xf32>
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<1x112x112x1xf32>, tensor<f32>) -> tensor<1x112x112x1xf32>
return %1 : tensor<1x112x112x2xf32>
return %1 : tensor<1x112x112x1xf32>
// CHECK: %[[CST1:.*]] = constant dense<{{\[\[\[\[}}2.000000e+00, 4.000000e+00], [6.000000e+00, 8.000000e+00]], {{\[\[}}1.000000e+01, 1.200000e+01], [1.400000e+01, 1.600000e+01]]]]> : tensor<1x2x2x2xf32>
// CHECK: %[[CST2:.*]] = constant dense<2.000000e+00> : tensor<2xf32>
// CHECK: %[[RES:[0-9].*]] = "tfl.conv_2d"(%arg0, %[[CST1]], %[[CST2]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<1x2x2x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
// CHECK: %[[CST2:.*]] = constant dense<2.000000e+00> : tensor<1xf32>
// CHECK: %[[RES:[0-9].*]] = "tfl.conv_2d"(%arg0, %[[CST1]], %[[CST2]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<1x2x2x2xf32>, tensor<1xf32>) -> tensor<1x112x112x1xf32>
// CHECK: return %[[RES]]
}
@ -896,6 +922,36 @@ func @fuseTileWithBinaryOp1(%arg0: tensor<1x1xf32>, %arg1: tensor<1x128xf32>) ->
// CHECK: return %[[RES]]
}
// CHECK-LABEL: notFuseTileWithBinaryOpOn5DInputs
func @notFuseTileWithBinaryOpOn5DInputs(%arg0: tensor<1x1xf32>) -> tensor<1x1x1x1x2xf32> {
%cst = constant dense<[1, 1, 1, 1, 2]> : tensor<5xi32>
%cst1 = constant dense<3.0> : tensor<1x1x1x1x2xf32>
%0 = "tfl.sqrt"(%arg0) : (tensor<1x1xf32>) -> tensor<1x1xf32>
%1 = "tfl.tile"(%0, %cst) : (tensor<1x1xf32>, tensor<5xi32>) -> tensor<1x1x1x1x2xf32>
%2 = "tfl.add"(%cst1, %1) {fused_activation_function = "NONE"} : (tensor<1x1x1x1x2xf32>, tensor<1x1x1x1x2xf32>) -> tensor<1x1x1x1x2xf32>
return %2 : tensor<1x1x1x1x2xf32>
// CHECK: "tfl.sqrt"
// CHECK: "tfl.tile"
// CHECK: tfl.add
}
// CHECK-LABEL: notFuseTileWithBinaryOp1On5DInputs
func @notFuseTileWithBinaryOp1On5DInputs(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1x1x1x128xf32>) -> tensor<1x1x1x1x128xf32> {
%cst_0 = constant dense<1.0> : tensor<f32>
%cst_1 = constant dense<[1, 1, 1, 1, 128]> : tensor<5xi32>
%0 = "tfl.add"(%arg0, %cst_0) {fused_activation_function = "NONE"} : (tensor<1x1xf32>, tensor<f32>) -> tensor<1x1xf32>
%1 = "tfl.sqrt"(%0) : (tensor<1x1xf32>) -> tensor<1x1xf32>
%2 = "tfl.tile"(%1, %cst_1) : (tensor<1x1xf32>, tensor<5xi32>) -> tensor<1x1x1x1x128xf32>
%3 = "tfl.div"(%2, %arg1) {fused_activation_function = "NONE"} : (tensor<1x1x1x1x128xf32>, tensor<1x1x1x1x128xf32>) -> tensor<1x1x1x1x128xf32>
return %3 : tensor<1x1x1x1x128xf32>
// CHECK: "tfl.add"
// CHECK: "tfl.sqrt"
// CHECK: "tfl.tile"
// CHECK: tfl.div
}
// CHECK-LABEL: InvalidFuseTileWithBinaryOp
func @InvalidFuseTileWithBinaryOp(%arg0: tensor<2x3xf32>) -> tensor<2x6xf32> {
%cst = constant dense<[[1,2]]> : tensor<1x2xi32>
@ -1155,6 +1211,18 @@ func @ReorderAddWithConstant(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[RESULT:.*]] = tfl.add %arg0, %[[CONST]] {fused_activation_function = "NONE"} : tensor<2x2xf32>
}
func @NotReorderAddWithConstantOn5D(%arg0: tensor<2x2x2x2x2xf32>) -> tensor<2x2x2x2x2xf32> {
%cst = constant dense<1.0> : tensor<2x2x2x2x2xf32>
%cst_1 = constant dense<2.0> : tensor<2x2x2x2x2xf32>
%0 = "tfl.add"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<2x2x2x2x2xf32>, tensor<2x2x2x2x2xf32>) -> tensor<2x2x2x2x2xf32>
%1 = "tfl.add"(%0, %cst_1) {fused_activation_function = "NONE"} : (tensor<2x2x2x2x2xf32>, tensor<2x2x2x2x2xf32>) -> tensor<2x2x2x2x2xf32>
return %1 : tensor<2x2x2x2x2xf32>
// CHECK-LABEL: NotReorderAddWithConstantOn5D
// CHECK: tfl.add
// CHECK: tfl.add
}
func @RemoveCast(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
%1 = "tfl.cast"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %1 : tensor<2x2xf32>
@ -1397,3 +1465,50 @@ func @fuseExpanded1DMulIntoConv2d(%arg0: tensor<1x8x8x207xf32>) -> tensor<1x8x8x
// CHECK: "tfl.conv_2d"(%arg0, %[[CST_0]], %[[CST_1]])
}
// CHECK-LABEL: @FuseFullyConnectedAddWithSplat2D
func @FuseFullyConnectedAddWithSplat2D(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
%cst = constant unit
%cst2 = constant dense<2.0> : tensor<40x40xf32>
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>)
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40x40xf32>) -> tensor<40x40xf32>
return %1 : tensor<40x40xf32>
// CHECK: %[[BIAS:.*]] = constant dense<2.000000e+00> : tensor<40xf32>
// CHECK: %[[FC_RESULT:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[BIAS]])
// CHECK: return %[[FC_RESULT]]
}
// CHECK-LABEL: @fuseMulIntoConv2d_Splat2D
func @fuseMulIntoConv2d_Splat2D(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> {
%cst0 = constant dense<[[[[1.0, 2.0]]], [[[3.0, 4.0]]]]> : tensor<2x1x1x2xf32>
%cst1 = constant dense<1.0> : tensor<2xf32>
%cst2 = constant dense<2.0> : tensor<1x112x112x2xf32>
%0 = "tfl.conv_2d"(%arg0, %cst0, %cst1) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<2x1x1x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<1x112x112x2xf32>, tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32>
return %1 : tensor<1x112x112x2xf32>
// CHECK: %[[CST1:.*]] = constant dense<{{\[\[\[\[}}2.000000e+00, 4.000000e+00]]], {{\[\[\[}}6.000000e+00, 8.000000e+00]]]]> : tensor<2x1x1x2xf32>
// CHECK: %[[CST2:.*]] = constant dense<2.000000e+00> : tensor<2xf32>
// CHECK: %[[RES:[0-9].*]] = "tfl.conv_2d"(%arg0, %[[CST1]], %[[CST2]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<2x1x1x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
// CHECK: return %[[RES]]
}
// CHECK-LABEL: @AvoidFuseFullyConnectedAddWithSplat2D
func @AvoidFuseFullyConnectedAddWithSplat2D(%arg0: tensor<1x1x1x1x1xf32>, %arg1: tensor<1x1xf32>) -> tensor<1x1x1x1x1xf32> {
%cst = constant unit
%cst2 = constant dense<2.0> : tensor<1x1x1x1x1xf32>
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x1x1x1x1xf32>, tensor<1x1xf32>, none) -> tensor<1x1x1x1x1xf32>
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<1x1x1x1x1xf32>, tensor<1x1x1x1x1xf32>) -> tensor<1x1x1x1x1xf32>
return %1 : tensor<1x1x1x1x1xf32>
// CHECK: %[[CST1:.*]] = constant unit
// CHECK: %[[CST2:.*]] = constant dense<2.000000e+00> : tensor<1x1x1x1x1xf32>
// CHECK: %[[FC_RESULT:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[CST1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x1x1x1x1xf32>, tensor<1x1xf32>, none) -> tensor<1x1x1x1x1xf32>
// CHECK: %[[ADD:.*]] = tfl.add %[[FC_RESULT]], %[[CST2]] {fused_activation_function = "NONE"} : tensor<1x1x1x1x1xf32>
// CHECK: return %[[ADD]] : tensor<1x1x1x1x1xf32>
}

View File

@ -77,3 +77,32 @@ func @HandleReturnedDequantizeWithAnotherUse(%arg0: tensor<128x16xf32>) -> (tens
// CHECK-NEXT: return %[[softmax]], %[[argmax]] : tensor<128x16xf32>, tensor<128xi32>
return %2, %3 : tensor<128x16xf32>, tensor<128xi32>
}
// CHECK-LABEL: PruneUnusedLstm
func @PruneUnusedLstm(%arg0: tensor<1x28x28xf32>) -> (tensor<1x28x28xf32>) {
%input = "tfl.quantize"(%arg0) {qtype = tensor<1x28x28x!quant.uniform<i8:f32, 0.003:-128>>} : (tensor<1x28x28xf32>) -> tensor<1x28x28x!quant.uniform<i8:f32, 0.003:-128>>
%cst_1 = "tfl.pseudo_qconst"() {qtype = tensor<1x20x!quant.uniform<i8:f32, 0.006:-34>>, value = dense<1> : tensor<1x20xi8>} : () -> tensor<1x20x!quant.uniform<i8:f32, 0.006:-34>>
%cst_2 = constant unit
%cst_3 = "tfl.pseudo_qconst"() {qtype = tensor<20x20x!quant.uniform<i8:f32, 0.006:-34>>, value = dense<1> : tensor<20x20xi8>} : () -> tensor<20x20x!quant.uniform<i8:f32, 0.006:-34>>
%cst_7 = "tfl.pseudo_qconst"() {qtype = tensor<20x!quant.uniform<i8:f32, 0.006:-34>>, value = dense<1> : tensor<20xi8>} : () -> tensor<20x!quant.uniform<i8:f32, 0.006:-34>>
%cst_11 = "tfl.pseudo_qconst"() {qtype = tensor<20x28x!quant.uniform<i8:f32, 0.006:-34>>, value = dense<1> : tensor<20x28xi8>} : () -> tensor<20x28x!quant.uniform<i8:f32, 0.006:-34>>
%cell_input = "tfl.pseudo_qconst"() {qtype = tensor<1x20x!quant.uniform<i16:f32, 0.006:-34>>, value = dense<1> : tensor<1x20xi6>} : () -> tensor<1x20x!quant.uniform<i16:f32, 0.006:-34>>
%0 = "tfl.unidirectional_sequence_lstm"(%input,
%cst_11, %cst_11, %cst_11, %cst_11,
%cst_3, %cst_3, %cst_3, %cst_3,
%cst_2, %cst_2, %cst_2,
%cst_7, %cst_7, %cst_7, %cst_7,
%cst_2, %cst_2,
%cst_1, %cell_input,
%cst_2, %cst_2, %cst_2, %cst_2) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false}
: ( tensor<1x28x28x!quant.uniform<i8:f32, 0.003:-128>>,
tensor<20x28x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x28x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x28x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x28x!quant.uniform<i8:f32, 0.006:-34>>,
tensor<20x20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x20x!quant.uniform<i8:f32, 0.006:-34>>,
none, none, none,
tensor<20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x!quant.uniform<i8:f32, 0.006:-34>>,
none, none,
tensor<1x20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<1x20x!quant.uniform<i16:f32, 0.006:-34>>,
none, none, none, none) -> tensor<1x28x20x!quant.uniform<i8:f32, 0.006:-34>>
return %arg0 : tensor<1x28x28xf32>
// CHECK-NEXT: return %arg0
}

View File

@ -500,21 +500,21 @@ func @nms_padded(%arg0: tensor<100x4xf32>, %arg1: tensor<100xf32>, %arg2: tensor
module {
// expected-error @+1 {{Invalid number of results from non_max_suppression_padded_v2}}
func @nms_padded_invalid_num_results(%arg0: tensor<100x4xf32>, %arg1: tensor<100xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<i1>, %arg6: tensor<i1>, %arg7: tensor<i1>, %arg8: tensor<i32>) -> () attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"}
func private @nms_padded_invalid_num_results(%arg0: tensor<100x4xf32>, %arg1: tensor<100xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<i1>, %arg6: tensor<i1>, %arg7: tensor<i1>, %arg8: tensor<i32>) -> () attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"}
// expected-error @+1 {{Invalid number of arguments to non_max_suppression_padded_v2}}
func @nms_padded_invalid_num_args(%arg0: tensor<100x4xf32>, %arg1: tensor<100xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>) -> (tensor<1x10xi32>, tensor<i32>) attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"}
func private @nms_padded_invalid_num_args(%arg0: tensor<100x4xf32>, %arg1: tensor<100xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>) -> (tensor<1x10xi32>, tensor<i32>) attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"}
// expected-error @+1 {{TFLite does not support batched input for non_max_suppression_padded}}
func @nms_padded_with_batches(%arg0: tensor<2x100x4xf32>, %arg1: tensor<2x100xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<i1>, %arg6: tensor<i1>, %arg7: tensor<i1>, %arg8: tensor<i32>) -> (tensor<2x10xi32>, tensor<i32>) attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"}
func private @nms_padded_with_batches(%arg0: tensor<2x100x4xf32>, %arg1: tensor<2x100xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<i1>, %arg6: tensor<i1>, %arg7: tensor<i1>, %arg8: tensor<i32>) -> (tensor<2x10xi32>, tensor<i32>) attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"}
}
// -----
module {
// CHECK-LABEL: func @some_func
// CHECK-LABEL: func private @some_func
// CHECK-LABEL: func @func_with_call
func @some_func(%arg0: tensor<100xf32>) -> tensor<100xf32> attributes {tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c"}
func private @some_func(%arg0: tensor<100xf32>) -> tensor<100xf32> attributes {tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c"}
func @func_with_call(%arg0: tensor<100xf32>) -> tensor<100xf32> {
%0 = call @some_func(%arg0) : (tensor<100xf32>) -> tensor<100xf32>
return %0 : tensor<100xf32>
@ -545,13 +545,13 @@ func @tflite_custom_nms(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>,
module {
// expected-error @+1 {{Invalid number of results from TFLite_Detection_PostProcess}}
func @tflite_custom_nms_invalid_results(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"}
func private @tflite_custom_nms_invalid_results(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"}
// expected-error @+1 {{Invalid number of arguments to TFLite_Detection_PostProcess}}
func @tflite_custom_nms_invalid_args(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"}
func private @tflite_custom_nms_invalid_args(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"}
// expected-error @+1 {{max_classes_per_detection attribute is not set or not an integer}}
func @tflite_custom_nms_missing_func_args(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"} {
func private @tflite_custom_nms_missing_func_args(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"} {
%0 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
%1 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
%2 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>

View File

@ -166,3 +166,37 @@ func @QuantizeTransposeConv(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<4xi32>)
// PerTensor: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) : (tensor<1x32x42x128x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>) -> tensor<1x32x42x128xf32>
// PerTensor: "tfl.transpose_conv"(%arg1, %arg0, %[[DEQUANTIZE]]
}
// CHECK-LABEL: QuantizeLstmCellInput
func @QuantizeLstmCellInput(%arg0: tensor<1x28x28xf32>) -> tensor<1x28x20xf32> {
%cst_1 = constant dense<1.0> : tensor<1x20xf32>
%cst_2 = constant unit
%cst_3 = constant dense<1.0> : tensor<20x20xf32>
%cst_7 = constant dense<1.0> : tensor<20xf32>
%cst_11 = constant dense<1.0> : tensor<20x28xf32>
%cell_input = constant dense<0.0> : tensor<1x20xf32>
%cell_stats = "quant.stats"(%cell_input) {layerStats = dense<[-2.73090601, 7.94872093]> : tensor<2xf32>} : (tensor<1x20xf32>) -> tensor<1x20xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0,
%cst_11, %cst_11, %cst_11, %cst_11,
%cst_3, %cst_3, %cst_3, %cst_3,
%cst_2, %cst_2, %cst_2,
%cst_7, %cst_7, %cst_7, %cst_7,
%cst_2, %cst_2,
%cst_1, %cell_stats,
%cst_2, %cst_2, %cst_2, %cst_2) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false}
: ( tensor<1x28x28xf32>,
tensor<20x28xf32>, tensor<20x28xf32>, tensor<20x28xf32>, tensor<20x28xf32>,
tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>,
none, none, none,
tensor<20xf32>, tensor<20xf32>, tensor<20xf32>, tensor<20xf32>,
none, none,
tensor<1x20xf32>, tensor<1x20xf32>,
none, none, none, none) -> tensor<1x28x20xf32>
return %0 : tensor<1x28x20xf32>
// CHECK: %[[none:.*]] = constant unit
// CHECK: %[[cell_input:.*]] = constant dense<0.000000e+00> : tensor<1x20xf32>
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cell_input]]) {qtype = tensor<1x20x!quant.uniform<i16:f32, 2.44140625E-4>>} : (tensor<1x20xf32>) -> tensor<1x20x!quant.uniform<i16:f32, 2.44140625E-4>>
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) : (tensor<1x20x!quant.uniform<i16:f32, 2.44140625E-4>>) -> tensor<1x20xf32>
// Checks if input 19 is correctly passed from a dequantize op.
// CHECK: %[[lstm:.*]] = "tfl.unidirectional_sequence_lstm"(%arg0, {{(%[^%,]+, )+}}%[[dq]], %[[none]], %[[none]], %[[none]], %[[none]])
}

View File

@ -520,6 +520,17 @@ func @PadStridedSliceNewAxisMask2(%arg0: tensor<4x64x64x1xf32>) -> tensor<1x4x64
return %1 : tensor<1x4x64x64xf32>
}
// CHECK-LABEL: @AvoidPadStridedSliceNewAxisMaskOnUnknownShapes
func @AvoidPadStridedSliceNewAxisMaskOnUnknownShapes(%arg0: tensor<?x?xf32>) -> tensor<1x?x?x1xf32> {
%cst = constant dense<0> : tensor<4xi32>
%cst_0 = constant dense<1> : tensor<4xi32>
%0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 9 : i64, shrink_axis_mask = 0 : i64} : (tensor<?x?xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x?x?x1xf32>
return %0 : tensor<1x?x?x1xf32>
// CHECK-NOT: "tf.Reshape"
// CHECK: "tf.StridedSlice"
}
// CHECK-LABEL: @StridedSliceRewriteMasks
func @StridedSliceRewriteMasks(%arg0: tensor<8x4x16x2xf32>) -> tensor<8x4x16x1xf32> {
%cst = "tf.Const"() {device = "", value = dense<[1, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
@ -540,37 +551,6 @@ func @StridedSliceRewriteMasks(%arg0: tensor<8x4x16x2xf32>) -> tensor<8x4x16x1xf
return %0 : tensor<8x4x16x1xf32>
}
// CHECK-LABEL: @MatrixSetDiagV2Conversion
func @MatrixSetDiagV2Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
%cst = constant dense<0> : tensor<i32>
%0 = "tf.MatrixSetDiagV2"(%arg0, %arg1, %cst) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
return %0 : tensor<3x3xi32>
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
// CHECK: return %[[RES]]
}
// CHECK-LABEL: @MatrixSetDiagV2NonZeroK
func @MatrixSetDiagV2NonZeroK(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
%cst = constant dense<1> : tensor<i32>
%0 = "tf.MatrixSetDiagV2"(%arg0, %arg1, %cst) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
return %0 : tensor<3x3xi32>
// CHECK: %[[CST:.*]] = constant dense<1> : tensor<i32>
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiagV2"(%arg0, %arg1, %[[CST]]) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
// CHECK: return %[[RES]]
}
// CHECK-LABEL: @MatrixSetDiagV3Conversion
func @MatrixSetDiagV3Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
%cst = constant dense<0> : tensor<i32>
%0 = "tf.MatrixSetDiagV3"(%arg0, %arg1, %cst) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
return %0 : tensor<3x3xi32>
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
// CHECK: return %[[RES]]
}
func @broadcast_to_f32_low_dim(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
return %0: tensor<3x3xf32>

View File

@ -4,10 +4,10 @@ func @testSingleLstm(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<4x4
// CHECK-LABEL: testSingleLstm
// CHECK: %[[CST_0:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
// CHECK: %[[CST_1:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
// CHECK: %[[LSTM:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
// CHECK: %[[LSTM:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg1, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
%0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const")
%1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
%1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg1, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
return %1 : tensor<4x4xf32>
}
@ -15,13 +15,13 @@ func @testMultipleLstms(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<
// CHECK-LABEL: testMultipleLstms
// CHECK: %[[CST_0:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
// CHECK: %[[CST_1:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
// CHECK: %[[LSTM_1:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
// CHECK: %[[LSTM_1:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg1, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
// CHECK: %[[CST_2:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
// CHECK: %[[CST_3:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
// CHECK: %[[LSTM_2:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%[[LSTM_1]], %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %[[CST_2]], %[[CST_3]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
// CHECK: %[[LSTM_2:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%[[LSTM_1]], %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg1, %[[CST_2]], %[[CST_3]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
%0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const")
%1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
%2 = "tfl.unidirectional_sequence_lstm"(%1, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
%1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg1, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
%2 = "tfl.unidirectional_sequence_lstm"(%1, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg1, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
return %2 : tensor<4x4xf32>
}

View File

@ -30,9 +30,9 @@ func @while() -> tensor<1xf32>
}) : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>) loc("WhileOp")
return %0#1 : tensor<1xf32>
}
// CHECK-LABEL: func @WhileOp_cond(
// CHECK-LABEL: func private @WhileOp_cond(
// CHECK: tfl.greater
// CHECK-LABEL: func @WhileOp_body(
// CHECK-LABEL: func private @WhileOp_body(
// CHECK: tfl.sub
// CHECK: tfl.add
@ -63,21 +63,21 @@ func @while2(%cst : tensor<i32>) -> tensor<1xf32> attributes {tf.entry_function
return %0#1 : tensor<1xf32>
}
func @WhileOp_cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>, %arg2: tensor<i32>) -> tensor<i1> attributes {sym_visibility = "private"} {
func private @WhileOp_cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>, %arg2: tensor<i32>) -> tensor<i1> {
%cst = constant dense<0> : tensor<i32>
%0 = "tfl.greater"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
func @WhileOp_body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>, %arg2: tensor<i32>) -> (tensor<*xi32>, tensor<*xf32>, tensor<i32>) attributes {sym_visibility = "private"} {
func private @WhileOp_body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>, %arg2: tensor<i32>) -> (tensor<*xi32>, tensor<*xf32>, tensor<i32>) {
%0 = "tfl.sub"(%arg0, %arg2) {fused_activation_function = "NONE"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%1 = tfl.add %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32>
return %0, %1, %arg2 : tensor<*xi32>, tensor<*xf32>, tensor<i32>
}
// CHECK-LABEL: func @WhileOp_cond(
// CHECK-LABEL: func private @WhileOp_cond(
// CHECK: tfl.greater
// CHECK-LABEL: func @WhileOp_body(
// CHECK-LABEL: func private @WhileOp_body(
// CHECK: tfl.sub
// CHECK: tfl.add
@ -152,14 +152,14 @@ func @rnn(%arg0: tensor<4x4x3xf32> {tf.device = "/device:CPU:0"}) -> tensor<4x?x
// CHECK: tfl.yield
// CHECK-SAME: (tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<4x4x3xf32>) -> ()
// CHECK-LABEL: func @tfl.while_cond(
// CHECK-SAME: [[VAL_35:%.*]]: tensor<i32>, [[VAL_36:%.*]]: tensor<i32>, [[VAL_37:%.*]]: tensor<*xf32>, [[VAL_38:%.*]]: tensor<4x2xf32>, [[VAL_39:%.*]]: tensor<4x2xf32>, [[VAL_40:%.*]]: tensor<*xf32>, [[VAL_41:%.*]]: tensor<4x4x3xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
// CHECK-LABEL: func private @tfl.while_cond(
// CHECK-SAME: [[VAL_35:%.*]]: tensor<i32>, [[VAL_36:%.*]]: tensor<i32>, [[VAL_37:%.*]]: tensor<*xf32>, [[VAL_38:%.*]]: tensor<4x2xf32>, [[VAL_39:%.*]]: tensor<4x2xf32>, [[VAL_40:%.*]]: tensor<*xf32>, [[VAL_41:%.*]]: tensor<4x4x3xf32>) -> tensor<i1> {
// CHECK: return
// CHECK-SAME: tensor<i1>
// CHECK: }
// CHECK-LABEL: func @tfl.while_body(
// CHECK-SAME: [[VAL_46:%.*]]: tensor<i32>, [[VAL_47:%.*]]: tensor<i32>, [[VAL_48:%.*]]: tensor<*xf32>, [[VAL_49:%.*]]: tensor<4x2xf32>, [[VAL_50:%.*]]: tensor<4x2xf32>, [[VAL_51:%.*]]: tensor<*xf32>, [[VAL_52:%.*]]: tensor<4x4x3xf32>) -> (tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<4x4x3xf32>) attributes {sym_visibility = "private"} {
// CHECK-LABEL: func private @tfl.while_body(
// CHECK-SAME: [[VAL_46:%.*]]: tensor<i32>, [[VAL_47:%.*]]: tensor<i32>, [[VAL_48:%.*]]: tensor<*xf32>, [[VAL_49:%.*]]: tensor<4x2xf32>, [[VAL_50:%.*]]: tensor<4x2xf32>, [[VAL_51:%.*]]: tensor<*xf32>, [[VAL_52:%.*]]: tensor<4x4x3xf32>) -> (tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<4x4x3xf32>) {
// CHECK: [[VAL_91:%.*]] = "tfl.cast"
// CHECK: return
// CHECK-SAME: [[VAL_91]], [[VAL_52]] : tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<4x4x3xf32>

View File

@ -234,6 +234,11 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
// tf.variable to model this.
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreateSplitMergedOperandsPass());
// Add CallOnceOp when there is a session initializer function in tf saved
// model dialect.
pass_manager->addPass(
mlir::TFL::CreateInsertCallOnceOpFromSessionInitializerPass());
}
}

View File

@ -0,0 +1,78 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
namespace mlir {
namespace TFL {
namespace {
// This pass inserts a TFL::CallOnce op when tf_saved_model's session
// initializer is given.
class InsertCallOnceOpFromSessionInitializerPass
: public mlir::PassWrapper<InsertCallOnceOpFromSessionInitializerPass,
OperationPass<ModuleOp>> {
private:
void runOnOperation() override;
};
void InsertCallOnceOpFromSessionInitializerPass::runOnOperation() {
ModuleOp module = getOperation();
tf_saved_model::SessionInitializerOp session_init_op =
tf_saved_model::GetSessionInitializerOp(module);
if (!session_init_op) return;
SymbolTable symbol_table(module);
for (auto sym_ref : session_init_op.initializers()) {
FuncOp init_func_op = symbol_table.lookup<mlir::FuncOp>(
sym_ref.cast<FlatSymbolRefAttr>().getValue());
if (!init_func_op) {
module.emitError("no session initializer function found");
return signalPassFailure();
}
for (auto func : module.getOps<FuncOp>()) {
auto dict_attr =
func.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
if (!dict_attr) continue;
OpBuilder builder(func.getContext());
builder.setInsertionPointToStart(&func.getBlocks().front());
builder.create<TFL::CallOnceOp>(func.getLoc(), init_func_op.getName());
}
}
}
} // namespace
// Inserts a TFL::CallOnce op when tf_saved_model's session initializer is
// given.
std::unique_ptr<OperationPass<ModuleOp>>
CreateInsertCallOnceOpFromSessionInitializerPass() {
return std::make_unique<InsertCallOnceOpFromSessionInitializerPass>();
}
static PassRegistration<InsertCallOnceOpFromSessionInitializerPass> pass(
"tfl-insert-call-once-op",
"Insert CallOnce op when tf_saved_model's session initializer is given");
} // namespace TFL
} // namespace mlir

View File

@ -54,7 +54,7 @@ def ExtractSingleElementAsInt32 : NativeCodeCall<
"$_builder.getI32IntegerAttr(ExtractSingleElementAsInteger($_self.cast<ElementsAttr>()).getInt())">;
// Converts tensor with int64 to int32.
def CreateTFLCastToInt32Op : NativeCodeCall<
def CreateTFCastToInt32Op : NativeCodeCall<
"CreateCastToInt32($0, $_loc, $_builder)">;
// Checks whether the given operation has static shapes and same shapes of all inputs.
@ -193,8 +193,8 @@ def LegalizeRound : Pat<(TF_RoundOp $arg), (TFL_RoundOp $arg)>;
def LegalizeRsqrt : Pat<(TF_RsqrtOp $arg), (TFL_RsqrtOp $arg)>;
def LegalizeSqrt : Pat<(TF_SqrtOp $arg), (TFL_SqrtOp $arg)>;
def LegalizeSquare : Pat<(TF_SquareOp $arg), (TFL_SquareOp $arg)>;
def LegalizeSegmentSum : Pat<(TF_SegmentSumOp $data, I32Tensor:$segment_ids),
(TFL_SegmentSumOp $data, $segment_ids)>;
def LegalizeSegmentSum : Pat<(TF_SegmentSumOp $data, $segment_ids),
(TFL_SegmentSumOp $data, (CreateTFCastToInt32Op $segment_ids))>;
def LegalizeSelect : Pat<(TF_SelectOp $cond, $x, $y),
(TFL_SelectOp $cond, $x, $y)>;
def LegalizeSelectV2SameStaticShape : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y),
@ -221,7 +221,7 @@ def LegalizeTanh : Pat<(TF_TanhOp $arg), (TFL_TanhOp $arg)>;
def LegalizeTranspose : Pat<(TF_TransposeOp $arg, $perm),
(TFL_TransposeOp $arg,
(CreateTFLCastToInt32Op $perm))>;
(CreateTFCastToInt32Op $perm))>;
def LegalizeWhere : Pat<(TF_WhereOp $arg), (TFL_WhereOp $arg)>;
def LegalizeZerosLike : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>;
@ -309,8 +309,9 @@ def LegalizeRank : Pat<(TF_RankOp $input), (TFL_RankOp $input)>;
def LegalizeSquaredDifference : Pat<(TF_SquaredDifferenceOp $l, $r),
(TFL_SquaredDifferenceOp $l, $r)>;
def LegalizeReverseV2 : Pat<(TF_ReverseV2Op $arg0, $arg1),
(TFL_ReverseV2Op $arg0, $arg1)>;
def LegalizeReverseV2 : Pat<
(TF_ReverseV2Op $arg0, $axis),
(TFL_ReverseV2Op $arg0, (CreateTFCastToInt32Op $axis))>;
def LegalizeEqual : Pat<(TF_EqualOp $arg0, $arg1,
/*incompatible_shape_error=*/ConstBoolAttrTrue),
@ -349,11 +350,13 @@ def LegalizeCast : Pat<(TF_CastOp $arg0, BoolAttr:$arg1), (TFL_CastOp $arg0)>;
def LegalizeBatchToSpaceND : Pat<
(TF_BatchToSpaceNDOp $input, $block_shape, $crops),
(TFL_BatchToSpaceNdOp $input, $block_shape, $crops)>;
(TFL_BatchToSpaceNdOp $input, (CreateTFCastToInt32Op $block_shape),
(CreateTFCastToInt32Op $crops))>;
def LegalizeSpaceToBatchND : Pat<
(TF_SpaceToBatchNDOp $input, $block_shape, $paddings),
(TFL_SpaceToBatchNdOp $input, $block_shape, $paddings)>;
(TFL_SpaceToBatchNdOp $input, (CreateTFCastToInt32Op $block_shape),
(CreateTFCastToInt32Op $paddings))>;
def LegalizeSpaceToDepth : Pat<
(TF_SpaceToDepthOp $input, $block_size, IsDataFormatNHWC:$data_format),
@ -437,14 +440,34 @@ def LegalizeConv2DBackpropInput : Pat<
/*stride_h=*/ ExtractI32At<1>:$strides,
/*stride_w=*/ ExtractI32At<2>:$strides)>;
def IsRankZeroAttr
: CPred<"$_self.cast<DenseElementsAttr>().getType().getRank() == 0">;
def HasValueZero
: CPred<"$_self.cast<DenseElementsAttr>().getSplatValue()."
"cast<::mlir::IntegerAttr>().getInt() == 0">;
// TFLite only supports MatrixSetDiag ops with scalar zero k attribute.
def IsSupportedByTFLiteMatrixSetDiag
: ElementsAttrBase<And<[ElementsAttr.predicate,
IsRankZeroAttr, HasValueZero]>,
"MatrixSetDiag attribute verification">;
// Attribute align doesn't matter when k is zero.
def LegalizeMatrixSetDiag : Pat<
(TF_MatrixSetDiagOp $input, $diagonal),
(TF_MatrixSetDiagV3Op $input, $diagonal,
(ConstantLikeMatcher IsSupportedByTFLiteMatrixSetDiag:$k), $align),
(TFL_MatrixSetDiagOp $input, $diagonal)>;
def LegalizeScatterNd : Pat<
(TF_ScatterNdOp I32Tensor:$indices, $updates, $shape),
(TFL_ScatterNdOp I32Tensor:$indices, $updates, $shape)>;
(TF_ScatterNdOp $indices, $updates, $shape),
(TFL_ScatterNdOp (CreateTFCastToInt32Op $indices), $updates,
(CreateTFCastToInt32Op $shape))>;
def LegalizeCumsum : Pat<
(TF_CumsumOp $input, $axis, $exclusive, $reverse),
(TFL_CumsumOp $input, $axis, $exclusive, $reverse)>;
(TFL_CumsumOp $input, (CreateTFCastToInt32Op $axis), $exclusive, $reverse)>;
def LegalizeReshape : Pat<
(TF_ReshapeOp $input, $shape),
(TFL_ReshapeOp $input, (CreateTFCastToInt32Op $shape))>;

View File

@ -123,7 +123,8 @@ Value CreateCastToInt32(Value val, Location loc, PatternRewriter& rewriter) {
auto shape = val.getType().dyn_cast<RankedTensorType>().getShape();
IntegerType new_ele_type = rewriter.getIntegerType(32);
ShapedType new_type = RankedTensorType::get(shape, new_ele_type);
return rewriter.create<TFL::CastOp>(loc, new_type, val);
return rewriter.createOrFold<TF::CastOp>(loc, new_type, val,
rewriter.getBoolAttr(false));
}
#include "tensorflow/compiler/mlir/lite/transforms/generated_legalize_tf.inc"
@ -145,7 +146,6 @@ DECL_CONVERT_OP(MatMul);
DECL_CONVERT_OP(MatrixDiagV2);
DECL_CONVERT_OP(MatrixDiagV3);
DECL_CONVERT_OP(Pack);
DECL_CONVERT_OP(Reshape);
DECL_CONVERT_OP(Split);
DECL_CONVERT_OP(SplitV);
DECL_CONVERT_OP(StridedSlice);
@ -299,30 +299,6 @@ LogicalResult ConvertTFPackOp::matchAndRewrite(
return success();
}
LogicalResult ConvertTFReshapeOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_reshape_op = cast<TF::ReshapeOp>(op);
auto input = tf_reshape_op.tensor();
auto shape = tf_reshape_op.shape();
ShapedType shape_type = shape.getType().cast<ShapedType>();
// The tfl reshape's #2 operand needs to i32 tensor type, so we have to cast.
if (!shape_type.getElementType().isSignlessInteger(32)) {
auto new_shape = shape_type.getShape();
IntegerType new_ele_type = rewriter.getIntegerType(32);
ShapedType new_type = RankedTensorType::get(new_shape, new_ele_type);
// Uses TF::CastOp to be folded if the shape input is a constant.
shape = rewriter
.create<TF::CastOp>(op->getLoc(), new_type, shape,
rewriter.getBoolAttr(false))
.y();
}
rewriter.replaceOpWithNewOp<ReshapeOp>(op, tf_reshape_op.output().getType(),
input, shape);
return success();
}
LogicalResult ConvertTFSplitOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_split_op = cast<TF::SplitOp>(op);
@ -792,10 +768,9 @@ void addPatterns(MLIRContext* context, OwningRewritePatternList& patterns) {
populateWithGenerated(context, patterns);
patterns
.insert<ConvertTFConcatV2Op, ConvertTFMatMulOp, ConvertTFMatrixDiagV2Op,
ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFReshapeOp,
ConvertTFSplitOp, ConvertTFSplitVOp, ConvertTFStridedSliceOp,
ConvertTFUnpackOp, ConvertTFAssertOp, ConvertTFRandomUniformOp>(
context);
ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFSplitOp,
ConvertTFSplitVOp, ConvertTFStridedSliceOp, ConvertTFUnpackOp,
ConvertTFAssertOp, ConvertTFRandomUniformOp>(context);
// Ophint python converter converted tf node pattern.
patterns.insert<LegalizeUnidirectionalSequenceLstm,

View File

@ -62,7 +62,7 @@ void RunOnWhile(TF::WhileOp while_op) {
auto call = builder.create<CallOp>(while_op.getLoc(), func, new_operands);
builder.create<YieldOp>(while_op.getLoc(), call.getResults());
// Mark old function as private so that it can be DCE'd if not called.
func.setVisibility(SymbolTable::Visibility::Private);
func.setPrivate();
};
create_region_with_call(while_op.cond_function(), new_op.cond());
create_region_with_call(while_op.body_function(), new_op.body());

View File

@ -27,11 +27,14 @@ limitations under the License.
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
@ -286,6 +289,18 @@ static bool ShapeMatchesReduceWithKeepAxes(Value input,
return true;
}
static bool FloatValueEquals(const Attribute &attr, double value) {
auto fp_attr = attr.dyn_cast_or_null<DenseFPElementsAttr>();
if (!fp_attr) return false;
if (fp_attr.isSplat()) {
return fp_attr.getSplatValue<APFloat>().isExactlyValue(value);
}
return llvm::all_of(fp_attr.getFloatValues(), [value](const APFloat &f) {
return f.isExactlyValue(value);
});
}
#include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc"
// Fuse Add with proceeding FullyConnected.
@ -729,6 +744,144 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
}
};
// If the operand to a broadcastable op is a splat constant, try to replace it
// with a 0-d constant, e.g. before this optimization,
// %cst = constant dense<1.0> : tensor<16x16x4xf32>
// %0 = "tfl.conv_2d"...
// %1 = "tfl.add"(%0, %cst) : (tensor<16x16x4xf32>, tensor<16x16x4xf32>)
// After this optimization:
// %cst = constant dense<1.0> : tensor<f32>
// %0 = "tfl.conv_2d"...
// %1 = "tfl.add"(%0, %cst) : (tensor<16x16x4xf32>, tensor<f32>)
// This pattern can enable more fusing opportunities when the binary op is
// following conv ops.
template <typename BinaryOpType>
struct ScalarizeSplatConstantForBroadcastableOps
: public OpRewritePattern<BinaryOpType> {
using OpRewritePattern<BinaryOpType>::OpRewritePattern;
LogicalResult matchAndRewrite(BinaryOpType binary_op,
PatternRewriter &rewriter) const override {
DenseElementsAttr splat_elements_attr;
if (!IsScalarizableSplatConstant(binary_op.rhs(), &splat_elements_attr)) {
return failure();
}
constexpr int kSplatOperandIndex = 1;
auto result_type =
binary_op.getResult().getType().template cast<ShapedType>();
mlir::Value non_splat_operand =
binary_op.getOperand(1 - kSplatOperandIndex);
auto non_splat_operand_type =
non_splat_operand.getType().cast<ShapedType>();
// If the other operand's shape does not equal to the result shape, then we
// cannot scalarize the splat constant because the result shape relies on
// the splat constant op's shape for broadcasting.
if (!non_splat_operand_type.hasStaticShape() ||
non_splat_operand_type.getShape() != result_type.getShape() ||
non_splat_operand_type.getRank() > 4) {
return failure();
}
// If non-splat operand is not fusable affine ops, then no need to apply
// this transformation.
if (!CanFuseAffineOp(non_splat_operand.getDefiningOp(), binary_op)) {
return failure();
}
// Creates a new scalar constant op using the splat value.
mlir::Value splat_operand = binary_op.getOperand(kSplatOperandIndex);
auto scalar_elements_attr = DenseElementsAttr::get(
RankedTensorType::get({},
splat_elements_attr.getType().getElementType()),
splat_elements_attr.getSplatValue());
auto scalar_constant_op = rewriter.create<ConstantOp>(
splat_operand.getLoc(), scalar_elements_attr.getType(),
scalar_elements_attr);
binary_op.setOperand(kSplatOperandIndex, scalar_constant_op);
return success();
}
private:
// Returns true if this value is a splat constant op which can be scalarized.
// Also returns the elements attr if this value is indeed a splat constant.
bool IsScalarizableSplatConstant(mlir::Value value,
DenseElementsAttr *elements_attr) const {
if (!matchPattern(value, m_Constant(elements_attr))) {
return false;
}
auto element_type = value.getType().cast<ShapedType>().getElementType();
// Ignore per-axis quantized constants because after converting to scalar,
// we will lose per-axis qantization parameter.
if (element_type.isa<quant::UniformQuantizedPerAxisType>()) {
return false;
}
if (IsScalar(value)) {
return false;
}
return elements_attr->isSplat();
}
// If this type is a scalar shaped type.
bool IsScalar(mlir::Value value) const {
auto type = value.getType().dyn_cast<ShapedType>();
if (!type) {
return false;
}
if (!type.hasStaticShape()) {
return false;
}
return type.getNumElements() == 1;
}
// Returns true if we can fuse an affine op with consuming binary op.
bool CanFuseAffineOp(Operation *affine_op, Operation *binary_op) const {
if (!isa_and_nonnull<TFL::Conv2DOp, TFL::DepthwiseConv2DOp,
TFL::FullyConnectedOp>(affine_op)) {
return false;
}
DenseElementsAttr value;
// Check that bias are constants if not none.
Value bias = affine_op->getOperand(2);
if (!bias.getType().isa<NoneType>() &&
!matchPattern(bias, m_Constant(&value))) {
return false;
}
// If the binary op is mul/div, also check that filter is constant.
if (isa<TFL::MulOp, TFL::DivOp>(binary_op) &&
!matchPattern(affine_op->getOperand(1), m_Constant(&value))) {
return false;
}
// We can only fuse F32/BF16.
auto is_fusable_type = [](Type t) {
Type element_type = t;
if (auto shaped_type = t.dyn_cast<ShapedType>()) {
element_type = shaped_type.getElementType();
}
return element_type.isBF16() || element_type.isF32();
};
for (Type t : binary_op->getOperandTypes()) {
if (!is_fusable_type(t)) {
return false;
}
}
return true;
}
};
using ScalarizeSplatConstantForSub =
ScalarizeSplatConstantForBroadcastableOps<TFL::SubOp>;
using ScalarizeSplatConstantForAdd =
ScalarizeSplatConstantForBroadcastableOps<TFL::AddOp>;
using ScalarizeSplatConstantForMul =
ScalarizeSplatConstantForBroadcastableOps<TFL::MulOp>;
using ScalarizeSplatConstantForDiv =
ScalarizeSplatConstantForBroadcastableOps<TFL::DivOp>;
struct ConvertTrivialTransposeOpToReshapeOp
: public OpRewritePattern<TFL::TransposeOp> {
using OpRewritePattern<TFL::TransposeOp>::OpRewritePattern;
@ -818,6 +971,8 @@ void Optimize::runOnFunction() {
OwningRewritePatternList phase_2_patterns;
TFL::populateWithGenerated(ctx, phase_2_patterns);
phase_2_patterns.insert<
ScalarizeSplatConstantForAdd, ScalarizeSplatConstantForSub,
ScalarizeSplatConstantForMul, ScalarizeSplatConstantForDiv,
FuseFullyConnectedAndAdd, FuseFullyConnectedAndReluX<TFL::ReluOp, kRelu>,
FuseFullyConnectedAndReluX<TFL::Relu6Op, kRelu6>,
FuseFullyConnectedAndReluX<TFL::Relu1Op, kRelu1>,

View File

@ -376,13 +376,17 @@ multiclass FuseTileBroadcastIntoFollowingBinary<dag BinaryOp> {
(BinaryOp:$result (TFL_TileOp $input, (ConstantOp $tile)),
$operand, $act_func),
(BinaryOp $input, $operand, $act_func),
[(OperandsBroadcastToOutputType $input, $operand, $result)]>;
[(OperandsBroadcastToOutputType $input, $operand, $result),
(HasRankAtMost<4> $input),
(HasRankAtMost<4> $operand)]>;
def FuseTileBroadcastToBinaryOp2#BinaryOp : Pat<
(BinaryOp:$result $operand,
(TFL_TileOp $input, (ConstantOp $tile)), $act_func),
(BinaryOp $operand, $input, $act_func),
[(OperandsBroadcastToOutputType $operand, $input, $result)]>;
[(OperandsBroadcastToOutputType $operand, $input, $result),
(HasRankAtMost<4> $operand),
(HasRankAtMost<4> $input)]>;
}
// Multi-pattern consisting of matching stand-alone op or op followed by relu.
@ -427,8 +431,9 @@ foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in {
// `input`. In other words, the shape of the `Reshape` op are not
// changed after the transformation.
(IsTailOfShape $rhs, $input),
(HasRankAtMost<5> $input),
(HasRankAtMost<5> $rhs)]>;
(HasRankAtMost<4> $input),
(HasRankAtMost<4> $lhs),
(HasRankAtMost<4> $rhs)]>;
}
foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
@ -457,7 +462,10 @@ foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
// The result of the new "BinaryOp" will have the same shape as
// `input`. In other words, the shape of the `Reshape` op are not
// changed after the transformation.
(IsTailOfShape $rhs, $input)]>;
(IsTailOfShape $rhs, $input),
(HasRankAtMost<4> $input),
(HasRankAtMost<4> $lhs),
(HasRankAtMost<4> $rhs)]>;
}
// Reorder the element-wise value operations and the element move operations,
@ -495,9 +503,7 @@ def ConvertExpandDimsToReshape : Pat<
[(AnyStaticShapeTensor $expand_dims_op)]>;
class FloatValueEquals<string val> : Constraint<CPred<
"$0.isa<DenseFPElementsAttr>() && "
"llvm::all_of($0.cast<DenseElementsAttr>().getFloatValues(), "
"[](const APFloat& f) { return f.isExactlyValue(" # val # "); })">>;
"FloatValueEquals($0, " # val # ")">>;
// ReLU patterns
def MatchReluPattern : Pat<
@ -570,7 +576,10 @@ foreach ActFun = [TFL_AF_Relu, TFL_AF_Relu6, TFL_AF_Relu1, TFL_AF_None] in {
(TFL_AddOp $input,
(TFL_AddOp (ConstantOp $a), (ConstantOp $b), TFL_AF_None),
ActFun),
[(HasOneUse $first_output)]>;
[(HasOneUse $first_output),
(HasRankAtMost<4> $input),
(HasRankAtMost<4> $a),
(HasRankAtMost<4> $b)]>;
}
// We can eliminate Relu from Relu(SquaredDifference(x, y)),

View File

@ -94,6 +94,10 @@ std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass();
// Creates raise custom ops pass, which legalize custom ops to TFL::CustomOp
std::unique_ptr<OperationPass<FuncOp>> CreateRaiseCustomOpsPass();
// Inserts an TFL::CallOnce op when the tf_saved_model's session initialzer is
// given.
std::unique_ptr<OperationPass<ModuleOp>>
CreateInsertCallOnceOpFromSessionInitializerPass();
} // namespace TFL
} // namespace mlir

View File

@ -139,6 +139,30 @@ struct RemoveVolatileOps : public OpRewritePattern<DequantizeOp> {
}
};
// Removes LSTMs that have dangling output.
// LSTMs are not removed automatically becuase they are stateful ops.
template <typename LstmOpTy>
struct PruneUnusedLstm : public OpRewritePattern<LstmOpTy> {
public:
explicit PruneUnusedLstm(MLIRContext* context)
: OpRewritePattern<LstmOpTy>(context) {}
LogicalResult matchAndRewrite(LstmOpTy lstm_op,
PatternRewriter& rewriter) const override {
Operation* op = lstm_op.getOperation();
if (op->isKnownTerminator()) {
return failure();
}
for (auto result : op->getOpResults()) {
if (!result.use_empty()) {
return failure();
}
}
rewriter.eraseOp(op);
return success();
}
};
#include "tensorflow/compiler/mlir/lite/transforms/generated_post_quantize.inc"
void PostQuantizePass::runOnFunction() {
@ -147,6 +171,7 @@ void PostQuantizePass::runOnFunction() {
auto* ctx = func.getContext();
TFL::populateWithGenerated(ctx, patterns);
patterns.insert<quant::FoldTrivalRequantizeOp<QuantizeOp>>(ctx);
patterns.insert<PruneUnusedLstm<TFL::UnidirectionalSequenceLSTMOp>>(ctx);
applyPatternsAndFoldGreedily(func, std::move(patterns));
if (!emit_quant_adaptor_ops_) {

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
// This transformation pass applies quantization propagation on TFLite dialect.
#include <cmath>
#include <iterator>
#include <string>
@ -21,10 +22,13 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
@ -305,6 +309,52 @@ bool PrepareQuantizePass::ContainsQuantizeOps(FuncOp func) {
using PrepareQuantStats =
quant::ConvertStatsToQDQs<quant::QuantizeCastOp, quant::DequantizeCastOp>;
// Calculates the minimum power of two that is not less than the value.
double power_of_two_bound(double value) {
return std::pow(2, std::ceil(std::log2(value)));
}
// Quantize recurrent input of LSTM with 16 bits.
template <typename SourceOp, typename Q, typename DQ>
struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
public:
explicit ConvertLstmStatsToQDQs(MLIRContext* context)
: OpRewritePattern<SourceOp>(context, /*benefit=*/2) {}
LogicalResult matchAndRewrite(SourceOp op,
PatternRewriter& rewriter) const override {
quant::StatisticsOp stats_op = llvm::dyn_cast_or_null<quant::StatisticsOp>(
op.input_cell_state().getDefiningOp());
// Recurrent input is be used within an LSTM, and thus should have one use.
if (!stats_op || !stats_op.getResult().hasOneUse()) {
return failure();
}
auto stats = stats_op.layerStats().dyn_cast<DenseFPElementsAttr>();
if (!stats) {
return failure();
}
double max = std::max(
std::abs(FloatAttr::getValueAsDouble(stats.getValue<APFloat>({0}))),
std::abs(FloatAttr::getValueAsDouble(stats.getValue<APFloat>({1}))));
double bound = power_of_two_bound(max);
Type expressed = stats_op.getType().cast<ShapedType>().getElementType();
// maximum value is adjusted to get a scale of power_of_two(max)/32768.
quant::QuantizedType quant_type = quant::fakeQuantAttrsToType(
stats_op.getLoc(), 16, -bound, bound * 32767.0 / 32768.0,
/*narrow_range*/ false, expressed, /*is_signed*/ true);
rewriter.setInsertionPointAfter(stats_op);
Type result_type = quant_type.castFromExpressedType(stats_op.getType());
auto q = rewriter.create<Q>(stats_op.getLoc(), result_type, stats_op.arg());
rewriter.replaceOpWithNewOp<DQ>(stats_op, stats_op.getType(), q);
return success();
}
};
using PrepareLstmQuantStats =
ConvertLstmStatsToQDQs<TFL::UnidirectionalSequenceLSTMOp,
quant::QuantizeCastOp, quant::DequantizeCastOp>;
void PrepareQuantizePass::runOnFunction() {
FuncOp func = getFunction();
MLIRContext* ctx = func.getContext();
@ -326,7 +376,14 @@ void PrepareQuantizePass::runOnFunction() {
OwningRewritePatternList patterns;
bool is_signed = quant_specs_.IsSignedInferenceType();
int bit_width = quant_specs_.GetQuantizationTypeWidth();
bool enforce_fixed_output_range = ContainsQuantizeOps(func);
bool quantization_aware_training_mode = ContainsQuantizeOps(func);
// Enforce fixed output range for post-training quantization and
// when the model has quantization emulation ops, unless it was disabled
// explicitly by the flag.
bool enforced_output_range =
(quant_specs_.post_training_quantization ||
quantization_aware_training_mode) &&
!quant_specs_.disable_enforced_fixed_output_range;
if (is_signed) {
patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
// Convert quant stats to int8 quantization parameters.
@ -337,6 +394,7 @@ void PrepareQuantizePass::runOnFunction() {
// Currently, only activation stats are imported, so narrow_range = false.
patterns.insert<PrepareQuantStats>(bit_width, false, false, ctx);
}
patterns.insert<PrepareLstmQuantStats>(ctx);
applyPatternsAndFoldGreedily(func, std::move(patterns));
SanityCheckAndAdjustment(func);
@ -345,8 +403,7 @@ void PrepareQuantizePass::runOnFunction() {
// values (tensors).
ApplyQuantizationParamsPropagation(
func, is_signed, disable_per_channel || quant_specs_.disable_per_channel,
GetOpQuantSpec,
enforce_fixed_output_range || quant_specs_.post_training_quantization);
GetOpQuantSpec, enforced_output_range);
ConvertMlirQuantOpsToTFLQuantOps(func);
}

View File

@ -64,6 +64,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#define DEBUG_TYPE "tf-tfl-legalization"
@ -518,9 +519,10 @@ struct ConvertTFStridedSlice : public RewritePattern {
explicit ConvertTFStridedSlice(MLIRContext *context)
: RewritePattern(TF::StridedSliceOp::getOperationName(), 2, context) {}
LogicalResult RewriteNewAxisMask(Operation *op, uint64_t new_axis_mask,
LogicalResult RewriteNewAxisMask(Operation *op,
PatternRewriter &rewriter) const {
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
uint64_t new_axis_mask = strided_slice_op.new_axis_mask();
// Insert a new reshape op.
Value original_input = strided_slice_op.input();
@ -528,48 +530,51 @@ struct ConvertTFStridedSlice : public RewritePattern {
original_input.getType().cast<RankedTensorType>();
const ArrayRef<int64_t> &original_input_shape =
original_input_type.getShape();
SmallVector<int64_t, 4> new_shape;
SmallVector<int64_t, 4> revised_shape;
int index = 0;
const int original_input_rank = original_input_shape.size();
while (index < original_input_rank || new_axis_mask) {
if (new_axis_mask & 1) {
new_shape.emplace_back(1);
revised_shape.emplace_back(1);
} else {
new_shape.emplace_back(original_input_shape[index++]);
revised_shape.emplace_back(original_input_shape[index++]);
}
new_axis_mask >>= 1;
}
const int dim_size = new_shape.size();
if (failed(TF::VerifyShapeOfReshapeOp(revised_shape))) return failure();
const int dim_size = revised_shape.size();
Location loc = strided_slice_op.getLoc();
auto shape_type =
RankedTensorType::get({dim_size}, rewriter.getIntegerType(32));
SmallVector<Attribute, 4> result_shape_data(dim_size);
for (int i = 0; i < dim_size; ++i) {
result_shape_data[i] =
rewriter.getI32IntegerAttr(static_cast<int32_t>(new_shape[i]));
rewriter.getI32IntegerAttr(static_cast<int32_t>(revised_shape[i]));
}
auto shape_attr = DenseElementsAttr::get(shape_type, result_shape_data);
auto shape = rewriter.create<ConstantOp>(loc, shape_type, shape_attr);
auto new_output_type =
RankedTensorType::get(new_shape, original_input_type.getElementType());
auto revised_output_type = RankedTensorType::get(
revised_shape, original_input_type.getElementType());
TF::ReshapeOp reshape = rewriter.create<TF::ReshapeOp>(
loc, new_output_type, original_input, shape);
loc, revised_output_type, original_input, shape);
// Replace the original strided_slice.
uint64_t new_begin_mask = strided_slice_op.begin_mask();
uint64_t new_end_mask = strided_slice_op.end_mask();
uint64_t revised_begin_mask = strided_slice_op.begin_mask();
uint64_t revised_end_mask = strided_slice_op.end_mask();
// Since we expand the dims, we need to apply them to the begin_mask &
// end_mask.
new_begin_mask |= strided_slice_op.new_axis_mask();
new_end_mask |= strided_slice_op.new_axis_mask();
revised_begin_mask |= strided_slice_op.new_axis_mask();
revised_end_mask |= strided_slice_op.new_axis_mask();
auto attribute_type = rewriter.getIntegerType(64);
rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
op, strided_slice_op.getType(), reshape, strided_slice_op.begin(),
strided_slice_op.end(), strided_slice_op.strides(),
rewriter.getIntegerAttr(attribute_type, new_begin_mask),
rewriter.getIntegerAttr(attribute_type, new_end_mask),
rewriter.getIntegerAttr(attribute_type, revised_begin_mask),
rewriter.getIntegerAttr(attribute_type, revised_end_mask),
rewriter.getIntegerAttr(attribute_type,
strided_slice_op.ellipsis_mask()),
rewriter.getI64IntegerAttr(0),
@ -578,10 +583,16 @@ struct ConvertTFStridedSlice : public RewritePattern {
return success();
}
LogicalResult RewriteEllipsisMask(Operation *op, uint64_t ellipsis_mask,
LogicalResult RewriteEllipsisMask(Operation *op,
PatternRewriter &rewriter) const {
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
uint64_t ellipsis_mask = strided_slice_op.ellipsis_mask();
uint64_t shrink_axis_mask = strided_slice_op.shrink_axis_mask();
// Enforce operator precedence.
shrink_axis_mask &= ~ellipsis_mask;
DenseIntElementsAttr begin_dense_elem_attr;
Value begin = strided_slice_op.begin();
auto begin_ranked_attr_type = begin.getType().dyn_cast<RankedTensorType>();
@ -623,8 +634,9 @@ struct ConvertTFStridedSlice : public RewritePattern {
int64_t begin_mask = strided_slice_op.begin_mask();
int64_t end_mask = strided_slice_op.end_mask();
int64_t new_begin_mask = 0;
int64_t new_end_mask = 0;
int64_t revised_begin_mask = 0;
int64_t revised_end_mask = 0;
int64_t revised_shrink_axis_mask = 0;
SmallVector<int32_t, 4> padded_begin;
SmallVector<int32_t, 4> padded_end;
@ -637,16 +649,18 @@ struct ConvertTFStridedSlice : public RewritePattern {
padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(index));
padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(index));
padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(index));
if ((begin_mask >> index) & 1) new_begin_mask |= (1 << new_index);
if ((end_mask >> index) & 1) new_end_mask |= (1 << new_index);
if ((begin_mask >> index) & 1) revised_begin_mask |= (1 << new_index);
if ((end_mask >> index) & 1) revised_end_mask |= (1 << new_index);
if ((shrink_axis_mask >> index) & 1)
revised_shrink_axis_mask |= (1 << new_index);
++index;
++new_index;
}
// Ellipsis.
for (; new_index < index + ellipsis_filled_dim_size; ++new_index) {
new_begin_mask |= (1 << new_index);
new_end_mask |= (1 << new_index);
revised_begin_mask |= (1 << new_index);
revised_end_mask |= (1 << new_index);
// Mimic the begin/end/strides mask behavior.
padded_begin.push_back(0);
@ -663,8 +677,10 @@ struct ConvertTFStridedSlice : public RewritePattern {
padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(index));
padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(index));
if ((begin_mask >> index) & 1) new_begin_mask |= (1 << new_index);
if ((end_mask >> index) & 1) new_end_mask |= (1 << new_index);
if ((begin_mask >> index) & 1) revised_begin_mask |= (1 << new_index);
if ((end_mask >> index) & 1) revised_end_mask |= (1 << new_index);
if ((shrink_axis_mask >> index) & 1)
revised_shrink_axis_mask |= (1 << new_index);
++index;
++new_index;
@ -687,13 +703,12 @@ struct ConvertTFStridedSlice : public RewritePattern {
rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
op, strided_slice_op.getType(), input, begin_op.getResult(),
end_op.getResult(), stride_op.getResult(),
rewriter.getIntegerAttr(attribute_type, new_begin_mask),
rewriter.getIntegerAttr(attribute_type, new_end_mask),
/*ellipsis_maks=*/rewriter.getI64IntegerAttr(0),
rewriter.getIntegerAttr(attribute_type, revised_begin_mask),
rewriter.getIntegerAttr(attribute_type, revised_end_mask),
/*ellipsis_mask=*/rewriter.getI64IntegerAttr(0),
rewriter.getIntegerAttr(attribute_type,
strided_slice_op.new_axis_mask()),
rewriter.getIntegerAttr(attribute_type,
strided_slice_op.shrink_axis_mask()));
rewriter.getIntegerAttr(attribute_type, revised_shrink_axis_mask));
return success();
}
@ -701,20 +716,18 @@ struct ConvertTFStridedSlice : public RewritePattern {
PatternRewriter &rewriter) const override {
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
// TODO(renjieliu): Consider expand the transformation for shrink mask as
// well.
if (strided_slice_op.shrink_axis_mask()) return failure();
// Handle new axis mask.
uint64_t new_axis_mask = strided_slice_op.new_axis_mask();
if (new_axis_mask != 0) {
return RewriteNewAxisMask(strided_slice_op, new_axis_mask, rewriter);
if (strided_slice_op.new_axis_mask() != 0) {
// We currently don't handle simultaneous shrink_ and new_axis masks.
if (strided_slice_op.shrink_axis_mask()) {
return failure();
}
return RewriteNewAxisMask(strided_slice_op, rewriter);
}
// Handle ellipsis mask.
uint64_t ellipsis_mask = strided_slice_op.ellipsis_mask();
if (ellipsis_mask != 0) {
return RewriteEllipsisMask(strided_slice_op, ellipsis_mask, rewriter);
if (strided_slice_op.ellipsis_mask() != 0) {
return RewriteEllipsisMask(strided_slice_op, rewriter);
}
return failure();
}

View File

@ -182,7 +182,7 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
b.create<ReturnOp>(yield_op->getLoc(), args);
yield_op->erase();
symbol_table.insert(outlined_func);
outlined_func.setVisibility(FuncOp::Visibility::Private);
outlined_func.setPrivate();
return outlined_func;
};

View File

@ -57,6 +57,8 @@ mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) {
return mlir::ComplexType::get(builder.getF64Type());
case tflite::TensorType_INT8:
return builder.getIntegerType(8);
case tflite::TensorType_UINT64:
return builder.getIntegerType(64, /*isSigned=*/false);
}
}
@ -86,6 +88,8 @@ tensorflow::DataType TflTypeToTfType(tflite::TensorType type) {
return tensorflow::DT_STRING;
case tflite::TensorType_UINT8:
return tensorflow::DT_UINT8;
case tflite::TensorType_UINT64:
return tensorflow::DT_UINT64;
}
}

View File

@ -51,7 +51,7 @@ static ConfigProto::Experimental::MlirBridgeRollout GetUserRequest(
}
MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy(
absl::optional<ConfigProto> config_proto) {
const tensorflow::Graph& graph, absl::optional<ConfigProto> config_proto) {
switch (GetUserRequest(config_proto)) {
case ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED:
return MlirBridgeRolloutPolicy::kEnabledByUser;

View File

@ -17,6 +17,7 @@ limitations under the License.
#define THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_MLIR_BRIDGE_ROLLOUT_POLICY_H_
#include "absl/types/optional.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/protobuf/config.pb.h"
namespace tensorflow {
@ -46,6 +47,7 @@ enum class MlirBridgeRolloutPolicy {
// The config_proto param is a required input for all TF1 graphs but it is
// redundant for TF2 graphs.
MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy(
const tensorflow::Graph& graph,
absl::optional<tensorflow::ConfigProto> config_proto);
} // namespace tensorflow

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
#include <memory>
#include <string>
#include "absl/container/flat_hash_set.h"
@ -32,10 +33,20 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
auto* shadow_run_success =
monitoring::Counter<0>::New("/tensorflow/core/mlir_shadow_run_success",
"Success count of MLIR shadow runs");
auto* shadow_run_failure = monitoring::Counter<2>::New(
"/tensorflow/core/mlir_shadow_run_failure",
"Failure count of MLIR shadow runs", "kind", "name");
static inline absl::string_view StringRefToView(llvm::StringRef ref) {
return {ref.data(), ref.size()};
}
@ -109,7 +120,7 @@ Status MlirFunctionOptimizationPass::Run(
// Skip conversion from Graph to MLIR if none of the passes are enabled.
const bool is_enabled =
llvm::any_of(registry_->passes(), [&](auto& pass_registration) -> bool {
return pass_registration.pass->IsEnabled(config_proto);
return pass_registration.pass->IsEnabled(config_proto, **graph);
});
if (!is_enabled) {
@ -123,6 +134,17 @@ Status MlirFunctionOptimizationPass::Run(
<< "(registered " << registry_->passes().size()
<< " passes)";
// For scenarios when the new bridge is enabled by analysis we need to make
// sure that MLIR transformations are executed in a shadow mode.
// In this case, no changes should be done to the original `graph`
// and no failures propagated to the user.
bool enabled_by_analysis =
mlir_rollout_policy_(**graph, config_proto) ==
MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis;
if (enabled_by_analysis) {
LOG_FIRST_N(INFO, 1) << "Shadow run of MLIR enabled after graph analysis";
}
GraphDebugInfo debug_info;
mlir::MLIRContext context;
RegisterDialects(context.getDialectRegistry());
@ -130,10 +152,21 @@ Status MlirFunctionOptimizationPass::Run(
import_config.graph_as_function = true;
import_config.control_outputs = *control_ret_node_names;
import_config.upgrade_legacy = true;
TF_ASSIGN_OR_RETURN(auto module_ref,
ConvertGraphToMlir(**graph, debug_info, *flib_def,
import_config, &context));
auto module_ref_status = ConvertGraphToMlir(**graph, debug_info, *flib_def,
import_config, &context);
if (!module_ref_status.ok()) {
if (enabled_by_analysis) {
shadow_run_failure->GetCell("graph_to_mlir", "")->IncrementBy(1);
// Do not fail, let the old bridge to run on the original `graph`.
return Status::OK();
}
return module_ref_status.status();
}
auto module_ref = std::move(module_ref_status.ValueOrDie());
AddDevicesToOp(*module_ref, &device_set);
for (auto& pass_registration : registry_->passes()) {
@ -144,7 +177,17 @@ Status MlirFunctionOptimizationPass::Run(
DumpModule(*module_ref, llvm::formatv("mlir_{0}_before_", name));
}
TF_RETURN_IF_ERROR(pass_registration.pass->Run(config_proto, *module_ref));
auto pass_status =
pass_registration.pass->Run(config_proto, *module_ref, **graph);
if (!pass_status.ok()) {
if (enabled_by_analysis) {
shadow_run_failure->GetCell("pass", name.str())->IncrementBy(1);
// Do not fail, let the old bridge to run on the original `graph`.
return Status::OK();
}
return pass_status;
}
if (VLOG_IS_ON(1)) {
DumpModule(*module_ref, llvm::formatv("mlir_{0}_after_", name));
@ -153,6 +196,25 @@ Status MlirFunctionOptimizationPass::Run(
GraphExportConfig export_config;
absl::flat_hash_set<Node*> control_ret_nodes;
// In case MLIR is enabled by analysis, verify that MLIR could be converted
// back to TF graph. Original `graph` must stay the same.
if (enabled_by_analysis) {
auto empty_graph = std::make_unique<Graph>(OpRegistry::Global());
FunctionLibraryDefinition empty_flib = empty_graph->flib_def();
auto mlir_to_graph_status =
ConvertMlirToGraph(*module_ref, export_config, &empty_graph,
&empty_flib, &control_ret_nodes);
if (mlir_to_graph_status.ok()) {
shadow_run_success->GetCell()->IncrementBy(1);
} else {
shadow_run_failure->GetCell("mlir_to_graph", "")->IncrementBy(1);
}
return Status::OK();
}
TF_RETURN_WITH_CONTEXT_IF_ERROR(
ConvertMlirToGraph(*module_ref, export_config, graph, flib_def,
&control_ret_nodes),
@ -183,7 +245,7 @@ Status MlirV1CompatGraphOptimizationPass::Run(
const bool is_enabled =
absl::c_any_of(registry_->passes(), [&](auto& pass_registration) -> bool {
return pass_registration.pass->IsEnabled(
options.session_options->config);
options.session_options->config, **options.graph);
});
if (!is_enabled) {

Some files were not shown because too many files have changed in this diff Show More