Merge remote-tracking branch 'upstream/master' into detection_postprocess
This commit is contained in:
commit
6714af7331
4
.bazelrc
4
.bazelrc
@ -602,6 +602,10 @@ build:release_windows_common --config=release_common
|
||||
build:release_windows_common --define=no_tensorflow_py_deps=true
|
||||
build:release_windows_common --announce_rc
|
||||
|
||||
# First available in VS 16.4. Speeds Windows compile times by a lot. See
|
||||
# https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
|
||||
build:release_windows_common --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions
|
||||
|
||||
build:release_cpu_windows --config=release_windows_common
|
||||
|
||||
build:release_gpu_windows --config=release_windows_common
|
||||
|
@ -5,7 +5,6 @@
|
||||
[](https://badge.fury.io/py/tensorflow)
|
||||
[](https://badge.fury.io/py/tensorflow)
|
||||
|
||||
|
||||
**`Documentation`** |
|
||||
------------------- |
|
||||
[](https://www.tensorflow.org/api_docs/) |
|
||||
@ -61,6 +60,7 @@ commands.
|
||||
*Nightly binaries are available for testing using the
|
||||
[tf-nightly](https://pypi.python.org/pypi/tf-nightly) and
|
||||
[tf-nightly-cpu](https://pypi.python.org/pypi/tf-nightly-cpu) packages on PyPi.*
|
||||
|
||||
#### *Try your first TensorFlow program*
|
||||
|
||||
```shell
|
||||
@ -159,8 +159,6 @@ Container Type | Status | Art
|
||||
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
|
||||
* [Machine Learning with TensorFlow on GCP](https://www.coursera.org/specializations/machine-learning-tensorflow-gcp)
|
||||
* [TensorFlow Codelabs](https://codelabs.developers.google.com/?cat=TensorFlow)
|
||||
* [TensorFlow Chat Room on StackOverflow (not actively monitored by the
|
||||
TensorFlow team)](https://chat.stackoverflow.com/rooms/216694/tensorflow)
|
||||
* [TensorFlow Blog](https://blog.tensorflow.org)
|
||||
* [Learn ML with TensorFlow](https://www.tensorflow.org/resources/learn-ml)
|
||||
* [TensorFlow Twitter](https://twitter.com/tensorflow)
|
||||
|
15
RELEASE.md
15
RELEASE.md
@ -45,11 +45,26 @@
|
||||
* Removed deprecated `Interpreter::UseNNAPI(bool)` C++ API.
|
||||
* Use `NnApiDelegate()` and related delegate configuration methods
|
||||
directly.
|
||||
* 16 bits quantization
|
||||
* Added int16x8 support for ABS, REDUCE_MAX and REDUCE_MIN operators.
|
||||
* Added support for saved model's session initializer through
|
||||
`TFLiteConverter.from_saved_model`.
|
||||
|
||||
* TF Core:
|
||||
* Corrected higher-order gradients of control flow constructs (`tf.cond`,
|
||||
`tf.while_loop`, and compositions like `tf.foldl`) computed with
|
||||
`tf.GradientTape` inside a `tf.function`.
|
||||
|
||||
* `tf.summary`:
|
||||
* New `tf.summary.graph` allows manual write of TensorFlow graph
|
||||
(`tf.Graph` or `tf.compat.v1.GraphDef`) as a summary. This is not a
|
||||
replacement for the trace-based API.
|
||||
|
||||
* Set `/d2ReducedOptimizeHugeFunctions` by default for Windows builds. This
|
||||
provides a big compile-time speedup, and effectively raises the minimum
|
||||
supported MSVC version to 16.4 (current: 16.8).
|
||||
* See: https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
||||
This release contains contributions from many people at Google, as well as:
|
||||
|
38
configure.py
38
configure.py
@ -1168,41 +1168,13 @@ def set_system_libs_flag(environ_cp):
|
||||
write_to_bazelrc('build --define=%s=%s' % (varname, environ_cp[varname]))
|
||||
|
||||
|
||||
def is_reduced_optimize_huge_functions_available(environ_cp):
|
||||
"""Check to see if the system supports /d2ReducedOptimizeHugeFunctions.
|
||||
|
||||
The above compiler flag is a new compiler flag introduced to the Visual Studio
|
||||
compiler in version 16.4 (available in Visual Studio 2019, Preview edition
|
||||
only, as of 2019-11-19). TensorFlow needs this flag to massively reduce
|
||||
compile times, but until 16.4 is officially released, we can't depend on it.
|
||||
|
||||
See also
|
||||
https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
|
||||
|
||||
Because it's very annoying to check this manually (to check the MSVC installed
|
||||
versions, you need to use the registry, and it's not clear if Bazel will be
|
||||
using that install version anyway), we expect enviroments who know they may
|
||||
use this flag to export TF_VC_VERSION=16.4
|
||||
|
||||
TODO(angerson, gunan): Remove this function when TensorFlow's minimum VS
|
||||
version is upgraded to 16.4.
|
||||
|
||||
Arguments:
|
||||
environ_cp: Environment of the current execution
|
||||
|
||||
Returns:
|
||||
boolean, whether or not /d2ReducedOptimizeHugeFunctions is available on this
|
||||
machine.
|
||||
"""
|
||||
return float(environ_cp.get('TF_VC_VERSION', '0')) >= 16.4
|
||||
|
||||
|
||||
def set_windows_build_flags(environ_cp):
|
||||
"""Set Windows specific build options."""
|
||||
if is_reduced_optimize_huge_functions_available(environ_cp):
|
||||
write_to_bazelrc(
|
||||
'build --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions'
|
||||
)
|
||||
|
||||
# First available in VS 16.4. Speeds up Windows compile times by a lot. See
|
||||
# https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
|
||||
# pylint: disable=line-too-long
|
||||
write_to_bazelrc('build --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions')
|
||||
|
||||
if get_var(
|
||||
environ_cp, 'TF_OVERRIDE_EIGEN_STRONG_INLINE', 'Eigen strong inline',
|
||||
|
@ -588,9 +588,11 @@ config_setting(
|
||||
# DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST!
|
||||
# Instead, please use public APIs or public build rules TF provides.
|
||||
# If you need functionality that is not exposed, we will work with you to expand our public APIs.
|
||||
# TODO(b/173549186): Move Google-internal TF code out of learning/brain
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = [
|
||||
"//learning/brain/mlir/...",
|
||||
"//learning/lib/ami/simple_ml/...",
|
||||
"//tensorflow/...",
|
||||
],
|
||||
|
@ -199,6 +199,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":logging",
|
||||
":tf_status",
|
||||
":tf_tensor",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
@ -51,6 +51,7 @@ tf_cuda_library(
|
||||
":immediate_execution_context",
|
||||
":immediate_execution_operation",
|
||||
":immediate_execution_tensor_handle",
|
||||
":immediate_execution_distributed_manager",
|
||||
":abstract_tensor_handle",
|
||||
":tfe_context_internal",
|
||||
":tfe_cancellation_manager_internal",
|
||||
@ -70,6 +71,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:context_distributed_manager",
|
||||
"//tensorflow/core/common_runtime/eager:core",
|
||||
"//tensorflow/core/common_runtime/eager:eager_executor",
|
||||
"//tensorflow/core/common_runtime/eager:execute",
|
||||
@ -119,6 +121,7 @@ filegroup(
|
||||
"gradients.h",
|
||||
"gradients_internal.h",
|
||||
"immediate_execution_context.h",
|
||||
"immediate_execution_distributed_manager.h",
|
||||
"immediate_execution_operation.h",
|
||||
"immediate_execution_tensor_handle.h",
|
||||
"tape.h",
|
||||
@ -176,6 +179,7 @@ cc_library(
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/platform:casts",
|
||||
"//tensorflow/core/platform:types",
|
||||
],
|
||||
@ -224,6 +228,34 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "unified_api_testutil",
|
||||
testonly = 1,
|
||||
srcs = [
|
||||
"unified_api_testutil.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"unified_api_testutil.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":abstract_tensor_handle",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
":c_api_unified_internal",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:status",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "gradients_test",
|
||||
size = "small",
|
||||
@ -240,6 +272,7 @@ tf_cuda_cc_test(
|
||||
":c_api_test_util",
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
":unified_api_testutil",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
@ -260,6 +293,29 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "unified_api_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"unified_api_test.cc",
|
||||
],
|
||||
args = ["--heap_check=local"],
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
":c_api_experimental",
|
||||
":c_api_unified_internal",
|
||||
":unified_api_testutil",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:errors",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradients_util",
|
||||
srcs = [
|
||||
@ -449,8 +505,10 @@ cc_library(
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:refcount",
|
||||
"//tensorflow/core/platform:status",
|
||||
],
|
||||
)
|
||||
|
||||
@ -529,6 +587,19 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "immediate_execution_distributed_manager",
|
||||
hdrs = ["immediate_execution_distributed_manager.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "immediate_execution_context",
|
||||
hdrs = ["immediate_execution_context.h"],
|
||||
@ -537,12 +608,14 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":immediate_execution_distributed_manager",
|
||||
":immediate_execution_operation",
|
||||
":immediate_execution_tensor_handle",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
|
@ -17,8 +17,10 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/refcount.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to a Tensor handle in either tracing or immediate
|
||||
@ -32,6 +34,9 @@ class AbstractTensorHandle : public core::RefCounted {
|
||||
public:
|
||||
// Returns tensor dtype.
|
||||
virtual tensorflow::DataType DataType() const = 0;
|
||||
// Returns tensor shape. If tensor has unknown rank, shape remains untouched.
|
||||
virtual tensorflow::Status Shape(
|
||||
tensorflow::PartialTensorShape* shape) const = 0;
|
||||
|
||||
AbstractTensorHandleKind getKind() const { return kind_; }
|
||||
|
||||
|
@ -21,16 +21,11 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
|
||||
// clang-format off
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
// clang-format on
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
@ -39,59 +34,39 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
||||
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
|
||||
#endif
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/device_filters.pb.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
#include "tensorflow/core/common_runtime/copy_tensor.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/execute.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
|
||||
#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/remote_device.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/platform/blocking_counter.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/notification.h"
|
||||
#include "tensorflow/core/platform/random.h"
|
||||
#include "tensorflow/core/platform/refcount.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
// "tensorflow/core/platform/platform.h" must be included first before using
|
||||
// PLATFORM_GOOGLE, IS_MOBILE_PLATFORM, etc.
|
||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
||||
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
|
||||
#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
#include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
namespace {
|
||||
@ -100,611 +75,6 @@ string DeviceName(const tensorflow::Device* d) {
|
||||
return (d == nullptr) ? "cpu:0" : d->name();
|
||||
}
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
bool AreLocalDevicesCompatible(const tensorflow::EagerContext* context,
|
||||
const tensorflow::ServerDef& server_def) {
|
||||
if (server_def.job_name() != context->HostCPU()->parsed_name().job) {
|
||||
return false;
|
||||
}
|
||||
return server_def.default_session_config().SerializeAsString() ==
|
||||
context->session_options().config.SerializeAsString();
|
||||
}
|
||||
|
||||
tensorflow::Status AddRemoteDevicesToMgr(
|
||||
const std::vector<string>& added_remote_workers,
|
||||
tensorflow::WorkerCacheInterface* worker_cache,
|
||||
tensorflow::DynamicDeviceMgr* remote_device_mgr) {
|
||||
std::vector<std::unique_ptr<tensorflow::Device>> remote_devices;
|
||||
tensorflow::mutex remote_devices_mu;
|
||||
int num_added_workers = added_remote_workers.size();
|
||||
tensorflow::BlockingCounter counter(num_added_workers);
|
||||
std::vector<tensorflow::Status> statuses(num_added_workers);
|
||||
for (int i = 0; i < num_added_workers; i++) {
|
||||
tensorflow::NewRemoteDevices(
|
||||
tensorflow::Env::Default(), worker_cache, added_remote_workers[i],
|
||||
[i, &statuses, &counter, &remote_devices, &remote_devices_mu](
|
||||
const tensorflow::Status& s,
|
||||
std::vector<tensorflow::Device*>* devices) {
|
||||
statuses[i] = s;
|
||||
if (s.ok()) {
|
||||
tensorflow::mutex_lock l(remote_devices_mu);
|
||||
for (tensorflow::Device* d : *devices) {
|
||||
remote_devices.emplace_back(d);
|
||||
}
|
||||
}
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
counter.Wait();
|
||||
for (int i = 0; i < num_added_workers; i++) {
|
||||
TF_RETURN_IF_ERROR(statuses[i]);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(remote_device_mgr->AddDevices(std::move(remote_devices)));
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status GetAllRemoteDevices(
|
||||
const std::vector<string>& remote_workers,
|
||||
tensorflow::WorkerCacheInterface* worker_cache,
|
||||
std::unique_ptr<tensorflow::DynamicDeviceMgr>* device_mgr) {
|
||||
auto remote_device_mgr = absl::make_unique<tensorflow::DynamicDeviceMgr>();
|
||||
TF_RETURN_IF_ERROR(AddRemoteDevicesToMgr(remote_workers, worker_cache,
|
||||
remote_device_mgr.get()));
|
||||
*device_mgr = std::move(remote_device_mgr);
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status RemoveRemoteDevicesFromMgr(
|
||||
const std::vector<string>& removed_remote_workers,
|
||||
tensorflow::DynamicDeviceMgr* remote_device_mgr) {
|
||||
const std::vector<tensorflow::Device*> remote_devices =
|
||||
(remote_device_mgr->ListDevices());
|
||||
std::vector<tensorflow::Device*> devices_to_remove;
|
||||
for (tensorflow::Device* d : remote_devices) {
|
||||
for (const string& remote_worker : removed_remote_workers) {
|
||||
if (tensorflow::DeviceNameUtils::IsSameAddressSpace(remote_worker,
|
||||
d->name())) {
|
||||
devices_to_remove.emplace_back(d);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(remote_device_mgr->RemoveDevices(devices_to_remove));
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status ListRemoteWorkers(tensorflow::ServerInterface* server,
|
||||
const string& local_worker,
|
||||
std::vector<string>* remote_workers) {
|
||||
tensorflow::GrpcServer* grpc_server =
|
||||
dynamic_cast<tensorflow::GrpcServer*>(server);
|
||||
if (grpc_server == nullptr) {
|
||||
return tensorflow::errors::Internal(
|
||||
"Currently, TFE_NewContext only supports tensorflow::GrpcServer.");
|
||||
}
|
||||
grpc_server->master_env()->worker_cache->ListWorkers(remote_workers);
|
||||
remote_workers->erase(
|
||||
std::remove(remote_workers->begin(), remote_workers->end(), local_worker),
|
||||
remote_workers->end());
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
void DifferentiateWorkerLists(const std::vector<string>* current_list,
|
||||
const std::vector<string>* new_list,
|
||||
std::vector<string>* added,
|
||||
std::vector<string>* removed,
|
||||
std::vector<string>* existing) {
|
||||
// Get STL set_difference and set_intersection with one list traversal.
|
||||
// Similar to the set_difference library function, the input lists
|
||||
// (`current_list` and `new_list`) must be sorted before calling the function.
|
||||
added->resize(new_list->size());
|
||||
removed->resize(current_list->size());
|
||||
existing->resize(current_list->size());
|
||||
std::vector<string>::const_iterator curr_it = current_list->begin();
|
||||
std::vector<string>::const_iterator new_it = new_list->begin();
|
||||
std::vector<string>::iterator added_it = added->begin();
|
||||
std::vector<string>::iterator removed_it = removed->begin();
|
||||
std::vector<string>::iterator existing_it = existing->begin();
|
||||
while (curr_it != current_list->end() && new_it != new_list->end()) {
|
||||
if (*curr_it < *new_it) {
|
||||
*removed_it++ = *curr_it++;
|
||||
} else if (*curr_it > *new_it) {
|
||||
*added_it++ = *new_it++;
|
||||
} else {
|
||||
*existing_it++ = *curr_it++;
|
||||
new_it++;
|
||||
}
|
||||
}
|
||||
removed_it = std::copy(curr_it, current_list->end(), removed_it);
|
||||
added_it = std::copy(new_it, new_list->end(), added_it);
|
||||
added->resize(added_it - added->begin());
|
||||
removed->resize(removed_it - removed->begin());
|
||||
existing->resize(existing_it - existing->begin());
|
||||
}
|
||||
|
||||
tensorflow::Status GetReplacedFromExistingWorkers(
|
||||
const std::vector<string>* existing_workers, tensorflow::uint64 context_id,
|
||||
tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* client_cache,
|
||||
std::vector<string>* replaced_workers) {
|
||||
tensorflow::BlockingCounter counter(existing_workers->size());
|
||||
std::vector<tensorflow::Status> statuses(existing_workers->size());
|
||||
tensorflow::eager::KeepAliveRequest request;
|
||||
request.set_context_id(context_id);
|
||||
std::vector<tensorflow::eager::KeepAliveResponse> responses(
|
||||
existing_workers->size());
|
||||
for (int i = 0; i < existing_workers->size(); i++) {
|
||||
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
||||
statuses[i] =
|
||||
client_cache->GetClient(existing_workers->at(i), &eager_client);
|
||||
if (!statuses[i].ok()) {
|
||||
counter.DecrementCount();
|
||||
continue;
|
||||
}
|
||||
eager_client->KeepAliveAsync(
|
||||
&request, &responses[i],
|
||||
[i, &statuses, &counter](const tensorflow::Status& s) {
|
||||
statuses[i] = s;
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
counter.Wait();
|
||||
for (int i = 0; i < existing_workers->size(); i++) {
|
||||
// If the RPC fails (indicating that the requested ID doesn't exist on
|
||||
// remote), or the returned view ID is not equal to the local one
|
||||
// (indicating that the remote worker has a stale view of cluster), treat
|
||||
// the worker as replaced.
|
||||
if (!statuses[i].ok() ||
|
||||
responses[i].context_view_id() != context_view_id) {
|
||||
replaced_workers->emplace_back(existing_workers->at(i));
|
||||
}
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status CreateRemoteContexts(
|
||||
TFE_Context* ctx, const std::vector<string>& remote_workers,
|
||||
tensorflow::uint64 context_id, tensorflow::uint64 context_view_id,
|
||||
int keep_alive_secs, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
|
||||
const bool lazy_copy_remote_function_inputs,
|
||||
const tensorflow::eager::CreateContextRequest& base_request) {
|
||||
int num_remote_workers = remote_workers.size();
|
||||
tensorflow::BlockingCounter counter(num_remote_workers);
|
||||
std::vector<tensorflow::Status> statuses(num_remote_workers);
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
const string& remote_worker = remote_workers[i];
|
||||
tensorflow::DeviceNameUtils::ParsedName parsed_name;
|
||||
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
|
||||
&parsed_name)) {
|
||||
statuses[i] = tensorflow::errors::InvalidArgument(
|
||||
"Unable to parse ", remote_worker, " as a device name");
|
||||
counter.DecrementCount();
|
||||
continue;
|
||||
}
|
||||
|
||||
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
||||
statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
|
||||
if (eager_client == nullptr) {
|
||||
statuses[i] = tensorflow::errors::Internal(
|
||||
"Cannot find a client for the given target:", remote_worker);
|
||||
}
|
||||
if (!statuses[i].ok()) {
|
||||
counter.DecrementCount();
|
||||
continue;
|
||||
}
|
||||
|
||||
tensorflow::eager::CreateContextRequest request;
|
||||
tensorflow::eager::CreateContextResponse* response =
|
||||
new tensorflow::eager::CreateContextResponse();
|
||||
request.set_context_id(context_id);
|
||||
request.set_context_view_id(context_view_id);
|
||||
*request.mutable_server_def() = server_def;
|
||||
request.mutable_server_def()->set_job_name(parsed_name.job);
|
||||
request.mutable_server_def()->set_task_index(parsed_name.task);
|
||||
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
|
||||
server_def.default_session_config());
|
||||
|
||||
std::vector<bool> filtered_device_mask;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->FilterDevicesForRemoteWorkers(
|
||||
remote_worker, base_request.cluster_device_attributes(),
|
||||
&filtered_device_mask);
|
||||
DCHECK_EQ(filtered_device_mask.size(),
|
||||
base_request.cluster_device_attributes_size());
|
||||
for (int i = 0; i < filtered_device_mask.size(); i++) {
|
||||
if (filtered_device_mask[i]) {
|
||||
const auto& da = base_request.cluster_device_attributes(i);
|
||||
*request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
}
|
||||
request.set_async(async);
|
||||
request.set_keep_alive_secs(keep_alive_secs);
|
||||
request.set_lazy_copy_remote_function_inputs(
|
||||
lazy_copy_remote_function_inputs);
|
||||
|
||||
eager_client->CreateContextAsync(
|
||||
&request, response,
|
||||
[i, &statuses, &counter, response](const tensorflow::Status& s) {
|
||||
statuses[i] = s;
|
||||
delete response;
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
counter.Wait();
|
||||
tensorflow::StatusGroup sg;
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
if (TF_PREDICT_FALSE(!statuses[i].ok())) {
|
||||
sg.Update(statuses[i]);
|
||||
}
|
||||
}
|
||||
return sg.as_summary_status();
|
||||
}
|
||||
|
||||
tensorflow::Status UpdateRemoteContexts(
|
||||
TFE_Context* ctx, const std::vector<string>& remote_workers,
|
||||
const std::vector<string>& added_workers,
|
||||
const std::vector<string>& removed_workers, tensorflow::uint64 context_id,
|
||||
tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* remote_eager_workers,
|
||||
const tensorflow::eager::CreateContextRequest& base_request) {
|
||||
int num_remote_workers = remote_workers.size();
|
||||
tensorflow::BlockingCounter counter(num_remote_workers);
|
||||
std::vector<tensorflow::Status> statuses(num_remote_workers);
|
||||
|
||||
int cluster_device_count = base_request.cluster_device_attributes_size();
|
||||
std::unordered_set<string> added_or_removed(added_workers.begin(),
|
||||
added_workers.end());
|
||||
std::copy(removed_workers.begin(), removed_workers.end(),
|
||||
std::inserter(added_or_removed, added_or_removed.end()));
|
||||
// Whether each device is in the updated (added or removed) workers
|
||||
std::vector<bool> device_added_or_removed(cluster_device_count);
|
||||
for (int i = 0; i < base_request.cluster_device_attributes_size(); i++) {
|
||||
const auto& da = base_request.cluster_device_attributes().at(i);
|
||||
tensorflow::DeviceNameUtils::ParsedName pn;
|
||||
tensorflow::DeviceNameUtils::ParseFullName(da.name(), &pn);
|
||||
string task_name;
|
||||
tensorflow::DeviceNameUtils::GetTaskName(pn, &task_name);
|
||||
if (added_or_removed.find(task_name) != added_or_removed.end()) {
|
||||
device_added_or_removed[i] = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
const string& remote_worker = remote_workers[i];
|
||||
tensorflow::DeviceNameUtils::ParsedName parsed_name;
|
||||
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
|
||||
&parsed_name)) {
|
||||
statuses[i] = tensorflow::errors::InvalidArgument(
|
||||
"Unable to parse ", remote_worker, " as a device name");
|
||||
counter.DecrementCount();
|
||||
continue;
|
||||
}
|
||||
|
||||
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
||||
statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
|
||||
if (eager_client == nullptr) {
|
||||
statuses[i] = tensorflow::errors::Internal(
|
||||
"Cannot find a client for the given target:", remote_worker);
|
||||
}
|
||||
if (!statuses[i].ok()) {
|
||||
counter.DecrementCount();
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<bool> filtered_device_mask;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->FilterDevicesForRemoteWorkers(
|
||||
remote_worker, base_request.cluster_device_attributes(),
|
||||
&filtered_device_mask);
|
||||
DCHECK_EQ(filtered_device_mask.size(), cluster_device_count);
|
||||
|
||||
// If any of the devices that match the device filters are in the set of
|
||||
// added or removed workers, we must send a complete UpdateContextRequest.
|
||||
// Otherwise, only send a simple request to increment context view ID.
|
||||
std::vector<bool> added_or_removed_filtered_devices(cluster_device_count);
|
||||
std::transform(device_added_or_removed.begin(),
|
||||
device_added_or_removed.end(), filtered_device_mask.begin(),
|
||||
added_or_removed_filtered_devices.begin(),
|
||||
std::logical_and<bool>());
|
||||
const bool full_update_request =
|
||||
std::accumulate(added_or_removed_filtered_devices.begin(),
|
||||
added_or_removed_filtered_devices.end(), false,
|
||||
std::logical_or<bool>());
|
||||
|
||||
tensorflow::eager::UpdateContextRequest request;
|
||||
auto* response = new tensorflow::eager::UpdateContextResponse();
|
||||
request.set_context_id(context_id);
|
||||
request.set_context_view_id(context_view_id);
|
||||
if (full_update_request) {
|
||||
*request.mutable_server_def() = server_def;
|
||||
request.mutable_server_def()->set_job_name(parsed_name.job);
|
||||
request.mutable_server_def()->set_task_index(parsed_name.task);
|
||||
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
|
||||
server_def.default_session_config());
|
||||
for (int i = 0; i < cluster_device_count; i++) {
|
||||
if (filtered_device_mask[i]) {
|
||||
const auto& da = base_request.cluster_device_attributes(i);
|
||||
*request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eager_client->UpdateContextAsync(
|
||||
&request, response,
|
||||
[i, &statuses, &counter, response](const tensorflow::Status& s) {
|
||||
statuses[i] = s;
|
||||
delete response;
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
counter.Wait();
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
TF_RETURN_IF_ERROR(statuses[i]);
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
int keep_alive_secs, const tensorflow::ServerDef& server_def,
|
||||
TFE_Context* ctx, bool reset_context) {
|
||||
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
|
||||
// server object (which currently CHECK-fails) and we miss the error, instead,
|
||||
// we log the error, and then return to allow the user to see the error
|
||||
// message.
|
||||
#define LOG_AND_RETURN_IF_ERROR(...) \
|
||||
do { \
|
||||
const ::tensorflow::Status _status = (__VA_ARGS__); \
|
||||
if (TF_PREDICT_FALSE(!_status.ok())) { \
|
||||
LOG(ERROR) << _status.error_message(); \
|
||||
return _status; \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
string worker_name =
|
||||
tensorflow::strings::StrCat("/job:", server_def.job_name(),
|
||||
"/replica:0/task:", server_def.task_index());
|
||||
|
||||
// List of current remote workers before updating server_def. Unused if
|
||||
// resetting the server_def.
|
||||
std::vector<string> curr_remote_workers;
|
||||
// List of updated remote workers.
|
||||
std::vector<string> remote_workers;
|
||||
|
||||
// New server created for new server_def. Unused if updating server_def.
|
||||
std::unique_ptr<tensorflow::ServerInterface> new_server;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::GrpcServer* grpc_server;
|
||||
if (reset_context) {
|
||||
const tensorflow::DeviceMgr* device_mgr =
|
||||
AreLocalDevicesCompatible(context, server_def)
|
||||
? context->local_device_mgr()
|
||||
: nullptr;
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServerWithOptions(
|
||||
server_def, {device_mgr}, &new_server));
|
||||
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
ListRemoteWorkers(new_server.get(), worker_name, &remote_workers));
|
||||
} else {
|
||||
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
|
||||
&curr_remote_workers));
|
||||
// No need to check the cast here, since `ListRemoteWorkers` already checks
|
||||
// if the server is a GRPC server or not.
|
||||
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
|
||||
}
|
||||
|
||||
tensorflow::uint64 context_id = context->GetContextId();
|
||||
tensorflow::uint64 context_view_id = context->GetContextViewId();
|
||||
if (reset_context) {
|
||||
context_id = tensorflow::EagerContext::NewContextId();
|
||||
context_view_id = 0;
|
||||
// Make master eager context accessible by local eager service, which might
|
||||
// receive send tensor requests from remote workers.
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->AddMasterEagerContextToEagerService(context_id, context));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->master_env()->worker_cache->GetEagerClientCache(
|
||||
&remote_eager_workers));
|
||||
|
||||
// For cluster update, use a status group to aggregate statuses from
|
||||
// * adding and removing remote devices
|
||||
// * creating remote contexts on newly added workers
|
||||
// * updating remote contexts on existing workers
|
||||
// * updating the master context
|
||||
// Note that we should not return immediately on errors in the middle of these
|
||||
// updates to prevent cluster from having inconsistent context views.
|
||||
//
|
||||
// Unused if `reset_context` is True.
|
||||
tensorflow::StatusGroup sg;
|
||||
|
||||
// When updating an existing context, populate the following lists with:
|
||||
// * added_workers: set(remote_workers) - set(curr_remote_workers)
|
||||
// * removed_workers: set(curr_remote_workers) - set(remote_workers)
|
||||
// * existing_workers: set(curr_remote_workers) intersect set(remote_workers)
|
||||
// * replaced_workers: workers with the same task names and potentially the
|
||||
// same `hostname:port`s, but replaced by different processes
|
||||
std::vector<string> added_workers;
|
||||
std::vector<string> removed_workers;
|
||||
std::vector<string> existing_workers;
|
||||
std::vector<string> replaced_workers;
|
||||
|
||||
// New remote device manager created for new server_def. Unused if updating
|
||||
// server_def.
|
||||
std::unique_ptr<tensorflow::DynamicDeviceMgr> new_remote_device_mgr;
|
||||
tensorflow::DynamicDeviceMgr* remote_device_mgr = nullptr;
|
||||
if (reset_context) {
|
||||
LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices(
|
||||
remote_workers, grpc_server->master_env()->worker_cache,
|
||||
&new_remote_device_mgr));
|
||||
remote_device_mgr = new_remote_device_mgr.get();
|
||||
} else {
|
||||
context->ClearCachesAndDefaultExecutor();
|
||||
// TODO(b/143914772): Potential memory leak if rendezvous has pending
|
||||
// tensors for removed / replaced workers.
|
||||
|
||||
remote_device_mgr = context->GetOwnedRemoteDeviceMgr();
|
||||
if (remote_device_mgr == nullptr) {
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::errors::InvalidArgument(
|
||||
"Updating context with an invalid set of remote devices."));
|
||||
}
|
||||
std::sort(curr_remote_workers.begin(), curr_remote_workers.end());
|
||||
std::sort(remote_workers.begin(), remote_workers.end());
|
||||
DifferentiateWorkerLists(&curr_remote_workers, &remote_workers,
|
||||
&added_workers, &removed_workers,
|
||||
&existing_workers);
|
||||
sg.Update(GetReplacedFromExistingWorkers(
|
||||
&existing_workers, context_id, context->GetContextViewId(), server_def,
|
||||
remote_eager_workers.get(), &replaced_workers));
|
||||
if (VLOG_IS_ON(1)) {
|
||||
VLOG(1) << "Updating cluster with following changes";
|
||||
for (const string& w : added_workers) VLOG(1) << " Added worker " << w;
|
||||
for (const string& w : removed_workers)
|
||||
VLOG(1) << " Removed worker " << w;
|
||||
for (const string& w : replaced_workers)
|
||||
VLOG(1) << " Replaced worker " << w;
|
||||
}
|
||||
if (!replaced_workers.empty()) {
|
||||
// Treat replaced workers as removed then added back, so that we recreate
|
||||
// remote devices and contexts, and re-register functions on those workers
|
||||
removed_workers.insert(removed_workers.end(), replaced_workers.begin(),
|
||||
replaced_workers.end());
|
||||
added_workers.insert(added_workers.end(), replaced_workers.begin(),
|
||||
replaced_workers.end());
|
||||
for (const string& w : replaced_workers) {
|
||||
existing_workers.erase(
|
||||
std::remove(existing_workers.begin(), existing_workers.end(), w),
|
||||
existing_workers.end());
|
||||
}
|
||||
}
|
||||
sg.Update(RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr));
|
||||
sg.Update(AddRemoteDevicesToMgr(added_workers,
|
||||
grpc_server->master_env()->worker_cache,
|
||||
remote_device_mgr));
|
||||
}
|
||||
|
||||
std::vector<tensorflow::DeviceAttributes> cluster_device_attributes;
|
||||
remote_device_mgr->ListDeviceAttributes(&cluster_device_attributes);
|
||||
|
||||
std::vector<tensorflow::DeviceAttributes> local_device_attributes;
|
||||
grpc_server->worker_env()->device_mgr->ListDeviceAttributes(
|
||||
&local_device_attributes);
|
||||
|
||||
// This request make sure that we can create Rendezvous properly between
|
||||
// Local and Remote context.
|
||||
tensorflow::eager::CreateContextRequest base_request;
|
||||
for (const auto& da : cluster_device_attributes) {
|
||||
*base_request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
for (const auto& da : local_device_attributes) {
|
||||
*base_request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
|
||||
// Initialize remote eager workers.
|
||||
if (reset_context) {
|
||||
const tensorflow::Status s = CreateRemoteContexts(
|
||||
ctx, remote_workers, context_id, context_view_id, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||
context->LazyCopyFunctionRemoteInputs(), base_request);
|
||||
// NOTE: the remote tasks could fail after `GetAllRemoteDevices` and cause
|
||||
// the CreateRemoteContexts to fail. We currently only log instead of
|
||||
// directly returning the error, since returning here will cause the server
|
||||
// object to be destroyed (which currently CHECK-fails). The client will
|
||||
// see additional errors if ops are subsequently sent to the failed workers.
|
||||
if (TF_PREDICT_FALSE(!s.ok())) {
|
||||
LOG(ERROR) << "Error when creating contexts on remote targets: "
|
||||
<< s.error_message()
|
||||
<< "\nExecuting remote ops or functions on these remote "
|
||||
"targets will fail.";
|
||||
}
|
||||
} else {
|
||||
if (sg.ok()) {
|
||||
// Create remote contexts on the newly added workers only if the master
|
||||
// has collected all device information from them (i.e., the
|
||||
// GetAllRemoteDevices call returns succussfully). Note that in rare cases
|
||||
// GetAllRemoteDevices can still fail even with RPCs configured to wait
|
||||
// until the remote workers to become alive. If the master creates remote
|
||||
// contexts on the workers whose devices are still not collected, those
|
||||
// workers will be treated as existing workers subsequently, so the master
|
||||
// will never get devices from them even with retrying UpdateServerDef.
|
||||
sg.Update(CreateRemoteContexts(
|
||||
ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||
context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||
}
|
||||
if (!existing_workers.empty()) {
|
||||
if (VLOG_IS_ON(1)) {
|
||||
for (const string& w : existing_workers) {
|
||||
VLOG(1) << "Updating cluster with existing worker " << w;
|
||||
}
|
||||
}
|
||||
// The master's context_view_id will be incremented by one in the
|
||||
// UpdateRemoteMaster call later. We want existing workers to also have
|
||||
// the updated context_view_id, so we must set their context_view_id to
|
||||
// the master's current context_view_id + 1.
|
||||
sg.Update(UpdateRemoteContexts(ctx, existing_workers, added_workers,
|
||||
removed_workers, context_id,
|
||||
context_view_id + 1, server_def,
|
||||
remote_eager_workers.get(), base_request));
|
||||
}
|
||||
}
|
||||
|
||||
auto session_name = tensorflow::strings::StrCat("eager_", context_id);
|
||||
if (reset_context) {
|
||||
tensorflow::RemoteRendezvous* r =
|
||||
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
|
||||
auto* device_mgr = grpc_server->worker_env()->device_mgr;
|
||||
std::shared_ptr<tensorflow::WorkerSession> worker_session;
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->worker_env()->session_mgr->CreateSession(
|
||||
session_name, server_def, base_request.cluster_device_attributes(),
|
||||
true));
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
|
||||
session_name, &worker_session));
|
||||
|
||||
// Initialize remote tensor communication based on worker session.
|
||||
LOG_AND_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
|
||||
|
||||
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
|
||||
tensorflow::eager::CreateClusterFLR(context_id, context,
|
||||
worker_session.get());
|
||||
auto remote_mgr = absl::make_unique<tensorflow::eager::RemoteMgr>(
|
||||
/*is_master=*/true, context);
|
||||
|
||||
LOG_AND_RETURN_IF_ERROR(context->InitializeRemoteMaster(
|
||||
std::move(new_server), grpc_server->worker_env(), worker_session,
|
||||
std::move(remote_eager_workers), std::move(new_remote_device_mgr),
|
||||
remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr,
|
||||
std::move(remote_mgr)));
|
||||
|
||||
// NOTE: We start the server after all other initialization, because the
|
||||
// GrpcServer cannot be destroyed after it is started.
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
|
||||
} else {
|
||||
sg.Update(grpc_server->worker_env()->session_mgr->UpdateSession(
|
||||
session_name, server_def, base_request.cluster_device_attributes(),
|
||||
/*isolate_session_state=*/true));
|
||||
sg.Update(context->UpdateRemoteMaster(context_id,
|
||||
std::move(remote_eager_workers),
|
||||
added_workers, removed_workers));
|
||||
LOG_AND_RETURN_IF_ERROR(sg.as_summary_status());
|
||||
}
|
||||
#undef LOG_AND_RETURN_IF_ERROR
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
@ -735,7 +105,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
#else
|
||||
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
||||
return nullptr;
|
||||
#endif
|
||||
#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
|
||||
}
|
||||
std::vector<std::unique_ptr<tensorflow::Device>> devices;
|
||||
status->status = tensorflow::DeviceFactory::AddDevices(
|
||||
@ -747,13 +117,18 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
|
||||
tensorflow::Rendezvous* r =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||
|
||||
return tensorflow::wrap(new tensorflow::EagerContext(
|
||||
tensorflow::EagerContext* eager_context = new tensorflow::EagerContext(
|
||||
opts->session_options.options,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
opts->device_placement_policy),
|
||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
|
||||
/*device_mgr_owned*/ true, r));
|
||||
/*device_mgr_owned*/ true, r);
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
eager_context->SetDistributedManager(
|
||||
std::make_unique<tensorflow::EagerContextDistributedManager>(
|
||||
eager_context));
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
return tensorflow::wrap(eager_context);
|
||||
}
|
||||
|
||||
void TFE_DeleteContext(TFE_Context* ctx) {
|
||||
@ -791,26 +166,9 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
||||
"Invalid tensorflow.ServerDef protocol buffer");
|
||||
return;
|
||||
}
|
||||
if (server_def.has_cluster_device_filters()) {
|
||||
const auto& cdf = server_def.cluster_device_filters();
|
||||
for (const auto& jdf : cdf.jobs()) {
|
||||
const string remote_prefix = "/job:" + jdf.name() + "/task:";
|
||||
for (const auto& tdf : jdf.tasks()) {
|
||||
const int32_t task_index = tdf.first;
|
||||
std::vector<string> device_filters(tdf.second.device_filters_size());
|
||||
for (int i = 0; i < tdf.second.device_filters_size(); i++) {
|
||||
device_filters[i] = tdf.second.device_filters(i);
|
||||
}
|
||||
const string remote_worker = remote_prefix + std::to_string(task_index);
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status =
|
||||
context->SetRemoteDeviceFilters(remote_worker, device_filters);
|
||||
}
|
||||
}
|
||||
}
|
||||
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
|
||||
ctx, /*reset_context=*/true);
|
||||
status->status =
|
||||
tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
|
||||
server_def, /*reset_context=*/true, keep_alive_secs);
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
@ -835,14 +193,9 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Trying to update a context with invalid context id.");
|
||||
}
|
||||
if (server_def.has_cluster_device_filters()) {
|
||||
LOG(WARNING) << "Device filters can only be specified when initializing "
|
||||
"the cluster. Any changes in device filters are ignored "
|
||||
"when updating the server def.";
|
||||
}
|
||||
// TODO(haoyuzhang): Check server_def compatibility before the update
|
||||
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
|
||||
ctx, /*reset_context=*/false);
|
||||
status->status =
|
||||
tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
|
||||
server_def, /*reset_context=*/false, keep_alive_secs);
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
@ -854,44 +207,11 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
"TFE_ContextSetServerDef not supported on mobile");
|
||||
return false;
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::GrpcServer* grpc_server =
|
||||
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||
if (grpc_server == nullptr) {
|
||||
status->status =
|
||||
tensorflow::errors::Internal("Failed to get tensorflow::GrpcServer.");
|
||||
return false;
|
||||
}
|
||||
tensorflow::WorkerInterface* wi =
|
||||
grpc_server->master_env()->worker_cache->GetOrCreateWorker(worker_name);
|
||||
if (wi == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Unable to find worker interface corresponding to task ", worker_name);
|
||||
return false;
|
||||
}
|
||||
|
||||
tensorflow::GetStatusRequest request;
|
||||
tensorflow::GetStatusResponse response;
|
||||
tensorflow::Status remote_status;
|
||||
tensorflow::Notification done;
|
||||
wi->GetStatusAsync(/*opts_=*/nullptr, &request, &response, /*fail_fast=*/true,
|
||||
[&remote_status, &done](const tensorflow::Status& s) {
|
||||
remote_status = s;
|
||||
done.Notify();
|
||||
});
|
||||
done.WaitForNotification();
|
||||
|
||||
// We set OK status so the call does not raise any exceptions. Instead, caller
|
||||
// users the return value to tell if the remote worker is alive.
|
||||
status->status = tensorflow::Status::OK();
|
||||
|
||||
if (remote_status.ok()) {
|
||||
return true;
|
||||
}
|
||||
LOG(INFO) << "Remote worker " << worker_name
|
||||
<< " is not alive: " << remote_status.error_message();
|
||||
return false;
|
||||
bool is_alive;
|
||||
status->status =
|
||||
tensorflow::unwrap(ctx)->GetDistributedManager()->CheckRemoteAlive(
|
||||
worker_name, &is_alive);
|
||||
return is_alive;
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
|
@ -134,7 +134,9 @@ TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
|
||||
}
|
||||
|
||||
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
|
||||
TF_DataType dtype, TF_Status* s) {
|
||||
TF_DataType dtype, TF_Shape shape,
|
||||
TF_Status* s) {
|
||||
DCHECK_GE(shape.num_dims, -1);
|
||||
TracingTensorHandle* t;
|
||||
TracingContext* tracing_ctx = dyn_cast<TracingContext>(unwrap(func));
|
||||
if (!tracing_ctx) {
|
||||
@ -143,8 +145,20 @@ TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
|
||||
"TF_AddFunctionParameter must be called on a TracingContext."));
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::PartialTensorShape partial_shape;
|
||||
if (shape.num_dims != -1) {
|
||||
DCHECK(shape.dim_sizes != nullptr);
|
||||
Status status = tensorflow::PartialTensorShape::MakePartialShape(
|
||||
reinterpret_cast<tensorflow::int64*>(shape.dim_sizes), shape.num_dims,
|
||||
&partial_shape);
|
||||
if (!status.ok()) {
|
||||
Set_TF_Status_from_Status(s, status);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
Set_TF_Status_from_Status(
|
||||
s, tracing_ctx->AddParameter(static_cast<DataType>(dtype), &t));
|
||||
s, tracing_ctx->AddParameter(static_cast<DataType>(dtype), partial_shape,
|
||||
&t));
|
||||
return wrap(t);
|
||||
}
|
||||
|
||||
|
@ -64,10 +64,16 @@ TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*,
|
||||
TF_Status* s);
|
||||
void TF_DeleteExecutionContext(TF_ExecutionContext*);
|
||||
|
||||
// Represents a (partially-defined) shape.
|
||||
typedef struct TF_Shape {
|
||||
int num_dims; // Must be >= -1; -1 represents unknown rank.
|
||||
int64_t* dim_sizes;
|
||||
} TF_Shape;
|
||||
|
||||
// Add a new parameter to a TensorFlow Function.
|
||||
// TODO(aminim): what about shape?
|
||||
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
|
||||
TF_DataType dtype, TF_Status* s);
|
||||
TF_DataType dtype, TF_Shape shape,
|
||||
TF_Status* s);
|
||||
|
||||
// Create an operation suitable to use with the provided context. The operation
|
||||
// requires its type (e.g. "AddV2") to be set independently.
|
||||
|
@ -25,6 +25,8 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
@ -43,22 +45,50 @@ class GraphContext;
|
||||
class GraphOperation;
|
||||
class GraphTensor;
|
||||
|
||||
auto& kUnknownDim = shape_inference::InferenceContext::kUnknownDim;
|
||||
auto& kUnknownRank = shape_inference::InferenceContext::kUnknownRank;
|
||||
|
||||
// GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index
|
||||
// into the list of outputs for the operation.
|
||||
class GraphTensor : public TracingTensorHandle {
|
||||
public:
|
||||
explicit GraphTensor(TF_Output output)
|
||||
: TracingTensorHandle(kGraph), output_(output) {}
|
||||
explicit GraphTensor(TF_Output output, TF_Graph* graph)
|
||||
: TracingTensorHandle(kGraph), output_(output), graph_(graph) {}
|
||||
|
||||
tensorflow::DataType DataType() const override {
|
||||
return static_cast<tensorflow::DataType>(TF_OperationOutputType(output_));
|
||||
}
|
||||
|
||||
tensorflow::Status Shape(
|
||||
tensorflow::PartialTensorShape* shape) const override {
|
||||
DCHECK(shape != nullptr);
|
||||
TF_Status status;
|
||||
int num_dims = TF_GraphGetTensorNumDims(graph_, output_, &status);
|
||||
DCHECK_GE(num_dims, -1);
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(&status));
|
||||
if (num_dims == kUnknownRank) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<int64> dims(num_dims, kUnknownDim);
|
||||
TF_GraphGetTensorShape(graph_, output_,
|
||||
reinterpret_cast<int64_t*>(dims.data()), num_dims,
|
||||
&status);
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(&status));
|
||||
TF_RETURN_IF_ERROR(tensorflow::TensorShapeUtils::MakeShape(dims, shape));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TF_Output output_;
|
||||
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractTensorHandle* ptr) {
|
||||
return ptr->getKind() == kGraph;
|
||||
}
|
||||
|
||||
private:
|
||||
TF_Graph* graph_; // For shape inference.
|
||||
};
|
||||
|
||||
// GraphOperation wraps and populates a TF_OperationDescription.
|
||||
@ -135,7 +165,7 @@ class GraphOperation : public TracingOperation {
|
||||
TF_DeleteStatus(s);
|
||||
*num_retvals = TF_OperationNumOutputs(operation);
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals[i] = new GraphTensor({operation, i});
|
||||
retvals[i] = new GraphTensor({operation, i}, g_);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -326,12 +356,18 @@ class GraphContext : public TracingContext {
|
||||
return new GraphOperation(graph_.get());
|
||||
}
|
||||
|
||||
Status AddParameter(DataType dtype, TracingTensorHandle** output) override {
|
||||
Status AddParameter(DataType dtype, const PartialTensorShape& shape,
|
||||
TracingTensorHandle** output) override {
|
||||
TracingOperationPtr operation(CreateOperation());
|
||||
TF_RETURN_IF_ERROR(operation->Reset("Placeholder", nullptr));
|
||||
TF_RETURN_IF_ERROR(
|
||||
operation->SetOpName(absl::StrCat("_input_", inputs_.size()).c_str()));
|
||||
TF_RETURN_IF_ERROR(operation->SetAttrType("dtype", dtype));
|
||||
if (!shape.unknown_rank()) {
|
||||
TF_RETURN_IF_ERROR(operation->SetAttrShape(
|
||||
"shape", reinterpret_cast<int64_t*>(shape.dim_sizes().data()),
|
||||
shape.dims()));
|
||||
}
|
||||
int num_outputs = 1;
|
||||
std::vector<AbstractTensorHandle*> outputs(num_outputs);
|
||||
TF_RETURN_IF_ERROR(operation->Execute(
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -107,7 +108,8 @@ class TracingContext : public AbstractContext {
|
||||
|
||||
public:
|
||||
// Add a function parameter and return the corresponding tensor.
|
||||
virtual Status AddParameter(DataType dtype, TracingTensorHandle**) = 0;
|
||||
virtual Status AddParameter(DataType dtype, const PartialTensorShape& shape,
|
||||
TracingTensorHandle**) = 0;
|
||||
|
||||
// Finalize this context and make a function out of it. The context is in a
|
||||
// invalid state after this call and must be destroyed.
|
||||
|
@ -359,7 +359,7 @@ TEST_P(UnifiedCAPI, TestBasicGraph) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
auto* placeholder_t =
|
||||
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
|
||||
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract operation.
|
||||
@ -450,7 +450,7 @@ TEST_P(UnifiedCAPI, TestBasicGraphMatMul) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
auto* placeholder_t =
|
||||
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
|
||||
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract operation.
|
||||
@ -553,9 +553,9 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
||||
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
||||
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Create a first "Add" computing `arg0 + arg1`.
|
||||
@ -709,9 +709,9 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraphMatMul) {
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
||||
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
||||
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Create a first "Add" computing `arg0 + arg1`.
|
||||
@ -975,7 +975,7 @@ TEST_P(UnifiedCAPI, TF_AbstractTensorGetEagerTensorOnGraphTensorRaises) {
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
auto placeholder_t =
|
||||
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
|
||||
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, status.get());
|
||||
TF_AbstractTensorGetEagerTensor(placeholder_t, status.get());
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/eager/unified_api_testutil.h"
|
||||
#include "tensorflow/c/experimental/gradients/array_grad.h"
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
|
||||
@ -65,6 +66,8 @@ Status RegisterGradients(GradientRegistry* registry) {
|
||||
TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Sub", SubRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Mul", MulRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Log1p", Log1pRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("DivNoNan", DivNoNanRegisterer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -73,8 +76,10 @@ Status RegisterGradients(GradientRegistry* registry) {
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
Status AddGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
@ -107,8 +112,10 @@ Status AddGradModel(AbstractContext* ctx,
|
||||
// return grad(y, {inputs[0]})
|
||||
Status ExpGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
@ -137,8 +144,10 @@ Status ExpGradModel(AbstractContext* ctx,
|
||||
// return grad(y, {inputs[0]})
|
||||
Status SqrtGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
@ -168,8 +177,10 @@ Status SqrtGradModel(AbstractContext* ctx,
|
||||
// This should return [nullptr, 1].
|
||||
Status IdentityNGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0]));
|
||||
@ -202,8 +213,10 @@ Status IdentityNGradModel(AbstractContext* ctx,
|
||||
// return grad(y, {inputs[0]})
|
||||
Status NegGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0]));
|
||||
@ -233,8 +246,10 @@ Status NegGradModel(AbstractContext* ctx,
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
Status SubGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
@ -267,8 +282,10 @@ Status SubGradModel(AbstractContext* ctx,
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
Status MulGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
@ -297,122 +314,72 @@ Status MulGradModel(AbstractContext* ctx,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
AbstractContext* BuildFunction(const char* fn_name) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get());
|
||||
return unwrap(graph_ctx);
|
||||
}
|
||||
// Computes
|
||||
// y = log(1 + inputs[0])
|
||||
// return grad(y, {inputs[0]})
|
||||
Status Log1pGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
std::vector<AbstractTensorHandle*> log1p_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Log1p(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(log1p_outputs),
|
||||
"Log1p")); // Compute log(1 + x).
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
Status CreateParamsForInputs(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
std::vector<AbstractTensorHandle*>* params) {
|
||||
tracing::TracingTensorHandle* handle = nullptr;
|
||||
for (auto input : inputs) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
|
||||
input->DataType(), &handle));
|
||||
params->emplace_back(handle);
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(log1p_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
for (auto log1p_output : log1p_outputs) {
|
||||
log1p_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
using Model = std::function<Status(
|
||||
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
|
||||
absl::Span<AbstractTensorHandle*>, const GradientRegistry&)>;
|
||||
// Computes
|
||||
// y = inputs[0] / inputs[1]
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
Status DivNoNanGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
std::vector<AbstractTensorHandle*> div_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::DivNoNan(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(div_outputs),
|
||||
"DivNoNan")); // Compute x / y.
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
// Runs `model` maybe wrapped in a function.
|
||||
Status RunModel(Model model, AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
|
||||
const GradientRegistry& registry) {
|
||||
if (use_function) {
|
||||
const char* fn_name = "test_fn";
|
||||
std::unique_ptr<AbstractFunction> scoped_func;
|
||||
// Returning null tensors from a tf.function is not supported, so we keep
|
||||
// track of indices in the model's outputs are nullptr in this set.
|
||||
// The FunctionDef only outputs the non-null tensors. We later pad the
|
||||
// function op outputs to have nullptrs at the `null_indices`.
|
||||
absl::flat_hash_set<int> null_indices;
|
||||
{
|
||||
AbstractContextPtr func_ctx(BuildFunction(fn_name));
|
||||
std::vector<AbstractTensorHandle*> func_inputs;
|
||||
func_inputs.reserve(inputs.size());
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
|
||||
vector<AbstractTensorHandle*> model_outputs;
|
||||
model_outputs.resize(outputs.size());
|
||||
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
|
||||
absl::MakeSpan(model_outputs), registry));
|
||||
for (auto func_input : func_inputs) {
|
||||
func_input->Unref();
|
||||
}
|
||||
AbstractFunction* func = nullptr;
|
||||
OutputList output_list;
|
||||
output_list.expected_num_outputs = 0;
|
||||
output_list.outputs.reserve(outputs.size());
|
||||
for (int i = 0; i < model_outputs.size(); i++) {
|
||||
if (model_outputs[i]) {
|
||||
output_list.outputs.emplace_back(model_outputs[i]);
|
||||
output_list.expected_num_outputs += 1;
|
||||
} else {
|
||||
null_indices.insert(i);
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
|
||||
->Finalize(&output_list, &func));
|
||||
scoped_func.reset(func);
|
||||
for (auto output : output_list.outputs) {
|
||||
output->Unref();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
|
||||
}
|
||||
|
||||
AbstractOperationPtr fn_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr));
|
||||
for (auto input : inputs) {
|
||||
TF_RETURN_IF_ERROR(fn_op->AddInput(input));
|
||||
}
|
||||
int retvals = outputs.size() - null_indices.size();
|
||||
vector<AbstractTensorHandle*> fn_outputs(retvals);
|
||||
TF_RETURN_IF_ERROR(fn_op->Execute(
|
||||
absl::Span<AbstractTensorHandle*>(fn_outputs.data(), fn_outputs.size()),
|
||||
&retvals));
|
||||
int skipped_indices = 0;
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
if (!null_indices.contains(i)) {
|
||||
outputs[i] = fn_outputs[i - skipped_indices];
|
||||
} else {
|
||||
skipped_indices += 1;
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
|
||||
return Status::OK();
|
||||
} else {
|
||||
return model(ctx, inputs, outputs, registry);
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(div_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
for (auto div_output : div_outputs) {
|
||||
div_output->Unref();
|
||||
}
|
||||
}
|
||||
|
||||
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||
*ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get()));
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_DeleteContextOptions(opts);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -467,7 +434,7 @@ TEST_P(CppGradients, TestAddGrad) {
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
s = RunModel(AddGradModel, ctx.get(), {x.get(), y.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
/*use_function=*/!std::get<2>(GetParam()));
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
@ -507,18 +474,15 @@ TEST_P(CppGradients, TestExpGrad) {
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x)
|
||||
// y = exp(x)
|
||||
// outputs = tape.gradient(y, x)
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
s = RunModel(ExpGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
Status s =
|
||||
RunModel(ExpGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()));
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
@ -551,18 +515,15 @@ TEST_P(CppGradients, TestSqrtGrad) {
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x)
|
||||
// y = sqrt(x)
|
||||
// outputs = tape.gradient(y, x)
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
s = RunModel(SqrtGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
Status s =
|
||||
RunModel(SqrtGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()));
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
@ -620,7 +581,7 @@ TEST_P(CppGradients, TestIdentityNGrad) {
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
s = RunModel(IdentityNGradModel, ctx.get(), {x1.get(), x2.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
/*use_function=*/!std::get<2>(GetParam()));
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
EXPECT_EQ(outputs[0], nullptr);
|
||||
@ -665,7 +626,7 @@ TEST_P(CppGradients, TestNegGrad) {
|
||||
// outputs = tape.gradient(y, x)
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
s = RunModel(NegGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
/*use_function=*/!std::get<2>(GetParam()));
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
@ -706,10 +667,6 @@ TEST_P(CppGradients, TestSubGrad) {
|
||||
y.reset(y_raw);
|
||||
}
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x)
|
||||
@ -717,9 +674,9 @@ TEST_P(CppGradients, TestSubGrad) {
|
||||
// y = x - y
|
||||
// outputs = tape.gradient(y, [x, y])
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
s = RunModel(SubGradModel, ctx.get(), {x.get(), y.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
Status s = RunModel(SubGradModel, ctx.get(), {x.get(), y.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()));
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
@ -767,10 +724,6 @@ TEST_P(CppGradients, TestMulGrad) {
|
||||
y.reset(y_raw);
|
||||
}
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x)
|
||||
@ -778,9 +731,9 @@ TEST_P(CppGradients, TestMulGrad) {
|
||||
// y = x * y
|
||||
// outputs = tape.gradient(y, [x, y])
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
s = RunModel(MulGradModel, ctx.get(), {x.get(), y.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
Status s = RunModel(MulGradModel, ctx.get(), {x.get(), y.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()));
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
@ -800,6 +753,104 @@ TEST_P(CppGradients, TestMulGrad) {
|
||||
TF_DeleteTensor(result_tensor);
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestLog1pGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x)
|
||||
// y = log(1 + x)
|
||||
// outputs = tape.gradient(y, x)
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
Status s =
|
||||
RunModel(Log1pGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()));
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
s = getValue(outputs[0], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_NEAR(*result_value, 0.5, 0.001);
|
||||
outputs[0]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
result_tensor = nullptr;
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestDivNoNanGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr y;
|
||||
{
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
y.reset(y_raw);
|
||||
}
|
||||
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x)
|
||||
// tape.watch(y)
|
||||
// y = x / y
|
||||
// outputs = tape.gradient(y, [x, y])
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
Status s = RunModel(DivNoNanGradModel, ctx.get(), {x.get(), y.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()));
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
s = getValue(outputs[0], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_NEAR(*result_value, 0.5, 0.001);
|
||||
outputs[0]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
result_tensor = nullptr;
|
||||
|
||||
s = getValue(outputs[1], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_NEAR(*result_value, -0.25, 0.001);
|
||||
outputs[1]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestSetAttrString) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
@ -224,8 +225,10 @@ Status CreateParamsForInputs(AbstractContext* ctx,
|
||||
vector<AbstractTensorHandle*>* params) {
|
||||
tracing::TracingTensorHandle* handle = nullptr;
|
||||
for (auto input : inputs) {
|
||||
PartialTensorShape shape;
|
||||
TF_RETURN_IF_ERROR(input->Shape(&shape));
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
|
||||
input->DataType(), &handle));
|
||||
input->DataType(), shape, &handle));
|
||||
params->emplace_back(handle);
|
||||
}
|
||||
return Status::OK();
|
||||
@ -314,4 +317,4 @@ Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
} // namespace tensorflow
|
||||
|
@ -21,12 +21,15 @@ limitations under the License.
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_distributed_manager.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
@ -138,8 +141,8 @@ class ImmediateExecutionContext : public AbstractContext {
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Following are legacy features in TF Eager Runtime.
|
||||
// TODO(tf-runtime): Figure out a way to deprecate following features after
|
||||
// Following are features in current TF Eager Runtime.
|
||||
// TODO(tfrt-devs): Figure out a way to deprecate following features after
|
||||
// migrated to TFRT.
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Clear pending nodes in thread executors and kernel caches.
|
||||
@ -157,6 +160,34 @@ class ImmediateExecutionContext : public AbstractContext {
|
||||
// Update the Eager Executor for current thread.
|
||||
virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Following are helper functions to assist integrating TFRT with current
|
||||
// TF eager runtime.
|
||||
// TODO(b/172877902): These helper functions are currently used to support
|
||||
// PyFuncOp on TFRT, and might be useful for ops that directly use low
|
||||
// level TF APIs. Remove/replace the following functions when TFRT native
|
||||
// ops are implemented.
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Create an abstract tensor handle from tensorflow::Tensor.
|
||||
virtual ImmediateExecutionTensorHandle* CreateLocalHandleFromTFTensor(
|
||||
tensorflow::Tensor& t, const char* d_name) = 0;
|
||||
|
||||
// Convert a TFRT TensorHandle to tensorflow::TensorHandle.
|
||||
virtual ImmediateExecutionTensorHandle* TFTensorHandleFromInterface(
|
||||
ImmediateExecutionTensorHandle* handle) = 0;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Distributed runtime related functions.
|
||||
//===--------------------------------------------------------------------===//
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
// Set a distributed manager that helps set up, update, and check liveness
|
||||
// of member tasks in the cluster.
|
||||
virtual void SetDistributedManager(
|
||||
std::unique_ptr<ImmediateExecutionDistributedManager> distributed) = 0;
|
||||
|
||||
virtual ImmediateExecutionDistributedManager* GetDistributedManager() = 0;
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
|
||||
protected:
|
||||
explicit ImmediateExecutionContext(AbstractContextKind kind)
|
||||
: AbstractContext(kind) {}
|
||||
|
45
tensorflow/c/eager/immediate_execution_distributed_manager.h
Normal file
45
tensorflow/c/eager/immediate_execution_distributed_manager.h
Normal file
@ -0,0 +1,45 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EAGER_immediate_execution_distributed_manager_H_
|
||||
#define TENSORFLOW_C_EAGER_immediate_execution_distributed_manager_H_
|
||||
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class ImmediateExecutionContext;
|
||||
class ServerDef;
|
||||
|
||||
class ImmediateExecutionDistributedManager {
|
||||
public:
|
||||
virtual ~ImmediateExecutionDistributedManager() {}
|
||||
|
||||
// Set up distributed execution environment on local and remote tasks.
|
||||
// When `reset_context` is true, initialize new cluster context state based on
|
||||
// cluster configurations provided in `server_def`; otherwise, update existing
|
||||
// context state with the provided `server_def`.
|
||||
// Contexts created on remote tasks will be considered stale and garbage
|
||||
// collected after `keep_alive_secs` of inactivity.
|
||||
virtual Status SetOrUpdateServerDef(const ServerDef& server_def,
|
||||
bool reset_context,
|
||||
int keep_alive_secs) = 0;
|
||||
|
||||
// Check if the remote task is alive.
|
||||
virtual Status CheckRemoteAlive(const std::string& remote_task_name,
|
||||
bool* is_alive) = 0;
|
||||
};
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_immediate_execution_distributed_manager_H_
|
@ -76,6 +76,7 @@ cc_library(
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
|
@ -328,6 +328,17 @@ ParallelDevice::Execute(TFE_Context* context,
|
||||
const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int expected_max_outputs,
|
||||
TF_Status* status) const {
|
||||
std::vector<PartialTensorShape> expected_output_shapes(expected_max_outputs);
|
||||
return Execute(context, inputs, operation_name, attributes,
|
||||
expected_output_shapes, status);
|
||||
}
|
||||
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
ParallelDevice::Execute(
|
||||
TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
const std::vector<PartialTensorShape>& expected_output_shapes,
|
||||
TF_Status* status) const {
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
|
||||
// Compute per-device per-output tensors
|
||||
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
|
||||
@ -344,7 +355,7 @@ ParallelDevice::Execute(TFE_Context* context,
|
||||
}
|
||||
device_thread->StartExecute(context, operation_name,
|
||||
std::move(device_inputs), attributes,
|
||||
expected_max_outputs);
|
||||
expected_output_shapes.size());
|
||||
}
|
||||
StatusPtr first_bad_status(nullptr);
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
@ -386,8 +397,15 @@ ParallelDevice::Execute(TFE_Context* context,
|
||||
for (int j = 0; j < underlying_devices_.size(); ++j) {
|
||||
components.push_back(std::move(per_device_output_tensors[j][i]));
|
||||
}
|
||||
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
|
||||
*this, std::move(components), status));
|
||||
if (expected_output_shapes[i].IsFullyDefined()) {
|
||||
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
|
||||
*this, std::move(components),
|
||||
absl::Span<const int64>(expected_output_shapes[i].dim_sizes()),
|
||||
status));
|
||||
} else {
|
||||
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
|
||||
*this, std::move(components), status));
|
||||
}
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
result.emplace(std::move(per_device_outputs));
|
||||
@ -396,9 +414,27 @@ ParallelDevice::Execute(TFE_Context* context,
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
|
||||
const ParallelDevice& parallel_device,
|
||||
std::vector<TensorHandlePtr> components, TF_Status* status) {
|
||||
std::vector<TensorHandlePtr> components, absl::Span<const int64> shape,
|
||||
TF_Status* status) {
|
||||
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
|
||||
std::vector<int64_t> shape(
|
||||
// Verify that the TensorHandle's shape and dtype match all of the component
|
||||
// shapes and dtypes.
|
||||
for (TensorHandlePtr& component : components) {
|
||||
if (TFE_TensorHandleDataType(component.get()) != dtype) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Components of a ParallelTensor must all have "
|
||||
"the same dtype");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return std::unique_ptr<ParallelTensor>(
|
||||
new ParallelTensor(parallel_device, std::move(components), shape, dtype));
|
||||
}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
|
||||
const ParallelDevice& parallel_device,
|
||||
std::vector<TensorHandlePtr> components, TF_Status* status) {
|
||||
std::vector<int64> shape(
|
||||
TFE_TensorHandleNumDims(components[0].get(), status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
@ -406,11 +442,10 @@ std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
|
||||
// Verify that the TensorHandle's shape and dtype match all of the component
|
||||
// shapes and dtypes.
|
||||
// Verify that the TensorHandle's shape matches all of the component shapes.
|
||||
for (TensorHandlePtr& component : components) {
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
|
||||
int64 tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
if (tensor_dim != shape[i]) {
|
||||
// TODO(allenl): Allow shapes to differ.
|
||||
@ -419,17 +454,10 @@ std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
|
||||
"the same shape");
|
||||
return nullptr;
|
||||
}
|
||||
if (TFE_TensorHandleDataType(component.get()) != dtype) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Components of a ParallelTensor must all have "
|
||||
"the same dtype");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::unique_ptr<ParallelTensor>(new ParallelTensor(
|
||||
parallel_device, std::move(components), std::move(shape), dtype));
|
||||
return FromTensorHandles(parallel_device, std::move(components),
|
||||
absl::Span<const int64>(shape), status);
|
||||
}
|
||||
|
||||
} // namespace parallel_device
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace parallel_device {
|
||||
@ -93,6 +94,15 @@ class ParallelDevice {
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const;
|
||||
|
||||
// Accepts inferred shapes for outputs, which if fully defined will avoid
|
||||
// querying the shapes of the underlying TensorHandles. This allows async
|
||||
// computation to continue without blocking.
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute(
|
||||
TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
const std::vector<PartialTensorShape>& expected_output_shapes,
|
||||
TF_Status* status) const;
|
||||
|
||||
private:
|
||||
// A sequence of device names, indicating which devices replicated operations
|
||||
// are forwarded to.
|
||||
@ -117,10 +127,15 @@ class ParallelDevice {
|
||||
class ParallelTensor {
|
||||
public:
|
||||
// Construct a ParallelTensor from TensorHandles placed on the component
|
||||
// devices of a ParallelDevice.
|
||||
// devices of a ParallelDevice. Inspects `components` to determine a shape.
|
||||
static std::unique_ptr<ParallelTensor> FromTensorHandles(
|
||||
const ParallelDevice& parallel_device,
|
||||
std::vector<TensorHandlePtr> components, TF_Status* status);
|
||||
// Uses the provided shape without additional checks, which avoids blocking.
|
||||
static std::unique_ptr<ParallelTensor> FromTensorHandles(
|
||||
const ParallelDevice& parallel_device,
|
||||
std::vector<TensorHandlePtr> components, absl::Span<const int64> shape,
|
||||
TF_Status* status);
|
||||
|
||||
size_t num_tensors() const { return tensors_.size(); }
|
||||
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
|
||||
@ -132,10 +147,10 @@ class ParallelTensor {
|
||||
private:
|
||||
ParallelTensor(const ParallelDevice& device,
|
||||
std::vector<TensorHandlePtr> tensors,
|
||||
std::vector<int64_t> shape, const TF_DataType dtype)
|
||||
absl::Span<const int64> shape, const TF_DataType dtype)
|
||||
: device_(device),
|
||||
tensors_(std::move(tensors)),
|
||||
shape_(std::move(shape)),
|
||||
shape_(shape.begin(), shape.end()),
|
||||
dtype_(dtype) {}
|
||||
|
||||
const ParallelDevice& device_;
|
||||
|
@ -80,5 +80,41 @@ TEST(PARALLEL_DEVICE_LIB, TestOpWithError) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE_LIB, TestExplicitOutputShape) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*xla*/ false,
|
||||
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
||||
2),
|
||||
TF_DeleteBuffer);
|
||||
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
|
||||
status.get());
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
std::vector<std::string> devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1"};
|
||||
ParallelDevice parallel_device(std::move(devices));
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> handle_op(
|
||||
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetAttrType(handle_op.get(), "dtype", TF_FLOAT);
|
||||
TFE_OpSetAttrShape(handle_op.get(), "shape", /*dims=*/nullptr, /*num_dims=*/0,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
auto outputs = parallel_device.Execute(
|
||||
context.get(), std::vector<ParallelTensor*>(), "VarHandleOp",
|
||||
TFE_OpGetAttrs(handle_op.get()), {PartialTensorShape({})}, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
const std::vector<std::unique_ptr<ParallelTensor>>& handles = *outputs;
|
||||
EXPECT_EQ(0, handles[0]->shape().size());
|
||||
}
|
||||
|
||||
} // namespace parallel_device
|
||||
} // namespace tensorflow
|
||||
|
205
tensorflow/c/eager/unified_api_test.cc
Normal file
205
tensorflow/c/eager/unified_api_test.cc
Normal file
@ -0,0 +1,205 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/unified_api_testutil.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
class UnifiedAPI
|
||||
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
TF_StatusPtr status(TF_NewStatus());
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
|
||||
Status s = StatusFromTF_Status(status.get());
|
||||
CHECK_EQ(errors::OK, s.code()) << s.error_message();
|
||||
}
|
||||
|
||||
public:
|
||||
bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; }
|
||||
bool UseFunction() const { return std::get<2>(GetParam()); }
|
||||
};
|
||||
|
||||
// Checks that inputs[0] is a scalar.
|
||||
Status TestScalarShape(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
PartialTensorShape shape;
|
||||
TF_RETURN_IF_ERROR(inputs[0]->Shape(&shape));
|
||||
if (shape.dims() != 0) {
|
||||
return errors::InvalidArgument(
|
||||
"Tensor expected to have scalar shape found rank: ", shape.dims());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TEST_P(UnifiedAPI, TestTensorShapeScalar) {
|
||||
if (UseFunction() && UseMlir()) {
|
||||
// TODO(b/173074167): Remove this.
|
||||
GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
|
||||
}
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
Status s = RunModel(TestScalarShape, ctx.get(),
|
||||
/*inputs=*/{x.get()},
|
||||
/*outputs=*/{},
|
||||
/*use_function=*/UseFunction());
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
}
|
||||
|
||||
// Checks that inputs[0] is a matrix with shape 2x4.
|
||||
Status TestTensorShape2x4(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
PartialTensorShape shape;
|
||||
TF_RETURN_IF_ERROR(inputs[0]->Shape(&shape));
|
||||
if (shape.dims() != 2) {
|
||||
return errors::InvalidArgument(
|
||||
"Tensor expected to have rank 2 found rank: ", shape.dims());
|
||||
}
|
||||
int64 dim_sizes[] = {2, 4};
|
||||
for (int i = 0; i < shape.dims(); i++) {
|
||||
if (shape.dim_size(i) != dim_sizes[i]) {
|
||||
return errors::InvalidArgument("Dim ", i, " expected to be of size ",
|
||||
dim_sizes[i],
|
||||
" found: ", shape.dim_size(i));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TEST_P(UnifiedAPI, TestTensorShape2x4) {
|
||||
if (UseFunction() && UseMlir()) {
|
||||
// TODO(b/173074167): Remove this.
|
||||
GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
|
||||
}
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
float data[] = {0., 0., 0., 0., 0., 0., 0., 0};
|
||||
int64 dim_sizes[] = {2, 4};
|
||||
Status s =
|
||||
TestTensorHandleWithDimsFloat(ctx.get(), data, dim_sizes, 2, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
Status s = RunModel(TestTensorShape2x4, ctx.get(),
|
||||
/*inputs=*/{x.get()},
|
||||
/*outputs=*/{},
|
||||
/*use_function=*/UseFunction());
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
}
|
||||
|
||||
TEST_P(UnifiedAPI, TestUnknownShapeTracing) {
|
||||
if (!UseFunction()) {
|
||||
GTEST_SKIP() << "Tracing only test.";
|
||||
}
|
||||
if (UseMlir()) {
|
||||
// TODO(b/173074167): Remove this.
|
||||
GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
|
||||
}
|
||||
AbstractContextPtr ctx(BuildFunction("test_fn"));
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
tracing::TracingTensorHandle* x_raw = nullptr;
|
||||
PartialTensorShape shape;
|
||||
Status s = dyn_cast<tracing::TracingContext>(ctx.get())->AddParameter(
|
||||
DT_FLOAT, shape, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
PartialTensorShape shape;
|
||||
Status s = x->Shape(&shape);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ASSERT_TRUE(shape.unknown_rank());
|
||||
}
|
||||
|
||||
TEST_P(UnifiedAPI, TestPartialShapeTracing) {
|
||||
if (!UseFunction()) {
|
||||
GTEST_SKIP() << "Tracing only test.";
|
||||
}
|
||||
if (UseMlir()) {
|
||||
GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
|
||||
}
|
||||
AbstractContextPtr ctx(BuildFunction("test_fn"));
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
tracing::TracingTensorHandle* x_raw = nullptr;
|
||||
PartialTensorShape shape;
|
||||
int64 dim_sizes[] = {2, -1};
|
||||
Status s = PartialTensorShape::MakePartialShape(dim_sizes, 2, &shape);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
s = dyn_cast<tracing::TracingContext>(ctx.get())->AddParameter(
|
||||
DT_FLOAT, shape, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
PartialTensorShape shape;
|
||||
Status s = x->Shape(&shape);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ASSERT_FALSE(shape.unknown_rank());
|
||||
|
||||
ASSERT_EQ(2, shape.dim_size(0));
|
||||
ASSERT_EQ(-1, shape.dim_size(1));
|
||||
}
|
||||
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCppAPI, UnifiedAPI,
|
||||
::testing::Combine(::testing::Values("graphdef", "mlir"),
|
||||
/*tfrt*/ ::testing::Values(true, false),
|
||||
/*use_function*/ ::testing::Values(true, false)));
|
||||
#else
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCppAPI, UnifiedAPI,
|
||||
::testing::Combine(::testing::Values("graphdef", "mlir"),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*use_function*/ ::testing::Values(true, false)));
|
||||
#endif
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
161
tensorflow/c/eager/unified_api_testutil.cc
Normal file
161
tensorflow/c/eager/unified_api_testutil.cc
Normal file
@ -0,0 +1,161 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/unified_api_testutil.h"
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
AbstractContext* BuildFunction(const char* fn_name) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get());
|
||||
return unwrap(graph_ctx);
|
||||
}
|
||||
|
||||
Status CreateParamsForInputs(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
std::vector<AbstractTensorHandle*>* params) {
|
||||
tracing::TracingTensorHandle* handle = nullptr;
|
||||
for (auto input : inputs) {
|
||||
PartialTensorShape shape;
|
||||
TF_RETURN_IF_ERROR(input->Shape(&shape));
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
|
||||
input->DataType(), shape, &handle));
|
||||
params->emplace_back(handle);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Runs `model` maybe wrapped in a function.
|
||||
Status RunModel(Model model, AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, bool use_function) {
|
||||
if (use_function) {
|
||||
const char* fn_name = "test_fn";
|
||||
std::unique_ptr<AbstractFunction> scoped_func;
|
||||
// Returning null tensors from a tf.function is not supported, so we keep
|
||||
// track of indices in the model's outputs are nullptr in this set.
|
||||
// The FunctionDef only outputs the non-null tensors. We later pad the
|
||||
// function op outputs to have nullptrs at the `null_indices`.
|
||||
absl::flat_hash_set<int> null_indices;
|
||||
{
|
||||
AbstractContextPtr func_ctx(BuildFunction(fn_name));
|
||||
std::vector<AbstractTensorHandle*> func_inputs;
|
||||
func_inputs.reserve(inputs.size());
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
|
||||
std::vector<AbstractTensorHandle*> model_outputs;
|
||||
model_outputs.resize(outputs.size());
|
||||
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
|
||||
absl::MakeSpan(model_outputs)));
|
||||
for (auto func_input : func_inputs) {
|
||||
func_input->Unref();
|
||||
}
|
||||
AbstractFunction* func = nullptr;
|
||||
OutputList output_list;
|
||||
output_list.expected_num_outputs = 0;
|
||||
output_list.outputs.reserve(outputs.size());
|
||||
for (int i = 0; i < model_outputs.size(); i++) {
|
||||
if (model_outputs[i]) {
|
||||
output_list.outputs.emplace_back(model_outputs[i]);
|
||||
output_list.expected_num_outputs += 1;
|
||||
} else {
|
||||
null_indices.insert(i);
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
|
||||
->Finalize(&output_list, &func));
|
||||
scoped_func.reset(func);
|
||||
for (auto output : output_list.outputs) {
|
||||
output->Unref();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
|
||||
}
|
||||
|
||||
AbstractOperationPtr fn_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr));
|
||||
for (auto input : inputs) {
|
||||
TF_RETURN_IF_ERROR(fn_op->AddInput(input));
|
||||
}
|
||||
int retvals = outputs.size() - null_indices.size();
|
||||
std::vector<AbstractTensorHandle*> fn_outputs(retvals);
|
||||
TF_RETURN_IF_ERROR(fn_op->Execute(
|
||||
absl::Span<AbstractTensorHandle*>(fn_outputs.data(), fn_outputs.size()),
|
||||
&retvals));
|
||||
int skipped_indices = 0;
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
if (!null_indices.contains(i)) {
|
||||
outputs[i] = fn_outputs[i - skipped_indices];
|
||||
} else {
|
||||
skipped_indices += 1;
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
|
||||
return Status::OK();
|
||||
} else {
|
||||
return model(ctx, inputs, outputs);
|
||||
}
|
||||
}
|
||||
|
||||
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||
*ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get()));
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_DeleteContextOptions(opts);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float* data,
|
||||
int64* dims, int num_dims,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager = TestTensorHandleWithDimsFloat(
|
||||
eager_ctx, data, reinterpret_cast<int64_t*>(dims), num_dims);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
61
tensorflow/c/eager/unified_api_testutil.h
Normal file
61
tensorflow/c/eager/unified_api_testutil.h
Normal file
@ -0,0 +1,61 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_UNIFIED_API_TESTUTIL_H_
|
||||
#define TENSORFLOW_C_EAGER_UNIFIED_API_TESTUTIL_H_
|
||||
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Builds and returns a `TracingContext` using the default tracing impl.
|
||||
AbstractContext* BuildFunction(const char* fn_name);
|
||||
|
||||
// Creates parameters (placeholders) in the tracing `ctx` using the shape and
|
||||
// dtype of `inputs`.
|
||||
Status CreateParamsForInputs(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
std::vector<AbstractTensorHandle*>* params);
|
||||
|
||||
// A callable that takes tensor inputs and returns zero or more tensor outputs.
|
||||
using Model = std::function<Status(AbstractContext*,
|
||||
absl::Span<AbstractTensorHandle* const>,
|
||||
absl::Span<AbstractTensorHandle*>)>;
|
||||
|
||||
// Runs `model` maybe wrapped in a function call op. This can be thought as
|
||||
// being equivalent to the following python code.
|
||||
//
|
||||
// if use_function:
|
||||
// outputs = tf.function(model)(inputs)
|
||||
// else:
|
||||
// outputs = model(inputs)
|
||||
Status RunModel(Model model, AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, bool use_function);
|
||||
|
||||
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
|
||||
|
||||
// Get a Scalar TensorHandle with given float value.
|
||||
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
AbstractTensorHandle** tensor);
|
||||
|
||||
// Get a Matrix TensorHandle with given float values and dimensions.
|
||||
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float* data,
|
||||
int64* dims, int num_dims,
|
||||
AbstractTensorHandle** tensor);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_UNIFIED_API_TESTUTIL_H_
|
@ -21,10 +21,14 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
|
||||
using std::vector;
|
||||
using tensorflow::ops::Add;
|
||||
using tensorflow::ops::Conj;
|
||||
using tensorflow::ops::Div;
|
||||
using tensorflow::ops::DivNoNan;
|
||||
using tensorflow::ops::MatMul;
|
||||
using tensorflow::ops::Mul;
|
||||
using tensorflow::ops::Neg;
|
||||
using tensorflow::ops::OnesLike;
|
||||
using tensorflow::ops::SqrtGrad;
|
||||
|
||||
namespace tensorflow {
|
||||
@ -289,6 +293,117 @@ class MulGradientFunction : public GradientFunction {
|
||||
vector<AbstractTensorHandle*> forward_inputs;
|
||||
};
|
||||
|
||||
class Log1pGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit Log1pGradientFunction(vector<AbstractTensorHandle*> f_inputs)
|
||||
: forward_inputs(f_inputs) {}
|
||||
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
// TODO(vnvo2409): Add control dependency
|
||||
/* Given upstream grad U and a Log1p op: Y = log(1 + X), the gradients are:
|
||||
*
|
||||
* dX = U / (1 + X)
|
||||
*
|
||||
*/
|
||||
|
||||
AbstractTensorHandle* upstream_grad = grad_inputs[0];
|
||||
AbstractTensorHandle* X = forward_inputs[0];
|
||||
|
||||
grad_outputs->resize(1);
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
|
||||
// Calculate conjugate of X
|
||||
std::string name = "Conj_Log1p_Grad_X";
|
||||
TF_RETURN_IF_ERROR(
|
||||
Conj(ctx->ctx, {X}, absl::MakeSpan(temp_outputs), name.c_str()));
|
||||
|
||||
AbstractTensorHandle* Conj_X = temp_outputs[0];
|
||||
|
||||
// Creates Ones
|
||||
name = "OnesLike_Log1p_Grad_X";
|
||||
TF_RETURN_IF_ERROR(OnesLike(ctx->ctx, {Conj_X},
|
||||
absl::MakeSpan(temp_outputs), name.c_str()));
|
||||
|
||||
AbstractTensorHandle* Ones_X = temp_outputs[0];
|
||||
|
||||
name = "Add_Log1p_Grad_X";
|
||||
// Calculate 1 + Conj(X)
|
||||
TF_RETURN_IF_ERROR(Add(ctx->ctx, {Ones_X, Conj_X},
|
||||
absl::MakeSpan(temp_outputs), name.c_str()));
|
||||
|
||||
AbstractTensorHandle* Conj_XP1 = temp_outputs[0];
|
||||
|
||||
name = "Div_Log1p_Grad_X";
|
||||
// Calculate U / (1 + Conj(X))
|
||||
TF_RETURN_IF_ERROR(Div(ctx->ctx, {upstream_grad, Conj_XP1},
|
||||
absl::MakeSpan(temp_outputs), name.c_str()));
|
||||
|
||||
(*grad_outputs)[0] = temp_outputs[0];
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
~Log1pGradientFunction() override {}
|
||||
|
||||
private:
|
||||
vector<AbstractTensorHandle*> forward_inputs;
|
||||
};
|
||||
|
||||
class DivNoNanGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit DivNoNanGradientFunction(vector<AbstractTensorHandle*> f_inputs,
|
||||
vector<AbstractTensorHandle*> f_outputs)
|
||||
: forward_inputs(f_inputs), forward_outputs(f_outputs) {}
|
||||
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
// TODO(vnvo2409): Add shape broadcasting
|
||||
/* Given upstream grad U and a Div op: Z = X/Y, the gradients are:
|
||||
*
|
||||
* dX = U / Y
|
||||
* dY = -U*X / Y^2 = (X/Y) * -U / Y = -U*Z / Y
|
||||
*
|
||||
*/
|
||||
|
||||
AbstractTensorHandle* upstream_grad = grad_inputs[0];
|
||||
AbstractTensorHandle* Y = forward_inputs[1];
|
||||
AbstractTensorHandle* Z = forward_outputs[0];
|
||||
|
||||
grad_outputs->resize(2);
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
|
||||
// Calculate dX = U / Y
|
||||
std::string name = "Div_Grad_X";
|
||||
TF_RETURN_IF_ERROR(DivNoNan(ctx->ctx, {upstream_grad, Y},
|
||||
absl::MakeSpan(temp_outputs), name.c_str()));
|
||||
|
||||
(*grad_outputs)[0] = temp_outputs[0];
|
||||
|
||||
// Calculate dY = -U*Z / Y
|
||||
name = "Neg_Div_Grad_Y";
|
||||
TF_RETURN_IF_ERROR(Neg(ctx->ctx, {upstream_grad},
|
||||
absl::MakeSpan(temp_outputs), name.c_str())); // -U
|
||||
AbstractTensorHandle* MinusU = temp_outputs[0];
|
||||
|
||||
name = "Mul_Div_Grad_Y";
|
||||
TF_RETURN_IF_ERROR(Mul(ctx->ctx, {MinusU, Z}, absl::MakeSpan(temp_outputs),
|
||||
name.c_str())); // -U*Z
|
||||
AbstractTensorHandle* UZ = temp_outputs[0];
|
||||
|
||||
name = "Div_Grad_Y";
|
||||
TF_RETURN_IF_ERROR(DivNoNan(ctx->ctx, {UZ, Y}, absl::MakeSpan(temp_outputs),
|
||||
name.c_str())); // -U*Z / Y
|
||||
|
||||
(*grad_outputs)[1] = temp_outputs[0];
|
||||
return Status::OK();
|
||||
}
|
||||
~DivNoNanGradientFunction() override {}
|
||||
|
||||
private:
|
||||
vector<AbstractTensorHandle*> forward_inputs;
|
||||
vector<AbstractTensorHandle*> forward_outputs;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
BackwardFunction* AddRegisterer(const ForwardOperation& op) {
|
||||
@ -354,5 +469,23 @@ BackwardFunction* MulRegisterer(const ForwardOperation& op) {
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
BackwardFunction* Log1pRegisterer(const ForwardOperation& op) {
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto gradient_function = new Log1pGradientFunction(op.inputs);
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
BackwardFunction* DivNoNanRegisterer(const ForwardOperation& op) {
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto gradient_function = new DivNoNanGradientFunction(op.inputs, op.outputs);
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
@ -27,6 +27,8 @@ BackwardFunction* SqrtRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* NegRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* SubRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* MulRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* Log1pRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* DivNoNanRegisterer(const ForwardOperation& op);
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
@ -81,5 +81,17 @@ Status ExpandDims(AbstractContext* ctx,
|
||||
return op->Execute(outputs, &num_retvals);
|
||||
}
|
||||
|
||||
Status OnesLike(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(op.get(), name));
|
||||
TF_RETURN_IF_ERROR(op->AddInput(inputs[0]));
|
||||
|
||||
int num_retvals = 1;
|
||||
return op->Execute(outputs, &num_retvals);
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
@ -42,6 +42,10 @@ Status ExpandDims(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status OnesLike(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -44,8 +44,18 @@ Status Conj(AbstractContext* ctx,
|
||||
if (DataTypeIsFloating(BaseType(dtype)) ||
|
||||
DataTypeIsInteger(BaseType(dtype))) {
|
||||
TF_RETURN_IF_ERROR(Identity(ctx, inputs, outputs, name));
|
||||
} else if (DataTypeIsComplex(BaseType(dtype)) ||
|
||||
BaseType(dtype) == DT_VARIANT) {
|
||||
AbstractOperationPtr conj_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(conj_op->Reset("Conj", /*raw_device_name=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(conj_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(conj_op->AddInput(inputs[0]));
|
||||
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(conj_op->Execute(outputs, &num_retvals));
|
||||
} else {
|
||||
return errors::Unimplemented("Conj does not support complex types yet.");
|
||||
return errors::InvalidArgument(
|
||||
"Expected numeric or variant tensor, got dtype ", dtype);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -118,6 +128,19 @@ Status Sum(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Div(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr div_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(div_op->Reset("Div", /*raw_device_name=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(div_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(div_op->AddInput(inputs[0])); // x
|
||||
TF_RETURN_IF_ERROR(div_op->AddInput(inputs[1])); // y
|
||||
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(div_op->Execute(outputs, &num_retvals)); // z = x / y
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DivNoNan(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
@ -172,5 +195,18 @@ Status SqrtGrad(AbstractContext* ctx,
|
||||
return s;
|
||||
}
|
||||
|
||||
Status Log1p(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr log1p_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(log1p_op->Reset("Log1p", /*raw_device_name=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(log1p_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(log1p_op->AddInput(inputs[0]));
|
||||
|
||||
int num_retvals = 1;
|
||||
Status s = log1p_op->Execute(outputs, &num_retvals);
|
||||
return s;
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
@ -44,6 +44,9 @@ Status Sum(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
Status Sub(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status Div(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status DivNoNan(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
@ -59,6 +62,10 @@ Status SqrtGrad(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status Log1p(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -115,8 +115,8 @@ genrule(
|
||||
# have control of the full GPU.
|
||||
cmd = "CUDA_VISIBLE_DEVICES='' " +
|
||||
"$(location :make_test_graphs) --out_dir $(@D)",
|
||||
exec_tools = [":make_test_graphs"],
|
||||
tags = ["manual"],
|
||||
tools = [":make_test_graphs"],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
|
@ -127,7 +127,7 @@ def tf_library(
|
||||
"$(location " + tfcompile_tool + ")" +
|
||||
" --config=$(location " + config + ")" +
|
||||
" --dump_fetch_nodes > $@"),
|
||||
exec_tools = [tfcompile_tool],
|
||||
tools = [tfcompile_tool],
|
||||
# Run tfcompile on the build host, rather than forge, since it's
|
||||
# typically way faster on the local machine.
|
||||
local = 1,
|
||||
@ -162,7 +162,7 @@ def tf_library(
|
||||
"//tensorflow/python/tools:freeze_graph)" +
|
||||
freeze_args
|
||||
),
|
||||
exec_tools = ["//tensorflow/python/tools:freeze_graph"],
|
||||
tools = ["//tensorflow/python/tools:freeze_graph"],
|
||||
tags = tags,
|
||||
)
|
||||
tfcompile_graph = freeze_file
|
||||
@ -242,7 +242,7 @@ def tf_library(
|
||||
" --out_function_object=$(@D)/" + function_object_file +
|
||||
" " + flags + " " + profiling_flag + " " + mlir_flag + " " + traceme_flag
|
||||
),
|
||||
exec_tools = [tfcompile_tool],
|
||||
tools = [tfcompile_tool],
|
||||
visibility = visibility,
|
||||
testonly = testonly,
|
||||
# Run tfcompile on the build host since it's typically faster on the
|
||||
@ -281,7 +281,7 @@ def tf_library(
|
||||
" --out_session_module=$(@D)/" + session_module_pb +
|
||||
" " + flags
|
||||
),
|
||||
exec_tools = [tfcompile_tool],
|
||||
tools = [tfcompile_tool],
|
||||
visibility = visibility,
|
||||
testonly = testonly,
|
||||
local = 1,
|
||||
|
@ -508,7 +508,7 @@ cc_library(
|
||||
":flags",
|
||||
":jit_compilation_passes",
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops_no_jit_rewrite_registration",
|
||||
"//tensorflow/compiler/mlir:mlir_bridge_rollout_policy",
|
||||
"//tensorflow/compiler/tf2xla:mlir_bridge_pass",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
|
@ -115,7 +115,7 @@ xla::StatusOr<std::string> GetCompilerIr(
|
||||
|
||||
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
|
||||
XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
constant_arg_indices, inputs, variable_infos);
|
||||
constant_arg_indices, inputs, variable_infos, dev);
|
||||
TF_RETURN_IF_ERROR(args.status());
|
||||
|
||||
switch (stage) {
|
||||
|
@ -206,8 +206,9 @@ static Status CompileToLocalExecutable(
|
||||
may_alias_resource_update;
|
||||
|
||||
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
|
||||
XlaComputationLaunchContext::BuildXlaCompilerArguments(constants, inputs,
|
||||
variable_infos);
|
||||
XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
constants, inputs, variable_infos,
|
||||
static_cast<Device*>(ctx->device()));
|
||||
TF_RETURN_IF_ERROR(args.status());
|
||||
return cache->Compile(options, function, *args, compile_options,
|
||||
lazy ? XlaCompilationCache::CompileMode::kLazy
|
||||
@ -246,8 +247,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
|
||||
VLOG(1) << "Executing XLA Computation...";
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
se::DeviceMemoryAllocator* allocator = GetAllocator(
|
||||
&tf_allocator_adapter, ctx->device(),
|
||||
|
@ -1990,6 +1990,8 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
||||
"StatelessCase",
|
||||
"StatelessIf",
|
||||
"StatelessMultinomial",
|
||||
"StatelessRandomGetAlg",
|
||||
"StatelessRandomGetKeyCounter",
|
||||
"StatelessRandomGetKeyCounterAlg",
|
||||
"StatelessRandomNormal",
|
||||
"StatelessRandomNormalV2",
|
||||
@ -2040,6 +2042,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
||||
"UnsortedSegmentSum",
|
||||
"VarIsInitializedOp",
|
||||
"VariableShape",
|
||||
"Where",
|
||||
"While",
|
||||
"XlaBroadcastHelper",
|
||||
"XlaConv",
|
||||
|
@ -140,6 +140,7 @@ XlaCompilationCache::BuildSignature(
|
||||
for (const XlaCompiler::Argument& arg : args) {
|
||||
switch (arg.kind) {
|
||||
case XlaCompiler::Argument::kConstant:
|
||||
case XlaCompiler::Argument::kConstantResource:
|
||||
signature.arg_values.push_back(arg.constant_value);
|
||||
break;
|
||||
case XlaCompiler::Argument::kParameter:
|
||||
@ -288,7 +289,7 @@ Status XlaCompilationCache::CompileSingleOp(
|
||||
const ConfigProto* config = ctx->function_library()->config_proto();
|
||||
// TODO(b/171039585): Support tf.VarIsInitializedOp using MLIR.
|
||||
bool use_mlir = config &&
|
||||
GetMlirBridgeRolloutPolicy(*config) ==
|
||||
GetMlirBridgeRolloutPolicy(*graph, *config) ==
|
||||
MlirBridgeRolloutPolicy::kEnabledByUser &&
|
||||
node_def.op() != "VarIsInitializedOp";
|
||||
#ifdef LIBTPU_ON_GCE
|
||||
|
@ -153,7 +153,8 @@ Status XlaCompileOnDemandOp::Compile(
|
||||
ctx, variables_indices, variable_infos, variable_args));
|
||||
|
||||
args = XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
constant_input_indices, inputs, variable_infos);
|
||||
constant_input_indices, inputs, variable_infos,
|
||||
static_cast<Device*>(ctx->device()));
|
||||
TF_RETURN_IF_ERROR(args.status());
|
||||
}
|
||||
|
||||
|
@ -14,7 +14,6 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/jit/xla_kernel_creator.h"
|
||||
|
||||
#include "tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
@ -23,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/mlir_bridge_pass.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
@ -89,10 +89,21 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
|
||||
// Make sure that kernels have been registered on the JIT device.
|
||||
XlaOpRegistry::RegisterCompilationKernels();
|
||||
|
||||
// Get function body, constant args, and resource args.
|
||||
NameAttrList function;
|
||||
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
|
||||
const FunctionBody* fbody = nullptr;
|
||||
std::vector<int> constant_arg_indices;
|
||||
std::vector<int> resource_arg_indices;
|
||||
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
|
||||
flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
|
||||
|
||||
// Only check for compilability if the MLIR bridge is not enabled.
|
||||
MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(absl::nullopt);
|
||||
if (policy == MlirBridgeRolloutPolicy::kDisabledByUser ||
|
||||
policy == MlirBridgeRolloutPolicy::kDisabledAfterGraphAnalysis) {
|
||||
absl::optional<ConfigProto> config_proto;
|
||||
if (flr->config_proto()) {
|
||||
config_proto = *flr->config_proto();
|
||||
}
|
||||
if (!IsMlirBridgePassEnabled(*fbody->graph, config_proto)) {
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
|
||||
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
|
||||
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
|
||||
@ -121,15 +132,6 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
|
||||
}
|
||||
}
|
||||
|
||||
// Get function body, constant args, and resource args.
|
||||
NameAttrList function;
|
||||
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
|
||||
const FunctionBody* fbody = nullptr;
|
||||
std::vector<int> constant_arg_indices;
|
||||
std::vector<int> resource_arg_indices;
|
||||
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
|
||||
flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
|
||||
|
||||
MemoryTypeVector input_memory_types =
|
||||
GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
|
||||
MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
|
||||
|
@ -449,15 +449,14 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
auto transfer_manager,
|
||||
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
|
||||
|
||||
xla::Shape output_host_shape = output.on_host_shape();
|
||||
xla::Shape output_device_shape = output.on_device_shape();
|
||||
TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
|
||||
stream, &output, &output_host_shape, &output_device_shape));
|
||||
stream, &output, &output_device_shape));
|
||||
|
||||
output.set_shapes(output_host_shape, output_device_shape);
|
||||
output.set_shapes(output_device_shape, output_device_shape);
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
const xla::Shape& subshape =
|
||||
xla::ShapeUtil::GetSubshape(output_host_shape, {i});
|
||||
xla::ShapeUtil::GetSubshape(output_device_shape, {i});
|
||||
TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
|
||||
output_tensor_shapes.push_back(shape);
|
||||
@ -564,11 +563,26 @@ xla::StatusOr<std::vector<XlaCompiler::Argument>>
|
||||
XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
absl::Span<int const> must_be_constant_idxs,
|
||||
absl::Span<const Tensor* const> inputs,
|
||||
absl::Span<VariableInfo const> variable_args) {
|
||||
absl::Span<VariableInfo const> variable_args, Device* device) {
|
||||
CHECK(absl::c_is_sorted(must_be_constant_idxs));
|
||||
std::vector<XlaCompiler::Argument> out;
|
||||
out.resize(inputs.size());
|
||||
|
||||
// TODO(cheshire): Avoid duplication with framework/op_kernel.h
|
||||
DeviceContext* device_context = nullptr;
|
||||
TF_RETURN_IF_ERROR(device->TryGetDeviceContext(&device_context));
|
||||
bool using_default_context = false;
|
||||
auto cleanup = xla::MakeCleanup([&] {
|
||||
if (device_context != nullptr && !using_default_context) {
|
||||
device_context->Unref();
|
||||
}
|
||||
});
|
||||
if (device_context == nullptr) {
|
||||
using_default_context = true;
|
||||
auto* dev_info = device->tensorflow_gpu_device_info();
|
||||
if (dev_info) device_context = dev_info->default_context;
|
||||
}
|
||||
|
||||
absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup;
|
||||
for (const VariableInfo& info : variable_args) {
|
||||
CHECK(!info.var() || info.lock_held())
|
||||
@ -581,18 +595,7 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
const Tensor* input = inputs[input_num];
|
||||
|
||||
XlaCompiler::Argument& arg = out[input_num];
|
||||
if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
|
||||
// Handles compile-time constants.
|
||||
|
||||
// TODO(b/157241314): Support constants located in resource variables.
|
||||
TF_RET_CHECK(input->dtype() != DT_RESOURCE)
|
||||
<< "tf2xla bridge does not support must-be-constants located in "
|
||||
"resource variables; try moving them to a tensor";
|
||||
arg.kind = XlaCompiler::Argument::kConstant;
|
||||
arg.type = input->dtype();
|
||||
arg.shape = input->shape();
|
||||
arg.constant_value = *input;
|
||||
} else if (variable_info_lookup.count(input_num)) {
|
||||
if (variable_info_lookup.count(input_num)) {
|
||||
// Handles resource variables.
|
||||
TF_RET_CHECK(input->dtype() == DT_RESOURCE);
|
||||
const VariableInfo& variable = *variable_info_lookup[input_num];
|
||||
@ -613,6 +616,25 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
arg.type = DT_INVALID;
|
||||
arg.shape = TensorShape();
|
||||
}
|
||||
|
||||
if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
|
||||
TF_RET_CHECK(variable.var() && variable.var()->is_initialized);
|
||||
const Tensor* value = variable.var()->tensor();
|
||||
Tensor value_on_host(value->dtype(), value->shape());
|
||||
if (!device_context) {
|
||||
value_on_host = *value;
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(device_context->CopyDeviceTensorToCPUSync(
|
||||
value, "", device, &value_on_host));
|
||||
}
|
||||
arg.kind = XlaCompiler::Argument::kConstantResource;
|
||||
arg.constant_value = value_on_host;
|
||||
}
|
||||
} else if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
|
||||
arg.kind = XlaCompiler::Argument::kConstant;
|
||||
arg.type = input->dtype();
|
||||
arg.shape = input->shape();
|
||||
arg.constant_value = *input;
|
||||
} else {
|
||||
// Normal inputs.
|
||||
TF_RET_CHECK(input->dtype() != DT_RESOURCE);
|
||||
|
@ -143,7 +143,8 @@ class XlaComputationLaunchContext {
|
||||
static xla::StatusOr<std::vector<XlaCompiler::Argument>>
|
||||
BuildXlaCompilerArguments(absl::Span<int const> must_be_constant_idxs,
|
||||
absl::Span<const Tensor* const> inputs,
|
||||
absl::Span<VariableInfo const> variable_args);
|
||||
absl::Span<VariableInfo const> variable_args,
|
||||
Device* device);
|
||||
|
||||
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
|
||||
// `variables` is a map from TensorFlow argument number to resource variable.
|
||||
|
@ -3,7 +3,11 @@
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "filegroup")
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_binary",
|
||||
"tf_cc_test",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
@ -126,12 +130,14 @@ cc_library(
|
||||
srcs = ["mlir_graph_optimization_pass.cc"],
|
||||
hdrs = ["mlir_graph_optimization_pass.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir:mlir_bridge_rollout_policy",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||
"//tensorflow/compiler/mlir/tensorflow:device_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
@ -198,11 +204,22 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "mlir_graph_optimization_pass_test",
|
||||
srcs = ["mlir_graph_optimization_pass_test.cc"],
|
||||
deps = [
|
||||
":mlir_graph_optimization_pass",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "litfiles",
|
||||
srcs = glob(["runlit*py"]),
|
||||
|
@ -51,10 +51,10 @@ filegroup(
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td",
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:SideEffectTdFiles",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/CopyOpInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/ViewLikeInterface.td",
|
||||
],
|
||||
)
|
||||
@ -464,7 +464,6 @@ cc_library(
|
||||
":hlo",
|
||||
":lhlo",
|
||||
":lhlo_gpu",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
||||
@ -500,7 +499,6 @@ cc_library(
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
||||
@ -639,12 +637,10 @@ cc_library(
|
||||
":lhlo",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Affine",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LinalgTransforms",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
"@llvm-project//mlir:ViewLikeInterface",
|
||||
],
|
||||
@ -668,6 +664,7 @@ cc_library(
|
||||
"@llvm-project//mlir:Shape",
|
||||
"@llvm-project//mlir:ShapeTransforms",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:StandardOpsTransforms",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
@ -702,12 +699,10 @@ cc_library(
|
||||
deps = [
|
||||
":cycle_detector",
|
||||
":hlo",
|
||||
"@llvm-project//llvm:Core",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
],
|
||||
alwayslink = 1,
|
||||
@ -738,7 +733,6 @@ cc_library(
|
||||
deps = [
|
||||
":hlo",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
@ -759,7 +753,6 @@ cc_library(
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
],
|
||||
alwayslink = 1,
|
||||
@ -777,8 +770,6 @@ cc_library(
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
alwayslink = 1,
|
||||
@ -797,7 +788,6 @@ cc_library(
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
alwayslink = 1,
|
||||
@ -838,11 +828,9 @@ cc_library(
|
||||
":hlo",
|
||||
":lower_complex_inc_gen",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -43,18 +43,6 @@ list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules")
|
||||
|
||||
option(MHLO_BUILD_EMBEDDED "Build MHLO as part of another project" OFF)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# MSVC defaults
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
if(MSVC)
|
||||
add_compile_options(
|
||||
$<$<CONFIG:>:/MD>
|
||||
$<$<CONFIG:Debug>:/MD>
|
||||
$<$<CONFIG:Release>:/MD>
|
||||
)
|
||||
endif()
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# MLIR/LLVM Configuration
|
||||
#-------------------------------------------------------------------------------
|
||||
|
@ -925,7 +925,7 @@ def HLO_CustomCallOp: HLO_Op<"custom_call", []>, BASE_HLO_CustomCallOp {
|
||||
DefaultValuedAttr<BoolAttr, "false">:$has_side_effect,
|
||||
DefaultValuedAttr<StrAttr, "">:$backend_config
|
||||
);
|
||||
let results = (outs HLO_Tensor);
|
||||
let results = (outs Variadic<HLO_Tensor>);
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
|
@ -264,10 +264,11 @@ def LHLO_WhileOp: LHLO_Op<"while", [SameVariadicOperandSize]>,
|
||||
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
|
||||
}
|
||||
|
||||
def LHLO_CustomCallOp : LHLO_Op<"custom_call", []>, BASE_HLO_CustomCallOp {
|
||||
def LHLO_CustomCallOp : LHLO_Op<"custom_call", [AttrSizedOperandSegments]>,
|
||||
BASE_HLO_CustomCallOp {
|
||||
let arguments = (ins
|
||||
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$args,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$output,
|
||||
StrAttr:$call_target_name,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$has_side_effect,
|
||||
DefaultValuedAttr<StrAttr, "">:$backend_config
|
||||
|
@ -1268,7 +1268,8 @@ class DynamicReshapeOpNotActuallyDynamic
|
||||
|
||||
void DynamicReshapeOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList& results, MLIRContext* context) {
|
||||
results.insert<DynamicReshapeOpNotActuallyDynamic>(context);
|
||||
results.insert<DynamicReshapeOpNotActuallyDynamic, ShapeOfDynamicReshape>(
|
||||
context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -28,3 +28,6 @@ def DynamicBroadcastToOwnShape_2 : Pat<
|
||||
(HLO_DynamicBroadcastInDimOp:$op $x, (Shape_ShapeOfOp $x), $attr),
|
||||
(replaceWithValue $x)>;
|
||||
|
||||
def ShapeOfDynamicReshape : Pat<
|
||||
(Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)),
|
||||
(replaceWithValue $shape)>;
|
||||
|
@ -13,6 +13,12 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Enable the use of M_* math constants.
|
||||
// NOTE: this must be first in the file to ensure that if cmath is transitively
|
||||
// included by any other header it has the define set on first processing.
|
||||
// https://docs.microsoft.com/en-us/cpp/c-runtime-library/math-constants
|
||||
#define _USE_MATH_DEFINES
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/Shape/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
@ -87,6 +88,32 @@ Value InsertAlloc(Location loc, OpResult result,
|
||||
return alloc;
|
||||
}
|
||||
|
||||
/// Converts the results of the operation `op` to memref types and append them
|
||||
/// to the `results` vector.
|
||||
LogicalResult ConvertResults(Operation* op, SmallVectorImpl<Value>& results,
|
||||
ConversionPatternRewriter& rewriter) {
|
||||
for (auto result : llvm::enumerate(op->getResults())) {
|
||||
RankedTensorType resultType =
|
||||
result.value().getType().dyn_cast<RankedTensorType>();
|
||||
if (!resultType) return failure();
|
||||
|
||||
if (resultType.hasStaticShape()) {
|
||||
results.push_back(InsertAlloc(op->getLoc(), result.value(), &rewriter));
|
||||
continue;
|
||||
}
|
||||
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
||||
if (!shape_type_op) return failure();
|
||||
|
||||
SmallVector<Value, 1> results_shape;
|
||||
auto status = shape_type_op.reifyReturnTypeShapes(rewriter, results_shape);
|
||||
if (failed(status)) return failure();
|
||||
results.push_back(
|
||||
InsertDynamicAllocAndDealloc(op->getLoc(), result.value(),
|
||||
results_shape[result.index()], &rewriter));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename HloOpTy>
|
||||
class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
|
||||
public:
|
||||
@ -95,29 +122,8 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
|
||||
HloOpTy hloOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
Operation* op = hloOp.getOperation();
|
||||
const auto& original_results = op->getResults();
|
||||
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
|
||||
for (auto result : llvm::enumerate(original_results)) {
|
||||
RankedTensorType resultType =
|
||||
result.value().getType().dyn_cast<RankedTensorType>();
|
||||
if (!resultType) {
|
||||
return failure();
|
||||
}
|
||||
if (resultType.hasStaticShape()) {
|
||||
buffer_args.push_back(
|
||||
InsertAlloc(op->getLoc(), result.value(), &rewriter));
|
||||
} else {
|
||||
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
||||
if (!shape_type_op) return failure();
|
||||
|
||||
SmallVector<Value, 1> results_shape;
|
||||
auto status =
|
||||
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape);
|
||||
if (failed(status)) return failure();
|
||||
buffer_args.push_back(InsertDynamicAllocAndDealloc(
|
||||
op->getLoc(), result.value(), results_shape.front(), &rewriter));
|
||||
}
|
||||
}
|
||||
if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
|
||||
rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
|
||||
buffer_args, op->getAttrs());
|
||||
rewriter.replaceOp(
|
||||
@ -139,28 +145,8 @@ class HloToLhloOpConverter<mhlo::DotOp> : public BaseOpConversion<mhlo::DotOp> {
|
||||
mhlo::DotOp hloOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
Operation* op = hloOp.getOperation();
|
||||
const auto& original_results = op->getResults();
|
||||
SmallVector<Value, 2> buffer_args(operands.begin(), operands.end());
|
||||
for (auto result : llvm::enumerate(original_results)) {
|
||||
RankedTensorType resultType =
|
||||
result.value().getType().dyn_cast<RankedTensorType>();
|
||||
if (!resultType) {
|
||||
return failure();
|
||||
}
|
||||
if (resultType.hasStaticShape()) {
|
||||
buffer_args.push_back(
|
||||
InsertAlloc(op->getLoc(), result.value(), &rewriter));
|
||||
} else {
|
||||
SmallVector<Value, 1> results_shape;
|
||||
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
||||
if (!shape_type_op) return failure();
|
||||
if (failed(
|
||||
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape)))
|
||||
return failure();
|
||||
buffer_args.push_back(InsertDynamicAllocAndDealloc(
|
||||
op->getLoc(), result.value(), results_shape.front(), &rewriter));
|
||||
}
|
||||
}
|
||||
if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
|
||||
|
||||
// TODO(silvasean): Move this helper to MLIR core.
|
||||
auto make_elements_attr = [&rewriter](ArrayRef<int64_t> integers) {
|
||||
@ -180,6 +166,32 @@ class HloToLhloOpConverter<mhlo::DotOp> : public BaseOpConversion<mhlo::DotOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct HloToLhloCustomCallOpConverter
|
||||
: public BaseOpConversion<mhlo::CustomCallOp> {
|
||||
public:
|
||||
using BaseOpConversion<mhlo::CustomCallOp>::BaseOpConversion;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
mhlo::CustomCallOp hloOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
Operation* op = hloOp.getOperation();
|
||||
SmallVector<Value, 2> buffer_args(operands.begin(), operands.end());
|
||||
if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
|
||||
|
||||
auto lhloOp = rewriter.create<lmhlo::CustomCallOp>(
|
||||
op->getLoc(), llvm::None, buffer_args, op->getAttrs());
|
||||
// Setup AttrSizedOperandSegments attribute to indicate number of operands
|
||||
// for args and outputs.
|
||||
const int32_t segments[2] = {static_cast<int32_t>(operands.size()),
|
||||
static_cast<int32_t>(op->getNumResults())};
|
||||
lhloOp.setAttr(lhloOp.getOperandSegmentSizeAttr(),
|
||||
rewriter.getI32VectorAttr(segments));
|
||||
|
||||
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||
: public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
|
||||
public:
|
||||
@ -194,8 +206,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||
|
||||
Value transformed_operand =
|
||||
InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
|
||||
rewriter.create<lmhlo::BroadcastInDimOp>(
|
||||
loc, transformed_operand, resultBuffer, op.broadcast_dimensions());
|
||||
rewriter.create<lmhlo::CopyOp>(loc, transformed_operand, resultBuffer);
|
||||
|
||||
rewriter.replaceOp(op, {resultBuffer});
|
||||
|
||||
@ -211,48 +222,76 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||
auto loc = op.getLoc();
|
||||
auto operand_type = operand.getType().cast<MemRefType>();
|
||||
auto operand_shape = operand_type.getShape();
|
||||
auto operand_rank = operand_type.getRank();
|
||||
|
||||
SmallVector<Value, 2> sizes, strides;
|
||||
sizes.reserve(operand_shape.size());
|
||||
strides.reserve(operand_shape.size());
|
||||
auto result_type = op.getType().cast<RankedTensorType>();
|
||||
auto result_rank = result_type.getRank();
|
||||
|
||||
Value zero = b->create<ConstantIndexOp>(loc, 0);
|
||||
Value one = b->create<ConstantIndexOp>(loc, 1);
|
||||
for (auto dim : llvm::enumerate(op.broadcast_dimensions())) {
|
||||
Value broadcast_dim_value =
|
||||
b->create<ConstantIndexOp>(loc, dim.value().getSExtValue());
|
||||
Value result_dim_size = b->create<ExtractElementOp>(
|
||||
loc, op.output_dimensions(), broadcast_dim_value);
|
||||
Value operand_dim_size =
|
||||
ShapedType::isDynamic(operand_shape[dim.index()])
|
||||
? b->create<DimOp>(loc, operand, dim.index()).getResult()
|
||||
: b->create<ConstantIndexOp>(loc, operand_shape[dim.index()])
|
||||
.getResult();
|
||||
|
||||
// TODO(pifon): Revisit if this cast is needed. Maybe we can use
|
||||
// tensor<index> for `output_dimensions` as well.
|
||||
// Compute a reversed scan product. Compute the stride for the dimensions so
|
||||
// far, working from minor to major dimensions. Additionally, save the
|
||||
// operand shape Values to use in the next loop.
|
||||
SmallVector<Value, 2> operand_strides(operand_rank, one);
|
||||
SmallVector<Value, 2> operand_sizes(operand_rank, one);
|
||||
Value stride_so_far = one;
|
||||
for (int i = operand_rank - 1; i >= 0; --i) {
|
||||
Value operand_dim_size =
|
||||
ShapedType::isDynamic(operand_shape[i])
|
||||
? b->create<DimOp>(loc, operand, i).getResult()
|
||||
: b->create<ConstantIndexOp>(loc, operand_shape[i]).getResult();
|
||||
operand_sizes[i] = operand_dim_size;
|
||||
|
||||
operand_strides[i] = stride_so_far;
|
||||
if (i > 0) {
|
||||
stride_so_far = b->create<MulIOp>(loc, stride_so_far, operand_dim_size);
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Value, 2> sizes, strides;
|
||||
sizes.reserve(result_rank);
|
||||
strides.reserve(result_rank);
|
||||
|
||||
DenseMap<int, int> output_to_input_dim;
|
||||
for (auto dim : llvm::enumerate(op.broadcast_dimensions())) {
|
||||
output_to_input_dim[dim.value().getSExtValue()] = dim.index();
|
||||
}
|
||||
for (int i = 0; i < result_rank; ++i) {
|
||||
Value i_val = b->create<ConstantIndexOp>(loc, i);
|
||||
Value result_dim_size =
|
||||
b->create<ExtractElementOp>(loc, op.output_dimensions(), i_val);
|
||||
if (!result_dim_size.getType().isIndex()) {
|
||||
result_dim_size =
|
||||
b->create<IndexCastOp>(loc, result_dim_size, b->getIndexType());
|
||||
}
|
||||
sizes.push_back(result_dim_size);
|
||||
|
||||
auto it = output_to_input_dim.find(i);
|
||||
// If the rank of the output is greater than the rank of the input, i.e.
|
||||
// there was no output dimension in the inverse broadcast_dimensions map
|
||||
// we also set stride to 0 to emulate padding of the shape with 1s and the
|
||||
// corresponding expansion.
|
||||
if (it == output_to_input_dim.end()) {
|
||||
strides.push_back(zero);
|
||||
continue;
|
||||
}
|
||||
|
||||
// There can be two cases:
|
||||
// 1) Operand dim == result dim => expansion is not needed => stride := 1.
|
||||
// 1) Operand dim == result dim => expansion is not needed
|
||||
// => stride flattened buffer stride
|
||||
// 2) Operand dim < result dim => expansion is needed => stride := 0.
|
||||
Value is_expansion = b->create<CmpIOp>(loc, CmpIPredicate::slt,
|
||||
operand_dim_size, result_dim_size);
|
||||
strides.push_back(
|
||||
b->create<mlir::SelectOp>(loc, is_expansion, zero, one));
|
||||
|
||||
// Size of input dim can be set to the size of the corresponding output
|
||||
// dimension for both cases.
|
||||
sizes.push_back(result_dim_size);
|
||||
int dim = it->second;
|
||||
Value is_expansion = b->create<CmpIOp>(
|
||||
loc, CmpIPredicate::slt, operand_sizes[dim], result_dim_size);
|
||||
strides.push_back(b->create<mlir::SelectOp>(loc, is_expansion, zero,
|
||||
operand_strides[dim]));
|
||||
}
|
||||
|
||||
// Type-erased memref type with static rank, dynamic sizes and strides.
|
||||
SmallVector<int64_t, 2> dynamic_layout(operand_shape.size(),
|
||||
SmallVector<int64_t, 2> dynamic_layout(result_rank,
|
||||
MemRefType::kDynamicStrideOrOffset);
|
||||
SmallVector<int64_t, 2> dynamic_shape(operand_shape.size(),
|
||||
SmallVector<int64_t, 2> dynamic_shape(result_rank,
|
||||
MemRefType::kDynamicSize);
|
||||
auto type_erased_memref_type = MemRefType::get(
|
||||
dynamic_shape, operand_type.getElementType(),
|
||||
@ -517,11 +556,8 @@ struct HloLegalizeToLhlo
|
||||
ConversionTarget target(context);
|
||||
target.addLegalDialect<lmhlo::LmhloDialect>();
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
target.addLegalOp<ModuleOp>();
|
||||
target.addIllegalOp<mlir::TensorLoadOp>();
|
||||
target.addIllegalOp<mlir::TensorStoreOp>();
|
||||
target.addLegalOp<ModuleTerminatorOp>();
|
||||
target.addLegalOp<TensorFromElementsOp>();
|
||||
target.addIllegalDialect<mhlo::MhloDialect>();
|
||||
|
||||
BufferizeTypeConverter converter;
|
||||
@ -543,9 +579,8 @@ struct HloLegalizeToLhlo
|
||||
});
|
||||
|
||||
populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
|
||||
populateWithBufferizeOpConversionPatterns<mlir::ReturnOp, mlir::ReturnOp,
|
||||
lmhlo::CopyOp>(
|
||||
&context, converter, patterns);
|
||||
populateFuncOpTypeConversionPattern(patterns, &context, converter);
|
||||
populateCallOpTypeConversionPattern(patterns, &context, converter);
|
||||
populateShapeStructuralTypeConversionsAndLegality(&context, converter,
|
||||
patterns, target);
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
@ -560,6 +595,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
||||
OwningRewritePatternList* patterns) {
|
||||
// clang-format off
|
||||
patterns->insert<
|
||||
HloToLhloCustomCallOpConverter,
|
||||
HloToLhloDotGeneralOpConverter,
|
||||
HloToLhloDynamicBroadcastInDimOpConverter,
|
||||
HloToLhloDynamicReshapeConverter,
|
||||
@ -576,7 +612,6 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
||||
HloToLhloOpConverter<mhlo::ConvertOp>,
|
||||
HloToLhloOpConverter<mhlo::CopyOp>,
|
||||
HloToLhloOpConverter<mhlo::CosOp>,
|
||||
HloToLhloOpConverter<mhlo::CustomCallOp>,
|
||||
HloToLhloOpConverter<mhlo::DivOp>,
|
||||
HloToLhloOpConverter<mhlo::DotOp>,
|
||||
HloToLhloOpConverter<mhlo::ExpOp>,
|
||||
@ -607,7 +642,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
||||
HloToLhloReturnOpConverter,
|
||||
HloToLhloTensorLoadOpConverter,
|
||||
HloToLhloTensorStoreOpConverter
|
||||
>(context);
|
||||
>(*converter, context);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
|
@ -85,13 +85,66 @@ class LhloFuseLinalgPass
|
||||
if (!definingOp) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto viewLike = dyn_cast<ViewLikeOpInterface>(definingOp)) {
|
||||
auto alias = viewLike.getViewSource();
|
||||
if (result_buffers.insert(alias).second) {
|
||||
worklist.push_back(alias);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto tensor_load = dyn_cast<TensorLoadOp>(definingOp)) {
|
||||
auto alias = tensor_load.memref();
|
||||
if (result_buffers.insert(alias).second) {
|
||||
worklist.push_back(alias);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto tensor_to_memref = dyn_cast<TensorToMemrefOp>(definingOp)) {
|
||||
auto alias = tensor_to_memref.tensor();
|
||||
if (result_buffers.insert(alias).second) {
|
||||
worklist.push_back(alias);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto tensor_cast = dyn_cast<TensorCastOp>(definingOp)) {
|
||||
auto alias = tensor_cast.source();
|
||||
if (result_buffers.insert(alias).second) {
|
||||
worklist.push_back(alias);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto regionInterface =
|
||||
dyn_cast<RegionBranchOpInterface>(definingOp)) {
|
||||
for (Region& region : regionInterface.getOperation()->getRegions()) {
|
||||
// Only consider regions that can return to the parent region.
|
||||
SmallVector<RegionSuccessor, 2> successorRegions;
|
||||
regionInterface.getSuccessorRegions(region.getRegionNumber(),
|
||||
successorRegions);
|
||||
if (llvm::none_of(successorRegions, [&](auto successorRegion) {
|
||||
return successorRegion.isParent();
|
||||
}))
|
||||
continue;
|
||||
|
||||
// Iterate over all immediate terminators and record the values
|
||||
// corresponding to result_buffers of interest.
|
||||
for (Block& block : region) {
|
||||
if (block.empty()) continue;
|
||||
Operation& operation = block.back();
|
||||
if (!operation.hasTrait<OpTrait::ReturnLike>()) continue;
|
||||
auto idx = result.dyn_cast<OpResult>().getResultNumber();
|
||||
if (result_buffers.insert(operation.getOperand(idx)).second) {
|
||||
worklist.push_back(operation.getOperand(idx));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
MLIRContext* ctx = func.getContext();
|
||||
OpBuilder b(func);
|
||||
func.walk([&](linalg::GenericOp generic_op) {
|
||||
@ -114,10 +167,10 @@ class LhloFuseLinalgPass
|
||||
|
||||
// Fuse producers of tiled linalg ops.
|
||||
llvm::SmallDenseSet<Operation*> erase_set;
|
||||
SmallVector<Operation*, 8> linalg_ops;
|
||||
SmallVector<LinalgOp, 8> linalg_ops;
|
||||
func.walk([&](LinalgOp op) { linalg_ops.push_back(op); });
|
||||
for (auto* op : llvm::reverse(linalg_ops)) {
|
||||
for (unsigned id = 0, e = LinalgOp(op).getNumInputs(); id < e; ++id) {
|
||||
for (LinalgOp op : llvm::reverse(linalg_ops)) {
|
||||
for (unsigned id = 0, e = op.getNumInputs(); id < e; ++id) {
|
||||
linalg::Aliases aliases;
|
||||
linalg::LinalgDependenceGraph graph(aliases, linalg_ops);
|
||||
if (auto info = fuseProducerOfBuffer(b, op, id, graph)) {
|
||||
|
@ -50,6 +50,8 @@ class SinkConstantsToControlFlowPass
|
||||
} else if (auto if_op = llvm::dyn_cast<IfOp>(op)) {
|
||||
SinkToRegion(&if_op.true_branch());
|
||||
SinkToRegion(&if_op.false_branch());
|
||||
} else if (auto reduce_window_op = llvm::dyn_cast<ReduceWindowOp>(op)) {
|
||||
SinkToRegion(&reduce_window_op.body());
|
||||
} else if (auto sort_op = llvm::dyn_cast<SortOp>(op)) {
|
||||
SinkToRegion(&sort_op.comparator());
|
||||
}
|
||||
|
@ -575,6 +575,16 @@ func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<
|
||||
return %0 : tensor<4x1xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @shape_of_dynamic_reshape
|
||||
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9]+]]
|
||||
func @shape_of_dynamic_reshape(%arg0: tensor<*xf32>, %shape: tensor<2xindex>) -> tensor<2xindex> {
|
||||
// CHECK: return [[ARG1]]
|
||||
%0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
%1 = shape.shape_of %0 : tensor<?x?xf32> -> tensor<2xindex>
|
||||
return %1 : tensor<2xindex>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: do_not_dce_while_with_outfeed
|
||||
func @do_not_dce_while_with_outfeed(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
// CHECK: mhlo.while
|
||||
|
@ -1,4 +1,6 @@
|
||||
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck %s
|
||||
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting \
|
||||
// RUN: -buffer-deallocation -split-input-file -cse %s -o - \
|
||||
// RUN: | FILECHECK_OPTS="" FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @attrs
|
||||
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
@ -153,64 +155,41 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
|
||||
|
||||
// -----
|
||||
|
||||
func @external_func() -> tensor<3xi64>
|
||||
|
||||
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
|
||||
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + d2 * s2)>
|
||||
|
||||
// CHECK-LABEL: func @dyn_broadcast
|
||||
func @dyn_broadcast(%operand: memref<?x?xf32>) {
|
||||
// CHECK-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
|
||||
func @dyn_broadcast(%operand: memref<?x?xf32>) -> index {
|
||||
// CHECK-SAME: %[[OPERAND:.*]]: memref<?x?xf32>
|
||||
%tensor_operand = tensor_load %operand : memref<?x?xf32>
|
||||
%c1 = constant 1 : i64
|
||||
%shape = tensor_from_elements %c1, %c1, %c1 : tensor<3xi64>
|
||||
%tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
|
||||
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[SHAPE:.*]] = tensor_from_elements
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
|
||||
// CHECK: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<3xi64>
|
||||
// CHECK: %[[IC1:.*]] = index_cast %[[EL1]] : i64 to index
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64>
|
||||
// CHECK: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index
|
||||
// CHECK: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]])
|
||||
|
||||
// CHECK: %[[C0_:.*]] = constant 0 : index
|
||||
// CHECK: %[[C1_:.*]] = constant 1 : index
|
||||
|
||||
// CHECK: %[[C1__:.*]] = constant 1 : index
|
||||
// CHECK: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64>
|
||||
// CHECK: %[[C0___:.*]] = constant 0 : index
|
||||
// CHECK: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], %[[C0___]] : memref<?x?xf32>
|
||||
// CHECK: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index
|
||||
// CHECK: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]]
|
||||
// CHECK: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index
|
||||
|
||||
// CHECK: %[[C2_:.*]] = constant 2 : index
|
||||
// CHECK: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64>
|
||||
// CHECK: %[[C1___:.*]] = constant 1 : index
|
||||
// CHECK: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], %[[C1___]] : memref<?x?xf32>
|
||||
// CHECK: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index
|
||||
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
|
||||
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
|
||||
|
||||
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to
|
||||
// CHECK-SAME: offset: [0],
|
||||
// CHECK-SAME: sizes: {{\[}}%[[RESULT_DIM_1]], %[[RESULT_DIM_2]]]
|
||||
// CHECK-SAME: strides: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
|
||||
// CHECK-SAME: : memref<?x?xf32> to memref<?x?xf32, #map>
|
||||
|
||||
// CHECK: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
|
||||
// CHECK-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
// CHECK-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> ()
|
||||
|
||||
// Do not store the value back to avoid the tensor-store being rewritten to
|
||||
// a copy into the pre-allocated argument.
|
||||
return
|
||||
%rank = rank %tensor_result : tensor<?x?x?xf32>
|
||||
return %rank : index
|
||||
}
|
||||
// CHECK: %[[SHAPE:.*]] = tensor_from_elements
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
|
||||
// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64>
|
||||
// CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64>
|
||||
// CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index
|
||||
// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
|
||||
// CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref<?x?xf32>
|
||||
// CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index
|
||||
// CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref<?x?xf32>
|
||||
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPER_DIM_0]], %[[SIZE_1]] : index
|
||||
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0]], %[[OP_STRIDE_0]] : index
|
||||
// CHECK: %[[EXPAND_2:.*]] = cmpi "slt", %[[OPER_DIM_1]], %[[SIZE_2]] : index
|
||||
// CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
|
||||
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]]: memref<?x?xf32> to memref<?x?x?xf32, #map>
|
||||
// CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref<?x?x?xf32, #map>, memref<?x?x?xf32>) -> ()
|
||||
// CHECK: dealloc %[[RESULT]] : memref<?x?x?xf32>
|
||||
|
||||
// -----
|
||||
|
||||
@ -483,11 +462,9 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
|
||||
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
|
||||
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
|
||||
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
|
||||
// CHECK: %[[C0_:.*]] = constant 0 : index
|
||||
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
|
||||
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64>
|
||||
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
|
||||
// CHECK: %[[C1_:.*]] = constant 1 : index
|
||||
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
|
||||
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64>
|
||||
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
|
||||
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
|
||||
// CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||
@ -508,11 +485,9 @@ func @tanh_dyn(%arg0: tensor<?x?xf32>) {
|
||||
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
|
||||
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
|
||||
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
|
||||
// CHECK: %[[C0_:.*]] = constant 0 : index
|
||||
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
|
||||
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64>
|
||||
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
|
||||
// CHECK: %[[C1_:.*]] = constant 1 : index
|
||||
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
|
||||
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64>
|
||||
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
|
||||
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
|
||||
// CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||
@ -613,7 +588,7 @@ func @transpose(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
func @custom_call(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memref<4x4xf16>) {
|
||||
%arg0_tensor = tensor_load %arg0 : memref<2x2xf32>
|
||||
%arg1_tensor = tensor_load %arg1 : memref<2x3xf32>
|
||||
// CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false}
|
||||
// CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false, operand_segment_sizes = dense<[2, 1]> : vector<2xi32>}
|
||||
%result_tensor = "mhlo.custom_call"(%arg0_tensor, %arg1_tensor)
|
||||
{backend_config = "", call_target_name = "foo", has_side_effect = false}
|
||||
: (tensor<2x2xf32>, tensor<2x3xf32>) -> tensor<4x4xf16>
|
||||
@ -623,6 +598,22 @@ func @custom_call(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memre
|
||||
|
||||
// ----
|
||||
|
||||
// CHECK-LABEL: func @custom_call_multiout
|
||||
// CHECK-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>, [[RESULT:%.*]]: memref<4x4xf16>)
|
||||
func @custom_call_multiout(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memref<4x4xf16>) {
|
||||
%arg0_tensor = tensor_load %arg0 : memref<2x2xf32>
|
||||
%arg1_tensor = tensor_load %arg1 : memref<2x3xf32>
|
||||
// CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}, %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false, operand_segment_sizes = dense<2> : vector<2xi32>}
|
||||
%temp:2 = "mhlo.custom_call"(%arg0_tensor, %arg1_tensor)
|
||||
{backend_config = "", call_target_name = "foo", has_side_effect = false}
|
||||
: (tensor<2x2xf32>, tensor<2x3xf32>) -> (tensor<4x4xf16>, tensor<4x4xf16>)
|
||||
%result_tensor = "mhlo.add"(%temp#0, %temp#1) : (tensor<4x4xf16>, tensor<4x4xf16>) -> tensor<4x4xf16>
|
||||
tensor_store %result_tensor, %result: memref<4x4xf16>
|
||||
return
|
||||
}
|
||||
|
||||
// ----
|
||||
|
||||
// CHECK-LABEL: func @isfinite
|
||||
func @isfinite(%arg0: memref<2x2xf32>, %result: memref<2x2xi1>) {
|
||||
%arg0_tensor = tensor_load %arg0 : memref<2x2xf32>
|
||||
@ -645,7 +636,7 @@ func @shape_assuming_memref(%arg0: tensor<?xf16>) -> tensor<?xf16> {
|
||||
%4 = tensor_cast %3 : tensor<?xindex> to tensor<1xindex>
|
||||
%5 = "mhlo.dynamic_broadcast_in_dim"(%0, %4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f16>, tensor<1xindex>) -> tensor<?xf16>
|
||||
%6 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf16>, tensor<1xindex>) -> tensor<?xf16>
|
||||
// CHECK: "lmhlo.maximum"(%6, %9, %20) : (memref<?xf16>, memref<?xf16>, memref<?xf16>) -> ()
|
||||
// CHECK: "lmhlo.maximum"(%{{.*}}, %{{.*}}, %{{.*}}) : (memref<?xf16>, memref<?xf16>, memref<?xf16>) -> ()
|
||||
%7 = mhlo.maximum %5, %6 : tensor<?xf16>
|
||||
// CHECK: shape.assuming_yield %{{.*}} : memref<?xf16>
|
||||
shape.assuming_yield %7 : tensor<?xf16>
|
||||
|
@ -299,3 +299,131 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
||||
// PLOOP: absf
|
||||
// PLOOP: memref_reshape
|
||||
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// Confirm that tiling information is passed through RegionBranchOpInterfaces.
|
||||
// This test also uses memref_reshape, just to have a value to return through
|
||||
// the if statement.
|
||||
func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
||||
-> memref<*xf32> {
|
||||
%c1 = constant 1 : index
|
||||
%c0 = constant 0 : index
|
||||
%1 = alloc(%arg2) : memref<?xf32>
|
||||
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
|
||||
affine_map<(d0) -> (d0)>],
|
||||
iterator_types = ["parallel"]}
|
||||
ins(%arg0 : memref<?xf32>) outs(%1 : memref<?xf32>) {
|
||||
^bb0(%arg3: f32, %arg4: f32): // no predecessors
|
||||
%13 = absf %arg3 : f32
|
||||
linalg.yield %13 : f32
|
||||
}
|
||||
%true = constant 1 : i1
|
||||
%3 = scf.if %true -> memref<*xf32> {
|
||||
%2 = memref_reshape %1(%arg1)
|
||||
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
|
||||
scf.yield %2 : memref<*xf32>
|
||||
} else {
|
||||
%2 = memref_reshape %1(%arg1)
|
||||
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
|
||||
scf.yield %2 : memref<*xf32>
|
||||
}
|
||||
return %3 : memref<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @branching_result
|
||||
// CHECK: %[[C1:.*]] = constant 1
|
||||
// CHECK-NOT: linalg.generic
|
||||
// CHECK: scf.for {{.*}} step %[[C1]]
|
||||
// CHECK-NOT: scf.for
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: absf
|
||||
// CHECK: scf.if
|
||||
// CHECK: memref_reshape
|
||||
// CHECK: scf.yield
|
||||
// CHECK: else
|
||||
// CHECK: memref_reshape
|
||||
// CHECK: scf.yield
|
||||
|
||||
// TILED-LABEL: func @branching_result
|
||||
// TILED-DAG: %[[C2:.*]] = constant 2
|
||||
// TILED-NOT: linalg.generic
|
||||
// TILED: scf.for {{.*}} step %[[C2]]
|
||||
// TILED-NOT: scf.for
|
||||
// TILED: linalg.generic
|
||||
// TILED: absf
|
||||
// TILED: scf.if
|
||||
// TILED: memref_reshape
|
||||
// TILED: scf.yield
|
||||
// TILED: else
|
||||
// TILED: memref_reshape
|
||||
// TILED: scf.yield
|
||||
|
||||
// PLOOP-LABEL: func @branching_result
|
||||
// PLOOP-NOT: linalg.generic
|
||||
// PLOOP: scf.parallel
|
||||
// PLOOP-NOT: scf.parallel
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: absf
|
||||
// PLOOP: scf.if
|
||||
// PLOOP: memref_reshape
|
||||
// PLOOP: scf.yield
|
||||
// PLOOP: else
|
||||
// PLOOP: memref_reshape
|
||||
// PLOOP: scf.yield
|
||||
|
||||
// -----
|
||||
|
||||
// Confirm that tiling information is passed through tensor_load, tensor_cast
|
||||
// and memref_to_tensor operations.
|
||||
func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
|
||||
-> memref<?xf32> {
|
||||
%c1 = constant 1 : index
|
||||
%1 = alloc() : memref<32xf32>
|
||||
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
|
||||
affine_map<(d0) -> (d0)>],
|
||||
iterator_types = ["parallel"]}
|
||||
ins(%arg0 : memref<32xf32>) outs(%1 : memref<32xf32>) {
|
||||
^bb0(%arg3: f32, %arg4: f32): // no predecessors
|
||||
%13 = absf %arg3 : f32
|
||||
linalg.yield %13 : f32
|
||||
}
|
||||
%2 = tensor_load %1 : memref<32xf32>
|
||||
%3 = tensor_cast %2 : tensor<32xf32> to tensor<?xf32>
|
||||
%4 = tensor_to_memref %3 : memref<?xf32>
|
||||
return %4 : memref<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tensor_ops
|
||||
// CHECK: %[[C1:.*]] = constant 1
|
||||
// CHECK-NOT: linalg.generic
|
||||
// CHECK: scf.for {{.*}} step %[[C1]]
|
||||
// CHECK-NOT: scf.for
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: absf
|
||||
// CHECK: tensor_load
|
||||
// CHECK: tensor_cast
|
||||
// CHECK: tensor_to_memref
|
||||
|
||||
// TILED-LABEL: func @tensor_ops
|
||||
// TILED-DAG: %[[C2:.*]] = constant 2
|
||||
// TILED-NOT: linalg.generic
|
||||
// TILED: scf.for {{.*}} step %[[C2]]
|
||||
// TILED-NOT: scf.for
|
||||
// TILED: linalg.generic
|
||||
// TILED: absf
|
||||
// TILED: tensor_load
|
||||
// TILED: tensor_cast
|
||||
// TILED: tensor_to_memref
|
||||
|
||||
|
||||
// PLOOP-LABEL: func @tensor_ops
|
||||
// PLOOP-NOT: linalg.generic
|
||||
// PLOOP: scf.parallel
|
||||
// PLOOP-NOT: scf.parallel
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: absf
|
||||
// PLOOP: tensor_load
|
||||
// PLOOP: tensor_cast
|
||||
// PLOOP: tensor_to_memref
|
||||
|
@ -3,13 +3,13 @@
|
||||
// Tests for types, ops with custom constraints, verifiers, printer or parser
|
||||
// methods.
|
||||
|
||||
// CHECK-LABEL: func @token_type() -> !mhlo.token
|
||||
func @token_type() -> !mhlo.token
|
||||
// CHECK-LABEL: func private @token_type() -> !mhlo.token
|
||||
func private @token_type() -> !mhlo.token
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{unknown mhlo type: foobar}}
|
||||
func @invalid_type() -> !mhlo.foobar
|
||||
func private @invalid_type() -> !mhlo.foobar
|
||||
|
||||
// -----
|
||||
|
||||
@ -1281,3 +1281,12 @@ func @set_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<1x128x512xf32> {
|
||||
%result = "mhlo.set_dimension_size"(%I, %dim) {dimension = 3 : i64} : (tensor<1x128x512xf32>, tensor<i32>) -> tensor<1x128x512xf32>
|
||||
return %result : tensor<1x128x512xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @custom_call_multiple_outputs
|
||||
func @custom_call_multiple_outputs(%x: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0:2 = "mhlo.custom_call"(%x) {backend_config="", call_target_name = "foo", has_side_effect = false} : (tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
%1 = "mhlo.add"(%0#0, %0#1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
return %1 : tensor<2xf32>
|
||||
}
|
||||
|
@ -35,9 +35,9 @@ filegroup(
|
||||
"ir/tfl_ops.td",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:SideEffectTdFiles",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
|
||||
],
|
||||
)
|
||||
|
||||
@ -390,6 +390,7 @@ cc_library(
|
||||
"transforms/generated_legalize_tf.inc",
|
||||
"transforms/generated_lower_static_tensor_list.inc",
|
||||
"transforms/generated_prepare_tf.inc",
|
||||
"transforms/insert_call_once_op.cc",
|
||||
"transforms/legalize_tf.cc",
|
||||
"transforms/legalize_tf_while.cc",
|
||||
"transforms/lower_static_tensor_list.cc",
|
||||
@ -427,6 +428,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
|
||||
"//tensorflow/compiler/mlir/tensorflow:unroll_batch_matmul_pass",
|
||||
"//tensorflow/compiler/mlir/tensorflow:verification_utils",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
@ -464,6 +466,7 @@ cc_library(
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:verification_utils",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
|
@ -167,7 +167,8 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
|
||||
case 32:
|
||||
return tflite::TensorType_INT32;
|
||||
case 64:
|
||||
return tflite::TensorType_INT64;
|
||||
return itype.isUnsigned() ? tflite::TensorType_UINT64
|
||||
: tflite::TensorType_INT64;
|
||||
}
|
||||
} else if (auto q_uniform_type =
|
||||
type.dyn_cast<mlir::quant::UniformQuantizedType>()) {
|
||||
@ -453,6 +454,11 @@ class Translator {
|
||||
mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
|
||||
// Build call once operator.
|
||||
BufferOffset<tflite::Operator> BuildCallOnceOperator(
|
||||
mlir::TFL::CallOnceOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
|
||||
// Builds custom operators.
|
||||
// Templated on a) data type of custom_option to be stored into flatbuffer,
|
||||
// and b) TFL custom op type.
|
||||
@ -787,6 +793,22 @@ BufferOffset<tflite::Operator> Translator::BuildIfOperator(
|
||||
builtin_options);
|
||||
}
|
||||
|
||||
BufferOffset<tflite::Operator> Translator::BuildCallOnceOperator(
|
||||
mlir::TFL::CallOnceOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results) {
|
||||
auto opcode_index =
|
||||
GetOpcodeIndex("call_once", tflite::BuiltinOperator_CALL_ONCE);
|
||||
int init_subgraph_index =
|
||||
subgraph_index_map_.at(op.session_init_function().str());
|
||||
auto builtin_options =
|
||||
tflite::CreateCallOnceOptions(builder_, init_subgraph_index).Union();
|
||||
auto inputs = builder_.CreateVector(operands);
|
||||
auto outputs = builder_.CreateVector(results);
|
||||
return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
|
||||
tflite::BuiltinOptions_CallOnceOptions,
|
||||
builtin_options);
|
||||
}
|
||||
|
||||
BufferOffset<tflite::Operator> Translator::BuildWhileOperator(
|
||||
mlir::TF::WhileOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results) {
|
||||
@ -1026,6 +1048,12 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
if (*builtin_code == tflite::BuiltinOperator_CALL_ONCE) {
|
||||
if (auto initOp = dyn_cast<mlir::TFL::CallOnceOp>(inst)) {
|
||||
return BuildCallOnceOperator(initOp, operands, results);
|
||||
}
|
||||
}
|
||||
|
||||
std::string op_name = inst->getName().getStringRef().str();
|
||||
uint32_t opcode_index = GetOpcodeIndex(op_name, *builtin_code);
|
||||
|
||||
|
@ -448,13 +448,54 @@ StatusOr<Operation*> BuildExternalConstOp(const tflite::TensorT& tensor,
|
||||
return op.getOperation();
|
||||
}
|
||||
|
||||
// Gets a constant splat for the given value of type. Requires value to be of
|
||||
// type static shaped RankedTensorType. `unique_index` is used to get the unique
|
||||
// value for the attribute.
|
||||
static mlir::ElementsAttr GetSplat(RankedTensorType type, int unique_index,
|
||||
OpBuilder builder) {
|
||||
mlir::Type element_ty = getElementTypeOrSelf(type);
|
||||
|
||||
if (element_ty.isSignlessInteger())
|
||||
return DenseElementsAttr::get(
|
||||
type, builder.getIntegerAttr(element_ty, unique_index));
|
||||
|
||||
if (element_ty.isa<mlir::FloatType>())
|
||||
return DenseElementsAttr::get(
|
||||
type, builder.getFloatAttr(element_ty, unique_index));
|
||||
|
||||
if (auto qtype = element_ty.dyn_cast<QuantizedType>()) {
|
||||
mlir::RankedTensorType new_type =
|
||||
RankedTensorType::get(type.getShape(), qtype.getStorageType());
|
||||
return DenseElementsAttr::get(
|
||||
new_type, builder.getIntegerAttr(qtype.getStorageType(), unique_index));
|
||||
}
|
||||
llvm_unreachable("unhandled element type");
|
||||
}
|
||||
|
||||
// TODO(b/172664358): Creates a new op instead of reusing constant op.
|
||||
// Creates a constant op to represent stateful variable. The function static
|
||||
// variable `stateful_variable_idx` is used as a unique value for each constant
|
||||
// to avoid CSEed. `tensor` is the data structure of flatbuffer. `shaped_type`
|
||||
// is the ShapedType for the const op.
|
||||
Operation* BuildVariableOp(const tflite::TensorT& tensor,
|
||||
mlir::RankedTensorType shaped_type,
|
||||
OpBuilder builder, Location loc) {
|
||||
static int stateful_variable_idx = 0;
|
||||
mlir::ElementsAttr value =
|
||||
GetSplat(shaped_type, stateful_variable_idx++, builder);
|
||||
if (IsQuantized(tensor)) {
|
||||
auto op = builder.create<tfl::QConstOp>(
|
||||
loc, mlir::TypeAttr::get(shaped_type), value);
|
||||
return op.getOperation();
|
||||
}
|
||||
auto op = builder.create<tfl::ConstOp>(loc, value);
|
||||
return op.getOperation();
|
||||
}
|
||||
|
||||
StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
|
||||
const std::vector<uint8_t>& buffer,
|
||||
OpBuilder builder, Location loc) {
|
||||
if (buffer.empty()) {
|
||||
return errors::InvalidArgument("Constant's buffer may not be empty");
|
||||
}
|
||||
|
||||
bool is_variable, OpBuilder builder,
|
||||
Location loc) {
|
||||
TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder,
|
||||
/*shapeless_are_scalars=*/true,
|
||||
/*is_constant=*/true));
|
||||
@ -466,7 +507,9 @@ StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
|
||||
auto elem_type = shaped_type.getElementType();
|
||||
|
||||
mlir::ElementsAttr value;
|
||||
if (auto float_type = elem_type.dyn_cast<mlir::FloatType>()) {
|
||||
if (is_variable) {
|
||||
return BuildVariableOp(tensor, shaped_type, builder, loc);
|
||||
} else if (auto float_type = elem_type.dyn_cast<mlir::FloatType>()) {
|
||||
TF_ASSIGN_OR_RETURN(value,
|
||||
ConvertFloatBuffer(shaped_type, float_type, buffer));
|
||||
} else if (elem_type.isa<mlir::IntegerType, QuantizedType>()) {
|
||||
@ -846,19 +889,8 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
GetTensorIndices(subgraph, ordered_input_arrays));
|
||||
}
|
||||
|
||||
// Add state variables to inputs.
|
||||
absl::flat_hash_set<int32_t> input_index_set(func_inputs.begin(),
|
||||
func_inputs.end());
|
||||
for (int i = 0, end = subgraph.tensors.size(); i < end; i++) {
|
||||
auto& tensor = *subgraph.tensors.at(i);
|
||||
if (tensor.is_variable && !input_index_set.contains(i)) {
|
||||
func_inputs.emplace_back(i);
|
||||
input_index_set.insert(i);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto input_or_variable : func_inputs) {
|
||||
auto& tensor = *subgraph.tensors.at(input_or_variable);
|
||||
for (int input : func_inputs) {
|
||||
auto& tensor = *subgraph.tensors.at(input);
|
||||
// TODO(b/138222071) Graph inputs must have static shape per the exporter,
|
||||
// but we cannot differentiate scalars from unranked tensors.
|
||||
// Here we reverse the default assumption that shape = [] means unranked.
|
||||
@ -889,7 +921,8 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
}
|
||||
|
||||
for (auto output : func_outputs) {
|
||||
const bool is_func_input = input_index_set.contains(output);
|
||||
const bool is_func_input = std::find(func_inputs.begin(), func_inputs.end(),
|
||||
output) != func_inputs.end();
|
||||
bool is_constant = !is_op_output[output] && !is_func_input;
|
||||
// There are 2 cases tensor is scalar when it doesn't have a shape in
|
||||
// flatbuffer:
|
||||
@ -955,7 +988,7 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
}
|
||||
func.setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
|
||||
} else {
|
||||
func.setVisibility(FuncOp::Visibility::Private);
|
||||
func.setPrivate();
|
||||
}
|
||||
|
||||
absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
|
||||
@ -991,7 +1024,7 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
? BuildExternalConstOp(const_tensor, const_tensor.buffer,
|
||||
op_builder, const_loc)
|
||||
: BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data,
|
||||
op_builder, const_loc);
|
||||
const_tensor.is_variable, op_builder, const_loc);
|
||||
if (!op_or_err.ok()) {
|
||||
return emitError(const_loc, op_or_err.status().ToString()),
|
||||
op_or_err.status();
|
||||
@ -1051,7 +1084,7 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
? BuildExternalConstOp(const_tensor, const_tensor.buffer,
|
||||
op_builder, const_loc)
|
||||
: BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data,
|
||||
op_builder, const_loc);
|
||||
const_tensor.is_variable, op_builder, const_loc);
|
||||
if (!op_or_err.ok()) {
|
||||
return emitError(const_loc, op_or_err.status().ToString()),
|
||||
op_or_err.status();
|
||||
|
@ -1972,6 +1972,43 @@ OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
|
||||
return value();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 1);
|
||||
// For now, only supports cast between integer types.
|
||||
auto elements_attr = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
|
||||
if (!elements_attr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto result_element_type =
|
||||
getType().cast<ShapedType>().getElementType().dyn_cast<IntegerType>();
|
||||
auto operand_element_type = input()
|
||||
.getType()
|
||||
.cast<ShapedType>()
|
||||
.getElementType()
|
||||
.dyn_cast<IntegerType>();
|
||||
// Returns nullptr if either result/operand element type is not integer.
|
||||
if (!result_element_type || !operand_element_type) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const bool is_input_unsigned = operand_element_type.isUnsigned();
|
||||
const int output_bitwidth = result_element_type.getWidth();
|
||||
// The integer cast op is the same as C integer cast. Depends on the operand
|
||||
// type's signedness, we will determine whether or not sign extension is
|
||||
// needed.
|
||||
auto cast = [&](APInt value) {
|
||||
return is_input_unsigned ? value.zextOrTrunc(output_bitwidth)
|
||||
: value.sextOrTrunc(output_bitwidth);
|
||||
};
|
||||
|
||||
return elements_attr.mapValues(result_element_type, cast);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SelectV2Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -3405,7 +3405,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8]>:$input,
|
||||
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8, TFL_Str]>:$input,
|
||||
TFL_I32Tensor:$begin,
|
||||
TFL_I32Tensor:$end,
|
||||
TFL_I32Tensor:$strides,
|
||||
@ -3418,7 +3418,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8]>:$output
|
||||
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8, TFL_Str]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
@ -3443,6 +3443,8 @@ def TFL_CastOp : TFL_Op<"cast", [
|
||||
// TFLite's cast op does not utilize CastOptions, instead derives types
|
||||
// from the TfLiteTensors.
|
||||
let hasOptions = 0;
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [
|
||||
@ -3877,7 +3879,7 @@ def TFL_UnidirectionalSequenceLSTMOp :
|
||||
TFL_OperandHasRank<14, 1>, // cell_gate_bias
|
||||
TFL_OperandHasRank<15, 1>, // output_gate_bias
|
||||
TFL_OperandIsNoneOrHasRank<16, 2>, // projection_weights
|
||||
TFL_OperandIsNoneOrHasRank<17, 2>, // projection_bias
|
||||
TFL_OperandIsNoneOrHasRank<17, 1>, // projection_bias
|
||||
TFL_StatefulOp]> {
|
||||
let summary = "Unidirectional sequence lstm operator";
|
||||
|
||||
@ -4358,6 +4360,21 @@ def TFL_WhileOp : Op<TFL_Dialect, "while", [
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TFL_CallOnceOp : TFL_Op<"call_once", []> {
|
||||
let summary = "Invokes an initialization function";
|
||||
|
||||
let description = [{
|
||||
This operation invokes the given initialization function for the session
|
||||
initializer in tf saved model dialect.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
StrAttr:$session_init_function
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
def TFL_CustomOp : Op<TFL_Dialect, "custom", [
|
||||
NoSideEffect, NoQuantizableResult]> {
|
||||
let summary = "Custom op";
|
||||
|
@ -55,7 +55,7 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
// Parse input arrays.
|
||||
std::vector<string> node_names;
|
||||
std::vector<string> node_dtypes;
|
||||
std::vector<std::vector<int>> node_shapes;
|
||||
std::vector<llvm::Optional<std::vector<int>>> node_shapes;
|
||||
std::vector<llvm::Optional<double>> node_mins;
|
||||
std::vector<llvm::Optional<double>> node_maxs;
|
||||
|
||||
|
@ -128,7 +128,7 @@ Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
// Parse input arrays.
|
||||
std::vector<string> node_names;
|
||||
std::vector<string> node_dtypes;
|
||||
std::vector<std::vector<int>> node_shapes;
|
||||
std::vector<llvm::Optional<std::vector<int>>> node_shapes;
|
||||
std::vector<llvm::Optional<double>> node_mins;
|
||||
std::vector<llvm::Optional<double>> node_maxs;
|
||||
|
||||
|
@ -119,6 +119,8 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
|
||||
return DT_INT32;
|
||||
case toco::IODataType::INT64:
|
||||
return DT_INT64;
|
||||
case toco::IODataType::UINT64:
|
||||
return DT_UINT64;
|
||||
case toco::IODataType::STRING:
|
||||
return DT_STRING;
|
||||
case toco::IODataType::BOOL:
|
||||
@ -185,7 +187,7 @@ Status PopulateQuantizationSpecs(
|
||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||
mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
|
||||
std::vector<string>* node_dtypes,
|
||||
std::vector<std::vector<int>>* node_shapes,
|
||||
std::vector<llvm::Optional<std::vector<int>>>* node_shapes,
|
||||
std::vector<llvm::Optional<double>>* node_mins,
|
||||
std::vector<llvm::Optional<double>>* node_maxs) {
|
||||
quant_specs->inference_input_type =
|
||||
@ -210,8 +212,12 @@ Status PopulateQuantizationSpecs(
|
||||
node_dtypes->push_back(
|
||||
DataType_Name(ConvertIODataTypeToDataType(toco_data_type)));
|
||||
}
|
||||
node_shapes->push_back(std::vector<int>(flag.shape().dims().begin(),
|
||||
flag.shape().dims().end()));
|
||||
if (flag.shape().unknown_rank()) {
|
||||
node_shapes->push_back(llvm::None);
|
||||
} else {
|
||||
node_shapes->push_back(std::vector<int>(flag.shape().dims().begin(),
|
||||
flag.shape().dims().end()));
|
||||
}
|
||||
// Currently, only UINT8 and INT8 require inputs stats
|
||||
if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) {
|
||||
if (flag.has_mean_value() && flag.has_std_value()) {
|
||||
|
@ -41,7 +41,7 @@ Status PopulateQuantizationSpecs(
|
||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||
mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
|
||||
std::vector<string>* node_dtypes,
|
||||
std::vector<std::vector<int>>* node_shapes,
|
||||
std::vector<llvm::Optional<std::vector<int>>>* node_shapes,
|
||||
std::vector<llvm::Optional<double>>* node_mins,
|
||||
std::vector<llvm::Optional<double>>* node_maxs);
|
||||
|
||||
|
@ -52,6 +52,12 @@ struct QuantizationSpecs {
|
||||
// weight FakeQuant).
|
||||
bool disable_per_channel = false;
|
||||
|
||||
// When set to true, the fixed output ranges of the activation ops (tanh,
|
||||
// sigmoid, etc.) are not enforced. Then, to quantize these ops, quantization
|
||||
// emulation ops should be specified after the ops in the input graph. This
|
||||
// flag should be set to false for post-training quantization.
|
||||
bool disable_enforced_fixed_output_range = false;
|
||||
|
||||
// The node type when the model is exported. Currently this is limited to
|
||||
// DT_FLOAT, DT_HALF, DT_QINT8, and DT_QUINT8. When DT_HALF is used, the
|
||||
// `weight_quantization` flag needs to set to true. When DT_QUINT8 is used,
|
||||
|
@ -587,3 +587,55 @@ func @rsqrt_bf16() -> tensor<bf16> {
|
||||
// CHECK: %[[CST:.*]] = constant dense<5.000000e-01> : tensor<bf16>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @cast_i64_to_i32
|
||||
func @cast_i64_to_i32() -> tensor<5xi32> {
|
||||
%cst = constant dense<[-1, 0, 1, 2147483647, 2147483648]> : tensor<5xi64>
|
||||
%0 = "tfl.cast"(%cst) : (tensor<5xi64>) -> tensor<5xi32>
|
||||
return %0 : tensor<5xi32>
|
||||
|
||||
// CHECK: %[[CST:.*]] = constant dense<[-1, 0, 1, 2147483647, -2147483648]> : tensor<5xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @cast_i32_to_ui8
|
||||
func @cast_i32_to_ui8() -> tensor<6xui8> {
|
||||
%cst = constant dense<[0, -1, 256, 127, -128, -129]> : tensor<6xi32>
|
||||
%0 = "tfl.cast"(%cst) : (tensor<6xi32>) -> tensor<6xui8>
|
||||
return %0 : tensor<6xui8>
|
||||
|
||||
// CHECK: %[[CST:.*]] = constant dense<[0, 255, 0, 127, 128, 127]> : tensor<6xui8>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @cast_ui8_to_i8
|
||||
func @cast_ui8_to_i8() -> tensor<4xi8> {
|
||||
%cst = constant dense<[0, 255, 127, 128]> : tensor<4xui8>
|
||||
%0 = "tfl.cast"(%cst) : (tensor<4xui8>) -> tensor<4xi8>
|
||||
return %0 : tensor<4xi8>
|
||||
|
||||
// CHECK: %[[CST:.*]] = constant dense<[0, -1, 127, -128]> : tensor<4xi8>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @cast_i8_to_i32
|
||||
func @cast_i8_to_i32() -> tensor<4xi32> {
|
||||
%cst = constant dense<[0, 128, -1, -128]> : tensor<4xi8>
|
||||
%0 = "tfl.cast"(%cst) : (tensor<4xi8>) -> tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
|
||||
// CHECK: %[[CST:.*]] = constant dense<[0, -128, -1, -128]> : tensor<4xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @cast_ui8_to_i32
|
||||
func @cast_ui8_to_i32() -> tensor<4xi32> {
|
||||
%cst = constant dense<[0, 128, 129, 255]> : tensor<4xui8>
|
||||
%0 = "tfl.cast"(%cst) : (tensor<4xui8>) -> tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
|
||||
// CHECK: %[[CST:.*]] = constant dense<[0, 128, 129, 255]> : tensor<4xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
|
||||
|
@ -411,11 +411,11 @@ versions {
|
||||
# CHECK-NEXT: constant dense<[5.000000e+00, 6.000000e+00, 7.000000e+00, 8.000000e+00]>
|
||||
# CHECK: "tf.If"{{.+}}else_branch = @cond_false_10{{.+}}is_stateless = true{{.+}}then_branch = @cond_true_10
|
||||
# CHECK: "tf.If"{{.+}}else_branch = @cond_false0{{.+}}is_stateless = false{{.+}}then_branch = @cond_true0
|
||||
# CHECK: func @cond_false_10
|
||||
# CHECK: func private @cond_false_10
|
||||
# CHECK-NEXT: tfl.div
|
||||
# CHECK: func @cond_true_10
|
||||
# CHECK: func private @cond_true_10
|
||||
# CHECK-NEXT: tfl.sub
|
||||
# CHECK: func @cond_false0
|
||||
# CHECK: func private @cond_false0
|
||||
# CHECK-NEXT: tfl.mul
|
||||
# CHECK: func @cond_true0
|
||||
# CHECK: func private @cond_true0
|
||||
# CHECK-NEXT: tfl.add
|
||||
|
@ -78,14 +78,14 @@ versions {
|
||||
}
|
||||
|
||||
# CHECK: func @main(%[[VAL_0:.*]]: tensor<2x5x3xf32>, %[[VAL_1:.*]]: tensor<3x7xf32>) -> tensor<2x5x7xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "Placeholder,Placeholder_1", outputs = "MatMul"}} {
|
||||
# CHECK: %[[VAL_2:.*]] = constant dense<[1, 0]> : tensor<2xi32>
|
||||
# CHECK: %[[VAL_3:.*]] = constant dense<[5, 3]> : tensor<2xi32>
|
||||
# CHECK: %[[VAL_4:.*]] = constant dense<[3, 7]> : tensor<2xi32>
|
||||
# CHECK: %[[VAL_5:.*]] = constant unit
|
||||
# CHECK: %[[VAL_6:.*]] = constant dense<[1, 0, 0]> : tensor<3xi32>
|
||||
# CHECK: %[[VAL_7:.*]] = constant dense<[1, 5, 3]> : tensor<3xi32>
|
||||
# CHECK: %[[VAL_8:.*]] = constant dense<0> : tensor<3xi32>
|
||||
# CHECK: %[[VAL_9:.*]] = constant dense<[1, 3, 7]> : tensor<3xi32>
|
||||
# CHECK-DAG: %[[VAL_2:.*]] = constant dense<[1, 0]> : tensor<2xi32>
|
||||
# CHECK-DAG: %[[VAL_3:.*]] = constant dense<[5, 3]> : tensor<2xi32>
|
||||
# CHECK-DAG: %[[VAL_4:.*]] = constant dense<[3, 7]> : tensor<2xi32>
|
||||
# CHECK-DAG: %[[VAL_5:.*]] = constant unit
|
||||
# CHECK-DAG: %[[VAL_6:.*]] = constant dense<[1, 0, 0]> : tensor<3xi32>
|
||||
# CHECK-DAG: %[[VAL_7:.*]] = constant dense<[1, 5, 3]> : tensor<3xi32>
|
||||
# CHECK-DAG: %[[VAL_8:.*]] = constant dense<0> : tensor<3xi32>
|
||||
# CHECK-DAG: %[[VAL_9:.*]] = constant dense<[1, 3, 7]> : tensor<3xi32>
|
||||
# CHECK: %[[VAL_10:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_8]], %[[VAL_7]]) : (tensor<2x5x3xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x5x3xf32>
|
||||
# CHECK: %[[VAL_11:.*]] = "tfl.reshape"(%[[VAL_10]], %[[VAL_3]]) : (tensor<1x5x3xf32>, tensor<2xi32>) -> tensor<5x3xf32>
|
||||
# CHECK: %[[VAL_12:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_6]], %[[VAL_7]]) : (tensor<2x5x3xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x5x3xf32>
|
||||
|
@ -8,9 +8,11 @@ func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32
|
||||
return %24 : tensor<1x4xf32>
|
||||
// CHECK-LABEL: main
|
||||
// seperate lines since there is no region for this op. third_party/tensorflow/compiler/mlir/lite/ir/tfl_ops.td: 3252
|
||||
// CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg22, %arg23, %arg18, %arg19, %arg20, %arg21) ( {
|
||||
// CHECK: %[[RES0:.*]] = "tfl.pseudo_const"() {value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32>
|
||||
// CHECK: %[[RES1:.*]] = "tfl.pseudo_const"() {value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32>
|
||||
// CHECK: %[[RES2:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %[[RES0]], %[[RES1]], %arg18, %arg19, %arg20, %arg21) ( {
|
||||
// CHECK: }) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK: return %[[RES0]]
|
||||
// CHECK: return %[[RES2]]
|
||||
|
||||
}
|
||||
|
||||
@ -29,9 +31,9 @@ func @testFullyQuantizedLSTM(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testUnidirectionalSequenceLstmWithIntermediates
|
||||
func @testUnidirectionalSequenceLstmWithIntermediates(%arg0: tensor<? x ? x f32>, %arg1: tensor<? x ? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x ? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
func @testUnidirectionalSequenceLstmWithIntermediates(%arg0: tensor<? x ? x f32>, %arg1: tensor<? x ? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
|
@ -5,8 +5,8 @@
|
||||
// CHECK: func @main(%arg0: tensor<1xf32>) -> tensor<*xf32>
|
||||
// CHECK: %0 = "tf.While"(%arg0) {body = @body, cond = @cond, is_stateless = false} : (tensor<1xf32>) -> tensor<*xf32>
|
||||
// CHECK: return %0 : tensor<*xf32>
|
||||
// CHECK: func @cond(%arg0: tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: func @body(%arg0: tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: func private @cond(%arg0: tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: func private @body(%arg0: tensor<*xf32>) -> tensor<*xf32>
|
||||
|
||||
func @main(%arg0: tensor<1xf32>) -> tensor<*xf32> {
|
||||
%0 = "tf.While"(%arg0) {cond = @cond, body = @body, is_stateless = false} : (tensor<1xf32>) -> tensor<*xf32>
|
||||
|
@ -1,6 +1,6 @@
|
||||
// RUN: tf-opt -tfl-prepare-composite-funcs-tf -tfl-fuse-tftext=true %s | FileCheck %s
|
||||
|
||||
func @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<1>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
|
||||
func private @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {tf._input_shapes = [#tf.shape<1>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
|
||||
%0 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
%1 = "tf.Const"() {value = dense<[]> : tensor<0xi64>} : () -> tensor<0xi64>
|
||||
%2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
@ -1026,11 +1026,11 @@ func @WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_true_23810(%arg0: t
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK: func @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<1>], tf.signature.is_stateful} {
|
||||
// CHECK: func private @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<1>], tf.signature.is_stateful} {
|
||||
// CHECK: %0:2 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<1x!tf.string>) -> (tensor<?x!tf.string>, tensor<?xi64>)
|
||||
// CHECK: return %0#0, %0#1 : tensor<?x!tf.string>, tensor<?xi64>
|
||||
|
||||
func @whitespace_tokenizer_rank2(%arg0: tensor<?x1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>) attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<?x1>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
|
||||
func private @whitespace_tokenizer_rank2(%arg0: tensor<?x1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>) attributes {tf._input_shapes = [#tf.shape<?x1>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
|
||||
%0 = "tf.Const"() {value = dense<[]> : tensor<0xi64>} : () -> tensor<0xi64>
|
||||
%1 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
%2 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||
@ -2160,11 +2160,11 @@ func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_As
|
||||
|
||||
|
||||
|
||||
// CHECK: func @whitespace_tokenizer_rank2(%arg0: tensor<?x1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<?x1>], tf.signature.is_stateful} {
|
||||
// CHECK: func private @whitespace_tokenizer_rank2(%arg0: tensor<?x1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>) attributes {tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<?x1>], tf.signature.is_stateful} {
|
||||
// CHECK: %0:3 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<?x1x!tf.string>) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>)
|
||||
// CHECK: return %0#0, %0#1, %0#2 : tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>
|
||||
|
||||
func @whitespace_tokenizer_rank0(%arg0: tensor<!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
|
||||
func private @whitespace_tokenizer_rank0(%arg0: tensor<!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {tf._input_shapes = [#tf.shape<>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
|
||||
%0 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
%1 = "tf.Const"() {value = dense<[]> : tensor<0xi64>} : () -> tensor<0xi64>
|
||||
%2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
@ -3190,7 +3190,7 @@ func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK: func @whitespace_tokenizer_rank0(%arg0: tensor<!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<>], tf.signature.is_stateful} {
|
||||
// CHECK: func private @whitespace_tokenizer_rank0(%arg0: tensor<!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<>], tf.signature.is_stateful} {
|
||||
// CHECK: %0 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<!tf.string>) -> tensor<?x!tf.string>
|
||||
// CHECK: return %0 : tensor<?x!tf.string>
|
||||
|
||||
@ -3213,7 +3213,7 @@ func @ngrams(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "input"}) ->
|
||||
// CHECK: return %0 : tensor<?x!tf.string>
|
||||
// CHECK: }
|
||||
|
||||
func @ngrams_ragged_rank_2(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<3xi64> {tf._user_specified_name = "args_0"}, %arg2: tensor<?xi64> {tf._user_specified_name = "args_1"}) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = "", width = 2 : i64}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<3>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||
func private @ngrams_ragged_rank_2(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<3xi64> {tf._user_specified_name = "args_0"}, %arg2: tensor<?xi64> {tf._user_specified_name = "args_1"}) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) attributes {tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = "", width = 2 : i64}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<3>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf.Const"() {value = dense<-1> : tensor<i64>} : () -> tensor<i64>
|
||||
%2 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
@ -3330,12 +3330,12 @@ func @ngrams_ragged_rank_2(%arg0: tensor<?x!tf.string> {tf._user_specified_name
|
||||
%71 = "tf.Identity"(%70) {device = ""} : (tensor<3xi64>) -> tensor<3xi64>
|
||||
return %68, %71, %64 : tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>
|
||||
}
|
||||
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_27770(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
|
||||
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_27770(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
|
||||
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_27780(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
|
||||
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_27780(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
|
||||
%0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
@ -3345,12 +3345,12 @@ func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_as
|
||||
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %5 : tensor<i1>
|
||||
}
|
||||
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_28130(%arg0: tensor<i1>, %arg1: tensor<?xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<?>]} {
|
||||
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_28130(%arg0: tensor<i1>, %arg1: tensor<?xi64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<?>]} {
|
||||
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_28140(%arg0: tensor<i1>, %arg1: tensor<?xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_28140(%arg0: tensor<i1>, %arg1: tensor<?xi64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||
%0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/sub:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
@ -3359,12 +3359,12 @@ func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_as
|
||||
%4 = "tf.Identity"(%3) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %4 : tensor<i1>
|
||||
}
|
||||
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28500(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
|
||||
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28500(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
|
||||
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28510(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
|
||||
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28510(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
|
||||
%0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits/strided_slice_1:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
@ -3374,12 +3374,12 @@ func @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_Assert
|
||||
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %5 : tensor<i1>
|
||||
}
|
||||
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28900(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
|
||||
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28900(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
|
||||
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28910(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
|
||||
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28910(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
|
||||
%0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
@ -3389,12 +3389,12 @@ func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_
|
||||
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %5 : tensor<i1>
|
||||
}
|
||||
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_29260(%arg0: tensor<i1>, %arg1: tensor<2xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<2>]} {
|
||||
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_29260(%arg0: tensor<i1>, %arg1: tensor<2xi64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<2>]} {
|
||||
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_29270(%arg0: tensor<i1>, %arg1: tensor<2xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<2>], tf.signature.is_stateful} {
|
||||
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_29270(%arg0: tensor<i1>, %arg1: tensor<2xi64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<2>], tf.signature.is_stateful} {
|
||||
%0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/sub:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
@ -3403,12 +3403,12 @@ func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_
|
||||
%4 = "tf.Identity"(%3) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %4 : tensor<i1>
|
||||
}
|
||||
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_true_29650(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
|
||||
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_true_29650(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
|
||||
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_false_29660(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
|
||||
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_false_29660(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
|
||||
%0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits_1/strided_slice:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
@ -3418,12 +3418,12 @@ func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_Asse
|
||||
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %5 : tensor<i1>
|
||||
}
|
||||
func @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_true_30330(%arg0: tensor<i1>, %arg1: tensor<?xi64>, %arg2: tensor<?xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<?>, #tf.shape<?>]} {
|
||||
func private @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_true_30330(%arg0: tensor<i1>, %arg1: tensor<?xi64>, %arg2: tensor<?xi64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<?>, #tf.shape<?>]} {
|
||||
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
func @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_false_30340(%arg0: tensor<i1>, %arg1: tensor<?xi64>, %arg2: tensor<?xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||
func private @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_false_30340(%arg0: tensor<i1>, %arg1: tensor<?xi64>, %arg2: tensor<?xi64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||
%0 = "tf.Const"() {value = dense<"Inputs must have identical ragged splits"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%2 = "tf.Const"() {value = dense<"x (NGrams/SlidingWindow/RaggedGetItem/RaggedRange:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
@ -3433,12 +3433,12 @@ func @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_false_
|
||||
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %5 : tensor<i1>
|
||||
}
|
||||
// CHECK: func @ngrams_ragged_rank_2(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<3xi64> {tf._user_specified_name = "args_0"}, %arg2: tensor<?xi64> {tf._user_specified_name = "args_1"}) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = "", width = 2 : i64}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<3>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||
// CHECK: func private @ngrams_ragged_rank_2(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<3xi64> {tf._user_specified_name = "args_0"}, %arg2: tensor<?xi64> {tf._user_specified_name = "args_1"}) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) attributes {tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = "", width = 2 : i64}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<3>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||
// CHECK: %0:3 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "tftext:Ngrams", custom_option = opaque<"tfl", "0x776964746800737472696E675F736570617261746F720000006178697300726564756374696F6E5F74797065000B535452494E475F4A4F494E0004221E373E040104FF152C0204141404082401"> : tensor<77xi8>} : (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>)
|
||||
// CHECK: return %0#0, %0#1, %0#2 : tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>
|
||||
|
||||
|
||||
func @sgnn_projection(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<?xi64> {tf._user_specified_name = "row_splits"}) -> tensor<?x10xf64> attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:custom:SgnnProjection", {buckets = 2147483647 : i64, hash_seed = [1902835825, -1475704015, 473120514, 1254202069, 1558833093, 1756181982, 1906603252, -1034142694, 542842690, 535515822]}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||
func private @sgnn_projection(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<?xi64> {tf._user_specified_name = "row_splits"}) -> tensor<?x10xf64> attributes {tf._implements = #tf.func<@"tftext:custom:SgnnProjection", {buckets = 2147483647 : i64, hash_seed = [1902835825, -1475704015, 473120514, 1254202069, 1558833093, 1756181982, 1906603252, -1034142694, 542842690, 535515822]}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||
%0 = "tf.Const"() {value = dense<[[1902835825], [-1475704015], [473120514], [1254202069], [1558833093], [1756181982], [1906603252], [-1034142694], [542842690], [535515822]]> : tensor<10x1xi64>} : () -> tensor<10x1xi64>
|
||||
%1 = "tf.StringToHashBucketFast"(%arg0) {device = "", num_buckets = 2147483647 : i64} : (tensor<?x!tf.string>) -> tensor<?xi64>
|
||||
%2 = "tf.Sgnn"(%1, %0) {device = ""} : (tensor<?xi64>, tensor<10x1xi64>) -> tensor<10x?xf64>
|
||||
@ -3448,6 +3448,6 @@ func @sgnn_projection(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "va
|
||||
}
|
||||
|
||||
|
||||
// CHECK: func @sgnn_projection(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<?xi64> {tf._user_specified_name = "row_splits"}) -> tensor<?x10xf64> attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:custom:SgnnProjection", {buckets = 2147483647 : i64, hash_seed = [1902835825, -1475704015, 473120514, 1254202069, 1558833093, 1756181982, 1906603252, -1034142694, 542842690, 535515822]}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||
// CHECK: func private @sgnn_projection(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<?xi64> {tf._user_specified_name = "row_splits"}) -> tensor<?x10xf64> attributes {tf._implements = #tf.func<@"tftext:custom:SgnnProjection", {buckets = 2147483647 : i64, hash_seed = [1902835825, -1475704015, 473120514, 1254202069, 1558833093, 1756181982, 1906603252, -1034142694, 542842690, 535515822]}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||
// CHECK: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "tftext:custom:SgnnProjection", custom_option = opaque<"tfl", "0x686173685F736565640000000A00000071F86A71318B0AA8023F331CD59AC14AC5E7E95CDE35AD68F474A4711A3C5CC2421F5B20AE52EB1F6275636B6574730002094200030000000100000002000000FFFFFF7F44000000062E0A2601"> : tensor<93xi8>} : (tensor<?x!tf.string>, tensor<?xi64>) -> tensor<?x10xf64>
|
||||
// CHECK: return %0 : tensor<?x10xf64>
|
||||
|
40
tensorflow/compiler/mlir/lite/tests/insert_call_once_op.mlir
Normal file
40
tensorflow/compiler/mlir/lite/tests/insert_call_once_op.mlir
Normal file
@ -0,0 +1,40 @@
|
||||
// RUN: tf-opt -split-input-file -tfl-insert-call-once-op %s | FileCheck %s
|
||||
|
||||
// Tests that new call_once op is added when there is a session initializer.
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
"tf_saved_model.session_initializer"() {initializers = [@init_all_tables]} : () -> ()
|
||||
|
||||
func @init_all_tables()
|
||||
attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"]} {
|
||||
%cst = constant dense<[1, 2, 3, 4]> : tensor<4xi64>
|
||||
%cst_0 = constant dense<["a", "b", "c", "d"]> : tensor<4x!tf.string>
|
||||
%0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = i64, shared_name = "hash_table_dba2ccaa-f1b1-46d6-b276-98008f69da71", use_node_name_sharing = false, value_dtype = !tf.string} : () -> tensor<!tf.resource>
|
||||
"tf.LookupTableImportV2"(%0, %cst, %cst_0) {device = ""} : (tensor<!tf.resource>, tensor<4xi64>, tensor<4x!tf.string>) -> ()
|
||||
return
|
||||
// CHECK-LABEL: @init_all_tables
|
||||
}
|
||||
|
||||
func @serving_default(%arg0: tensor<i64> {tf_saved_model.index_path = ["x"]}) -> (tensor<*x!tf.string> {tf_saved_model.index_path = ["r"]})
|
||||
attributes {tf.entry_function = {control_outputs = "", inputs = "input:0", outputs = "hash_table_Lookup/LookupTableFindV2:0"}, tf_saved_model.exported_names = ["serving_default"]} {
|
||||
%cst = constant dense<"f"> : tensor<!tf.string>
|
||||
%0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = i64, shared_name = "hash_table_dba2ccaa-f1b1-46d6-b276-98008f69da71", use_node_name_sharing = false, value_dtype = !tf.string} : () -> tensor<!tf.resource>
|
||||
%1 = "tf.LookupTableFindV2"(%0, %arg0, %cst) {device = ""} : (tensor<!tf.resource>, tensor<i64>, tensor<!tf.string>) -> tensor<*x!tf.string>
|
||||
return %1 : tensor<*x!tf.string>
|
||||
// CHECK-LABEL: @serving_default
|
||||
// CHECK: "tfl.call_once"() {session_init_function = "init_all_tables"} : () -> ()
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that no call_once op is added.
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
func @no_call_once(%arg0: tensor<i64> {tf_saved_model.index_path = ["x"]}) -> (tensor<i64> {tf_saved_model.index_path = ["r"]})
|
||||
attributes {tf.entry_function = {control_outputs = "", inputs = "input:0", outputs = "output:0"}, tf_saved_model.exported_names = ["serving_default"]} {
|
||||
return %arg0 : tensor<i64>
|
||||
// CHECK-LABEL: no_call_once
|
||||
// CHECK-NOT: "tfl.call_once"
|
||||
}
|
||||
}
|
@ -435,6 +435,16 @@ func @scatterNdHigherRankIndices(%arg0: tensor<4x2x2xi32>, %arg1: tensor<4x2x3xf
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
||||
func @scatter_nd_i64(%arg0: tensor<4x2x2xi64>, %arg1: tensor<4x2x3xf32>, %arg2: tensor<3xi64>) -> tensor<10x2x3xf32> {
|
||||
%0 = "tf.ScatterNd"(%arg0, %arg1, %arg2) : (tensor<4x2x2xi64>, tensor<4x2x3xf32>, tensor<3xi64>) -> tensor<10x2x3xf32>
|
||||
return %0 : tensor<10x2x3xf32>
|
||||
|
||||
// CHECK-LABEL:scatter_nd_i64
|
||||
// CHECK: "tfl.cast"
|
||||
// CHECK: "tfl.cast"
|
||||
// CHECK: "tfl.scatter_nd"
|
||||
}
|
||||
|
||||
func @gatherV2VectorIndices(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x3x5x20xf32> {
|
||||
%0 = "tf.Const"() { value = dense<[1]> : tensor<1xi32> } : () -> tensor<1xi32>
|
||||
%1 = "tf.GatherV2"(%arg0, %arg1, %0) : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x3x5x20xf32>
|
||||
@ -689,6 +699,16 @@ func @reverse_v2(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1xi32>) -> tensor<1x2
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
func @reverse_v2_i64(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1xi64>) -> tensor<1x2x3x4xf32> {
|
||||
%0 = "tf.ReverseV2"(%arg0, %arg1) : (tensor<1x2x3x4xf32>, tensor<1xi64>) -> tensor<1x2x3x4xf32>
|
||||
return %0 : tensor<1x2x3x4xf32>
|
||||
|
||||
// CHECK-LABEL:reverse_v2_i64
|
||||
// CHECK: "tfl.cast"
|
||||
// CHECK: "tfl.reverse_v2"
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
func @matrix_diag(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
|
||||
%0 = "tf.MatrixDiag"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16x16xf32>
|
||||
return %0 : tensor<8x16x16xf32>
|
||||
@ -763,13 +783,31 @@ func @matrix_diag_v3(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
|
||||
// CHECK: return [[VAL_6]] : tensor<8x16x16xf32>
|
||||
}
|
||||
|
||||
func @matrix_set_diag(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
|
||||
%0 = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
|
||||
return %0 : tensor<3x3xi32>
|
||||
func @matrix_set_diag_v3(%arg0: tensor<3x3xi64>, %arg1: tensor<3xi32>) -> tensor<3x3xi64> {
|
||||
%cst = constant dense<0> : tensor<i32>
|
||||
%0 = "tf.MatrixSetDiagV3"(%arg0, %arg1, %cst) {align = "RIGHT_LEFT"} : (tensor<3x3xi64>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi64>
|
||||
return %0 : tensor<3x3xi64>
|
||||
|
||||
// CHECK-LABEL: func @matrix_set_diag(
|
||||
// CHECK: [[VAL_0:%.*]] = "tfl.matrix_set_diag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
|
||||
// CHECK: return [[VAL_0]]
|
||||
// CHECK-LABEL: func @matrix_set_diag_v3
|
||||
// CHECK: "tfl.matrix_set_diag"(%arg0, %arg1) : (tensor<3x3xi64>, tensor<3xi32>) -> tensor<3x3xi64>
|
||||
}
|
||||
|
||||
func @matrix_set_diag_v3_non_zero_k(%arg0: tensor<3x3xi64>, %arg1: tensor<3xi32>) -> tensor<3x3xi64> {
|
||||
%cst = constant dense<1> : tensor<i32>
|
||||
%0 = "tf.MatrixSetDiagV3"(%arg0, %arg1, %cst) : (tensor<3x3xi64>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi64>
|
||||
return %0 : tensor<3x3xi64>
|
||||
|
||||
// CHECK-LABEL: @matrix_set_diag_v3_non_zero_k
|
||||
// CHECK: tf.MatrixSetDiagV3
|
||||
}
|
||||
|
||||
func @matrix_set_diag_v3_default_align(%arg0: tensor<3x3xi64>, %arg1: tensor<3xi32>) -> tensor<3x3xi64> {
|
||||
%cst = constant dense<0> : tensor<i32>
|
||||
%0 = "tf.MatrixSetDiagV3"(%arg0, %arg1, %cst) : (tensor<3x3xi64>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi64>
|
||||
return %0 : tensor<3x3xi64>
|
||||
|
||||
// CHECK-LABEL: @matrix_set_diag_v3_default_align
|
||||
// CHECK: "tfl.matrix_set_diag"(%arg0, %arg1) : (tensor<3x3xi64>, tensor<3xi32>) -> tensor<3x3xi64>
|
||||
}
|
||||
|
||||
func @maximum(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
@ -996,6 +1034,15 @@ func @batch_to_space_nd_unsupported(%arg0: tensor<?x1x1x1x4xf32>, %arg1: tensor<
|
||||
// CHECK: "tf.BatchToSpaceND"
|
||||
}
|
||||
|
||||
func @batch_to_space_nd_i64(%arg0: tensor<4x2x2x3xf32>, %arg1: tensor<2xi64>, %arg2: tensor<2x2xi64>) -> tensor<?xf32> {
|
||||
%0 = "tf.BatchToSpaceND"(%arg0, %arg1, %arg2) : (tensor<4x2x2x3xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
// CHECK-LABEL: batch_to_space_nd_i64
|
||||
// CHECK: "tfl.cast"
|
||||
// CHECK: "tfl.cast"
|
||||
// CHECK: "tfl.batch_to_space_nd"
|
||||
}
|
||||
|
||||
func @space_to_batch_nd(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2x2xi32>) -> tensor<*xf32> {
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
@ -1003,6 +1050,15 @@ func @space_to_batch_nd(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<2xi32>, %arg2:
|
||||
// CHECK: "tfl.space_to_batch_nd"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<*xf32>
|
||||
}
|
||||
|
||||
func @space_to_batch_nd_i64(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<2xi64>, %arg2: tensor<2x2xi64>) -> tensor<*xf32> {
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
// CHECK-LABEL: space_to_batch_nd_i64
|
||||
// CHECK: "tfl.cast"
|
||||
// CHECK: "tfl.cast"
|
||||
// CHECK: "tfl.space_to_batch_nd"
|
||||
}
|
||||
|
||||
func @split(%arg0: tensor<i32>, %arg1: tensor<1x4x3x3xf32>) -> tensor<1x4x3xf32> {
|
||||
%0:3 = "tf.Split"(%arg0, %arg1) : (tensor<i32>, tensor<1x4x3x3xf32>) -> (tensor<1x4x3xf32>, tensor<1x4x3xf32>, tensor<1x4x3xf32>)
|
||||
return %0#0 : tensor<1x4x3xf32>
|
||||
@ -1122,6 +1178,13 @@ func @strided_slice_with_constant_attributes(%arg0: tensor<10x10x10xf32>, %arg1:
|
||||
// CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 6 : i32, ellipsis_mask = 0 : i32, end_mask = 6 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<10x10xf32>
|
||||
}
|
||||
|
||||
func @strided_slice_with_string(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> {
|
||||
%0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string>
|
||||
return %0 : tensor<1x2x2x5x!tf.string>
|
||||
// CHECK-LABEL: strided_slice_with_string
|
||||
// CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string>
|
||||
}
|
||||
|
||||
func @slice1Tensor(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<?x3x5xf32> {
|
||||
%0 = "tf.Slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
|
||||
return %0 : tensor<?x3x5xf32>
|
||||
@ -1354,8 +1417,7 @@ func @conv2d_backprop_input(%arg0: tensor<4xi32>, %arg1: tensor<3x3x1x32xf32>, %
|
||||
|
||||
// CHECK-LABEL: conv2d_backprop_input
|
||||
// CHECK: %[[CST:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32>
|
||||
// CHECK: %[[CAST:.*]] = "tfl.cast"(%[[CST]]) : (tensor<4xi32>) -> tensor<4xi32>
|
||||
// CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CAST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
|
||||
// CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
|
||||
// CHECK: %[[CST_0:.*]] = constant unit
|
||||
// CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32>
|
||||
// CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32>
|
||||
@ -1790,10 +1852,25 @@ func @cumsum(%arg0: tensor<3x3xf32>, %arg1: tensor<i32>) -> tensor<3x3xf32> {
|
||||
// CHECK: "tfl.cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor<i32>) -> tensor<3x3xf32>
|
||||
}
|
||||
|
||||
func @cumsum_invalid(%arg0: tensor<3x3xf32>, %arg1: tensor<i64>) -> tensor<3x3xf32> {
|
||||
func @cumsum_i64(%arg0: tensor<3x3xf32>, %arg1: tensor<i64>) -> tensor<3x3xf32> {
|
||||
%0 = "tf.Cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor<i64>) -> tensor<3x3xf32>
|
||||
return %0 : tensor<3x3xf32>
|
||||
// CHECK-LABEL: cumsum_invalid
|
||||
// CHECK-NOT: "tfl.cumsum"
|
||||
// CHECK-LABEL: cumsum_i64
|
||||
// CHECK: "tfl.cast"
|
||||
// CHECK: "tfl.cumsum"
|
||||
}
|
||||
|
||||
func @segmentsum(%arg0: tensor<3x3xf32>, %arg1: tensor<i32>) -> tensor<*xf32> {
|
||||
%0 = "tf.SegmentSum"(%arg0, %arg1) : (tensor<3x3xf32>, tensor<i32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
// CHECK-LABEL: segmentsum
|
||||
// CHECK: "tfl.segment_sum"(%arg0, %arg1) : (tensor<3x3xf32>, tensor<i32>) -> tensor<*xf32>
|
||||
}
|
||||
|
||||
func @segmentsum_i64(%arg0: tensor<3x3xf32>, %arg1: tensor<i64>) -> tensor<*xf32> {
|
||||
%0 = "tf.SegmentSum"(%arg0, %arg1) : (tensor<3x3xf32>, tensor<i64>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
// CHECK-LABEL: segmentsum_i64
|
||||
// CHECK: "tfl.cast"
|
||||
// CHECK: "tfl.segment_sum"
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s --dump-input=always
|
||||
|
||||
func @main(tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> {
|
||||
func @main(tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> {
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
@ -129,7 +129,7 @@ func @main(tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: buffer: 18,
|
||||
// CHECK-NEXT: name: "arg17",
|
||||
// CHECK-NEXT: quantization: {
|
||||
@ -282,9 +282,36 @@ func @main(tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-EMPTY:
|
||||
|
||||
^bb0(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4x4xf32>, %arg18: tensor<4x4xf32>, %arg19: tensor<4x4xf32>, %arg20: tensor<4x4xf32>, %arg21: tensor<4x4xf32>):
|
||||
^bb0(%arg0: tensor<4x4xf32>,
|
||||
%arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>,
|
||||
%arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>,
|
||||
%arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>,
|
||||
%arg12: tensor<4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>,
|
||||
%arg16: tensor<4x4xf32>, %arg17: tensor<4xf32>,
|
||||
%arg18: tensor<4x4xf32>, %arg19: tensor<4x4xf32>, %arg20: tensor<4x4xf32>, %arg21: tensor<4x4xf32>):
|
||||
%0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const")
|
||||
%1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const")
|
||||
%2 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %0, %1, %arg18, %arg19, %arg20, %arg21) {effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "NONE", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
%2 = "tfl.unidirectional_sequence_lstm"(%arg0,
|
||||
%arg1, %arg2, %arg3, %arg4,
|
||||
%arg5, %arg6, %arg7, %arg8,
|
||||
%arg9, %arg10, %arg11,
|
||||
%arg12, %arg13, %arg14, %arg15,
|
||||
%arg16, %arg17,
|
||||
%0, %1,
|
||||
%arg18, %arg19,%arg20, %arg21) {
|
||||
effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>,
|
||||
fused_activation_function = "NONE",
|
||||
input_to_cell_intermediate = tensor<0xf32>,
|
||||
input_to_forget_intermediate = tensor<0xf32>,
|
||||
input_to_input_intermediate = tensor<0xf32>,
|
||||
input_to_output_intermediate = tensor<0xf32>, time_major = true}
|
||||
: (tensor<4x4xf32>,
|
||||
tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>,
|
||||
tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>,
|
||||
tensor<4xf32>, tensor<4xf32>, tensor<4xf32>,
|
||||
tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>,
|
||||
tensor<4x4xf32>, tensor<4xf32>,
|
||||
tensor<4x4xf32>, tensor<4x4xf32>,
|
||||
tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
return %2 : tensor<4x4xf32>
|
||||
}
|
||||
|
@ -663,25 +663,25 @@ func @testUnidirectionalSequenceLstmWithoutProjection(%arg0: tensor<? x ? x f32>
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testUnidirectionalSequenceLstm
|
||||
func @testUnidirectionalSequenceLstm(%arg0: tensor<? x ? x f32>, %arg1: tensor<? x ? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x ? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
func @testUnidirectionalSequenceLstm(%arg0: tensor<? x ? x f32>, %arg1: tensor<? x ? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr
|
||||
func @testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x ? x f32>, %arg1: none, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x ? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
func @testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x ? x f32>, %arg1: none, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testUnidirectionalSequenceLstmWithIntermediates
|
||||
func @testUnidirectionalSequenceLstmWithIntermediates(%arg0: tensor<? x ? x f32>, %arg1: none, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x ? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
func @testUnidirectionalSequenceLstmWithIntermediates(%arg0: tensor<? x ? x f32>, %arg1: none, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
@ -1458,6 +1458,12 @@ func @testStridedSliceTFType(%arg0: tensor<12x2x2x5xui8>, %arg1: tensor<1xi32>,
|
||||
return %0 : tensor<1x2x2x5x!tf.quint8>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testStridedSliceWithString
|
||||
func @testStridedSliceWithString(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> {
|
||||
%0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string>
|
||||
return %0 : tensor<1x2x2x5x!tf.string>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testStridedSliceWithInvalidOutputType(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5xi32> {
|
||||
|
@ -407,16 +407,16 @@ func @fuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @notFuseMulIntoDepthwiseConv2d
|
||||
func @notFuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> {
|
||||
func @notFuseMulIntoDepthwiseConv2d(%arg0: tensor<1x4x4x2xf32>) -> tensor<1x4x4x2xf32> {
|
||||
%cst0 = constant dense<[[[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0], [17.0, 18.0]]]]> : tensor<1x3x3x2xf32>
|
||||
%cst1 = constant dense<2.0> : tensor<2xf32>
|
||||
%cst2 = constant dense<3.0> : tensor<112x2xf32>
|
||||
%cst2 = constant dense<[[3.1, 3.2], [3.1, 3.2], [3.1, 3.2], [3.1, 3.2]]> : tensor<4x2xf32>
|
||||
|
||||
%0 = "tfl.depthwise_conv_2d"(%arg0, %cst0, %cst1) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
|
||||
%0 = "tfl.depthwise_conv_2d"(%arg0, %cst0, %cst1) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x4x4x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x4x4x2xf32>
|
||||
// We cannot fuse this tfl.mul into the preceding conv op because %cst2 is not broadcast-compatible to %cst0.
|
||||
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x112x112x2xf32>, tensor<112x2xf32>) -> tensor<1x112x112x2xf32>
|
||||
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x4x4x2xf32>, tensor<4x2xf32>) -> tensor<1x4x4x2xf32>
|
||||
|
||||
return %1 : tensor<1x112x112x2xf32>
|
||||
return %1 : tensor<1x4x4x2xf32>
|
||||
|
||||
// CHECK: %0 = "tfl.depthwise_conv_2d"(%arg0, %cst, %cst_0)
|
||||
// CHECK: %1 = "tfl.mul"(%0, %cst_1)
|
||||
@ -484,17 +484,17 @@ func @FuseFullyConnectedAddWithScalarRhs(%arg0: tensor<40x37xf32>, %arg1: tensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @FuseFullyConnectedAddWithUnfusableRhs
|
||||
func @FuseFullyConnectedAddWithUnfusableRhs(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
func @FuseFullyConnectedAddWithUnfusableRhs(%arg0: tensor<4x37xf32>, %arg1: tensor<4x37xf32>) -> tensor<4x4xf32> {
|
||||
%cst = constant unit
|
||||
%cst2 = constant dense<2.0> : tensor<40x40xf32>
|
||||
%cst2 = constant dense<[[2.0, 2.1, 2.2, 2.3], [2.0, 2.1, 2.2, 2.3], [2.0, 2.1, 2.2, 2.3], [2.0, 2.1, 2.2, 2.3]]> : tensor<4x4xf32>
|
||||
|
||||
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>)
|
||||
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40x40xf32>) -> tensor<40x40xf32>
|
||||
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x37xf32>, tensor<4x37xf32>, none) -> (tensor<4x4xf32>)
|
||||
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
|
||||
return %1 : tensor<40x40xf32>
|
||||
return %1 : tensor<4x4xf32>
|
||||
|
||||
// CHECK: %[[unit:.*]] = constant unit
|
||||
// CHECK: %[[filter:.*]] = constant dense<2.000000e+00> : tensor<40x40xf32>
|
||||
// CHECK: %[[filter:.*]] = constant dense<{{.*}}> : tensor<4x4xf32>
|
||||
// CHECK: %[[fc_result:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[unit]])
|
||||
// CHECK: %[[add_result:.*]] = tfl.add %[[fc_result]], %[[filter]]
|
||||
// CHECK: return %[[add_result]]
|
||||
@ -578,6 +578,32 @@ func @NotReorderReshapeAddIfNotTailingDimAfter(%arg0: tensor<1x30x1x96xf32>) ->
|
||||
// CHECK: return %[[rs2]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @NotReorderReshapeAddIf5DInputs
|
||||
func @NotReorderReshapeAddIf5DInputs(%arg0: tensor<1x1x1x1x1xf32>) -> tensor<1x1x1x1x2xf32> {
|
||||
%cst = constant dense<2.0> : tensor<1x1x1x1x2xf32>
|
||||
%shape = constant dense<[1, 1, 1, 1, 2]> : tensor<5xi32>
|
||||
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<1x1x1x1x1xf32>, tensor<5xi32>) -> tensor<1x1x1x1x2xf32>
|
||||
%2 = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<1x1x1x1x2xf32>, tensor<1x1x1x1x2xf32>) -> tensor<1x1x1x1x2xf32>
|
||||
return %2 : tensor<1x1x1x1x2xf32>
|
||||
|
||||
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0
|
||||
// CHECK: %[[rs2:.*]] = tfl.add %[[rs1]]
|
||||
// CHECK: return %[[rs2]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @NotReorderReshapeFloorDivIf5DInputs
|
||||
func @NotReorderReshapeFloorDivIf5DInputs(%arg0: tensor<1x1x1x1x1xf32>) -> tensor<1x1x1x1x2xf32> {
|
||||
%cst = constant dense<2.0> : tensor<1x1x1x1x2xf32>
|
||||
%shape = constant dense<[1, 1, 1, 1, 2]> : tensor<5xi32>
|
||||
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<1x1x1x1x1xf32>, tensor<5xi32>) -> tensor<1x1x1x1x2xf32>
|
||||
%2 = "tfl.floor_div"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<1x1x1x1x2xf32>, tensor<1x1x1x1x2xf32>) -> tensor<1x1x1x1x2xf32>
|
||||
return %2 : tensor<1x1x1x1x2xf32>
|
||||
|
||||
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0
|
||||
// CHECK: %[[rs2:.*]] = tfl.floor_div %[[rs1]]
|
||||
// CHECK: return %[[rs2]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @NotReorderReshapeAddIfNotTailingDim
|
||||
func @NotReorderReshapeAddIfNotTailingDim(%arg0: tensor<40x40x1xf32>) -> tensor<40x40xf32> {
|
||||
%cst = constant dense<2.0> : tensor<1x40xf32>
|
||||
@ -851,17 +877,17 @@ func @fuseDivIntoConv2d_Scalar(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fuseMulIntoConv2d_Scalar
|
||||
func @fuseMulIntoConv2d_Scalar(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> {
|
||||
func @fuseMulIntoConv2d_Scalar(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x1xf32> {
|
||||
%cst0 = constant dense<[[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]> : tensor<1x2x2x2xf32>
|
||||
%cst1 = constant dense<1.0> : tensor<2xf32>
|
||||
%cst1 = constant dense<1.0> : tensor<1xf32>
|
||||
%cst2 = constant dense<2.0> : tensor<f32>
|
||||
%0 = "tfl.conv_2d"(%arg0, %cst0, %cst1) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<1x2x2x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
|
||||
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<1x112x112x2xf32>, tensor<f32>) -> tensor<1x112x112x2xf32>
|
||||
%0 = "tfl.conv_2d"(%arg0, %cst0, %cst1) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<1x2x2x2xf32>, tensor<1xf32>) -> tensor<1x112x112x1xf32>
|
||||
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<1x112x112x1xf32>, tensor<f32>) -> tensor<1x112x112x1xf32>
|
||||
|
||||
return %1 : tensor<1x112x112x2xf32>
|
||||
return %1 : tensor<1x112x112x1xf32>
|
||||
// CHECK: %[[CST1:.*]] = constant dense<{{\[\[\[\[}}2.000000e+00, 4.000000e+00], [6.000000e+00, 8.000000e+00]], {{\[\[}}1.000000e+01, 1.200000e+01], [1.400000e+01, 1.600000e+01]]]]> : tensor<1x2x2x2xf32>
|
||||
// CHECK: %[[CST2:.*]] = constant dense<2.000000e+00> : tensor<2xf32>
|
||||
// CHECK: %[[RES:[0-9].*]] = "tfl.conv_2d"(%arg0, %[[CST1]], %[[CST2]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<1x2x2x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
|
||||
// CHECK: %[[CST2:.*]] = constant dense<2.000000e+00> : tensor<1xf32>
|
||||
// CHECK: %[[RES:[0-9].*]] = "tfl.conv_2d"(%arg0, %[[CST1]], %[[CST2]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<1x2x2x2xf32>, tensor<1xf32>) -> tensor<1x112x112x1xf32>
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
||||
@ -896,6 +922,36 @@ func @fuseTileWithBinaryOp1(%arg0: tensor<1x1xf32>, %arg1: tensor<1x128xf32>) ->
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: notFuseTileWithBinaryOpOn5DInputs
|
||||
func @notFuseTileWithBinaryOpOn5DInputs(%arg0: tensor<1x1xf32>) -> tensor<1x1x1x1x2xf32> {
|
||||
%cst = constant dense<[1, 1, 1, 1, 2]> : tensor<5xi32>
|
||||
%cst1 = constant dense<3.0> : tensor<1x1x1x1x2xf32>
|
||||
%0 = "tfl.sqrt"(%arg0) : (tensor<1x1xf32>) -> tensor<1x1xf32>
|
||||
%1 = "tfl.tile"(%0, %cst) : (tensor<1x1xf32>, tensor<5xi32>) -> tensor<1x1x1x1x2xf32>
|
||||
%2 = "tfl.add"(%cst1, %1) {fused_activation_function = "NONE"} : (tensor<1x1x1x1x2xf32>, tensor<1x1x1x1x2xf32>) -> tensor<1x1x1x1x2xf32>
|
||||
return %2 : tensor<1x1x1x1x2xf32>
|
||||
|
||||
// CHECK: "tfl.sqrt"
|
||||
// CHECK: "tfl.tile"
|
||||
// CHECK: tfl.add
|
||||
}
|
||||
|
||||
// CHECK-LABEL: notFuseTileWithBinaryOp1On5DInputs
|
||||
func @notFuseTileWithBinaryOp1On5DInputs(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1x1x1x128xf32>) -> tensor<1x1x1x1x128xf32> {
|
||||
%cst_0 = constant dense<1.0> : tensor<f32>
|
||||
%cst_1 = constant dense<[1, 1, 1, 1, 128]> : tensor<5xi32>
|
||||
%0 = "tfl.add"(%arg0, %cst_0) {fused_activation_function = "NONE"} : (tensor<1x1xf32>, tensor<f32>) -> tensor<1x1xf32>
|
||||
%1 = "tfl.sqrt"(%0) : (tensor<1x1xf32>) -> tensor<1x1xf32>
|
||||
%2 = "tfl.tile"(%1, %cst_1) : (tensor<1x1xf32>, tensor<5xi32>) -> tensor<1x1x1x1x128xf32>
|
||||
%3 = "tfl.div"(%2, %arg1) {fused_activation_function = "NONE"} : (tensor<1x1x1x1x128xf32>, tensor<1x1x1x1x128xf32>) -> tensor<1x1x1x1x128xf32>
|
||||
return %3 : tensor<1x1x1x1x128xf32>
|
||||
|
||||
// CHECK: "tfl.add"
|
||||
// CHECK: "tfl.sqrt"
|
||||
// CHECK: "tfl.tile"
|
||||
// CHECK: tfl.div
|
||||
}
|
||||
|
||||
// CHECK-LABEL: InvalidFuseTileWithBinaryOp
|
||||
func @InvalidFuseTileWithBinaryOp(%arg0: tensor<2x3xf32>) -> tensor<2x6xf32> {
|
||||
%cst = constant dense<[[1,2]]> : tensor<1x2xi32>
|
||||
@ -1155,6 +1211,18 @@ func @ReorderAddWithConstant(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[RESULT:.*]] = tfl.add %arg0, %[[CONST]] {fused_activation_function = "NONE"} : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
func @NotReorderAddWithConstantOn5D(%arg0: tensor<2x2x2x2x2xf32>) -> tensor<2x2x2x2x2xf32> {
|
||||
%cst = constant dense<1.0> : tensor<2x2x2x2x2xf32>
|
||||
%cst_1 = constant dense<2.0> : tensor<2x2x2x2x2xf32>
|
||||
%0 = "tfl.add"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<2x2x2x2x2xf32>, tensor<2x2x2x2x2xf32>) -> tensor<2x2x2x2x2xf32>
|
||||
%1 = "tfl.add"(%0, %cst_1) {fused_activation_function = "NONE"} : (tensor<2x2x2x2x2xf32>, tensor<2x2x2x2x2xf32>) -> tensor<2x2x2x2x2xf32>
|
||||
return %1 : tensor<2x2x2x2x2xf32>
|
||||
|
||||
// CHECK-LABEL: NotReorderAddWithConstantOn5D
|
||||
// CHECK: tfl.add
|
||||
// CHECK: tfl.add
|
||||
}
|
||||
|
||||
func @RemoveCast(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%1 = "tfl.cast"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %1 : tensor<2x2xf32>
|
||||
@ -1397,3 +1465,50 @@ func @fuseExpanded1DMulIntoConv2d(%arg0: tensor<1x8x8x207xf32>) -> tensor<1x8x8x
|
||||
// CHECK: "tfl.conv_2d"(%arg0, %[[CST_0]], %[[CST_1]])
|
||||
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @FuseFullyConnectedAddWithSplat2D
|
||||
func @FuseFullyConnectedAddWithSplat2D(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
%cst = constant unit
|
||||
%cst2 = constant dense<2.0> : tensor<40x40xf32>
|
||||
|
||||
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>)
|
||||
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40x40xf32>) -> tensor<40x40xf32>
|
||||
|
||||
return %1 : tensor<40x40xf32>
|
||||
|
||||
// CHECK: %[[BIAS:.*]] = constant dense<2.000000e+00> : tensor<40xf32>
|
||||
// CHECK: %[[FC_RESULT:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[BIAS]])
|
||||
// CHECK: return %[[FC_RESULT]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fuseMulIntoConv2d_Splat2D
|
||||
func @fuseMulIntoConv2d_Splat2D(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> {
|
||||
%cst0 = constant dense<[[[[1.0, 2.0]]], [[[3.0, 4.0]]]]> : tensor<2x1x1x2xf32>
|
||||
%cst1 = constant dense<1.0> : tensor<2xf32>
|
||||
%cst2 = constant dense<2.0> : tensor<1x112x112x2xf32>
|
||||
%0 = "tfl.conv_2d"(%arg0, %cst0, %cst1) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<2x1x1x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
|
||||
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<1x112x112x2xf32>, tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32>
|
||||
|
||||
return %1 : tensor<1x112x112x2xf32>
|
||||
// CHECK: %[[CST1:.*]] = constant dense<{{\[\[\[\[}}2.000000e+00, 4.000000e+00]]], {{\[\[\[}}6.000000e+00, 8.000000e+00]]]]> : tensor<2x1x1x2xf32>
|
||||
// CHECK: %[[CST2:.*]] = constant dense<2.000000e+00> : tensor<2xf32>
|
||||
// CHECK: %[[RES:[0-9].*]] = "tfl.conv_2d"(%arg0, %[[CST1]], %[[CST2]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<2x1x1x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @AvoidFuseFullyConnectedAddWithSplat2D
|
||||
func @AvoidFuseFullyConnectedAddWithSplat2D(%arg0: tensor<1x1x1x1x1xf32>, %arg1: tensor<1x1xf32>) -> tensor<1x1x1x1x1xf32> {
|
||||
%cst = constant unit
|
||||
%cst2 = constant dense<2.0> : tensor<1x1x1x1x1xf32>
|
||||
|
||||
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x1x1x1x1xf32>, tensor<1x1xf32>, none) -> tensor<1x1x1x1x1xf32>
|
||||
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<1x1x1x1x1xf32>, tensor<1x1x1x1x1xf32>) -> tensor<1x1x1x1x1xf32>
|
||||
|
||||
return %1 : tensor<1x1x1x1x1xf32>
|
||||
|
||||
// CHECK: %[[CST1:.*]] = constant unit
|
||||
// CHECK: %[[CST2:.*]] = constant dense<2.000000e+00> : tensor<1x1x1x1x1xf32>
|
||||
// CHECK: %[[FC_RESULT:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[CST1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x1x1x1x1xf32>, tensor<1x1xf32>, none) -> tensor<1x1x1x1x1xf32>
|
||||
// CHECK: %[[ADD:.*]] = tfl.add %[[FC_RESULT]], %[[CST2]] {fused_activation_function = "NONE"} : tensor<1x1x1x1x1xf32>
|
||||
// CHECK: return %[[ADD]] : tensor<1x1x1x1x1xf32>
|
||||
}
|
||||
|
@ -77,3 +77,32 @@ func @HandleReturnedDequantizeWithAnotherUse(%arg0: tensor<128x16xf32>) -> (tens
|
||||
// CHECK-NEXT: return %[[softmax]], %[[argmax]] : tensor<128x16xf32>, tensor<128xi32>
|
||||
return %2, %3 : tensor<128x16xf32>, tensor<128xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: PruneUnusedLstm
|
||||
func @PruneUnusedLstm(%arg0: tensor<1x28x28xf32>) -> (tensor<1x28x28xf32>) {
|
||||
%input = "tfl.quantize"(%arg0) {qtype = tensor<1x28x28x!quant.uniform<i8:f32, 0.003:-128>>} : (tensor<1x28x28xf32>) -> tensor<1x28x28x!quant.uniform<i8:f32, 0.003:-128>>
|
||||
%cst_1 = "tfl.pseudo_qconst"() {qtype = tensor<1x20x!quant.uniform<i8:f32, 0.006:-34>>, value = dense<1> : tensor<1x20xi8>} : () -> tensor<1x20x!quant.uniform<i8:f32, 0.006:-34>>
|
||||
%cst_2 = constant unit
|
||||
%cst_3 = "tfl.pseudo_qconst"() {qtype = tensor<20x20x!quant.uniform<i8:f32, 0.006:-34>>, value = dense<1> : tensor<20x20xi8>} : () -> tensor<20x20x!quant.uniform<i8:f32, 0.006:-34>>
|
||||
%cst_7 = "tfl.pseudo_qconst"() {qtype = tensor<20x!quant.uniform<i8:f32, 0.006:-34>>, value = dense<1> : tensor<20xi8>} : () -> tensor<20x!quant.uniform<i8:f32, 0.006:-34>>
|
||||
%cst_11 = "tfl.pseudo_qconst"() {qtype = tensor<20x28x!quant.uniform<i8:f32, 0.006:-34>>, value = dense<1> : tensor<20x28xi8>} : () -> tensor<20x28x!quant.uniform<i8:f32, 0.006:-34>>
|
||||
%cell_input = "tfl.pseudo_qconst"() {qtype = tensor<1x20x!quant.uniform<i16:f32, 0.006:-34>>, value = dense<1> : tensor<1x20xi6>} : () -> tensor<1x20x!quant.uniform<i16:f32, 0.006:-34>>
|
||||
%0 = "tfl.unidirectional_sequence_lstm"(%input,
|
||||
%cst_11, %cst_11, %cst_11, %cst_11,
|
||||
%cst_3, %cst_3, %cst_3, %cst_3,
|
||||
%cst_2, %cst_2, %cst_2,
|
||||
%cst_7, %cst_7, %cst_7, %cst_7,
|
||||
%cst_2, %cst_2,
|
||||
%cst_1, %cell_input,
|
||||
%cst_2, %cst_2, %cst_2, %cst_2) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false}
|
||||
: ( tensor<1x28x28x!quant.uniform<i8:f32, 0.003:-128>>,
|
||||
tensor<20x28x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x28x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x28x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x28x!quant.uniform<i8:f32, 0.006:-34>>,
|
||||
tensor<20x20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x20x!quant.uniform<i8:f32, 0.006:-34>>,
|
||||
none, none, none,
|
||||
tensor<20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x!quant.uniform<i8:f32, 0.006:-34>>,
|
||||
none, none,
|
||||
tensor<1x20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<1x20x!quant.uniform<i16:f32, 0.006:-34>>,
|
||||
none, none, none, none) -> tensor<1x28x20x!quant.uniform<i8:f32, 0.006:-34>>
|
||||
return %arg0 : tensor<1x28x28xf32>
|
||||
// CHECK-NEXT: return %arg0
|
||||
}
|
||||
|
@ -500,21 +500,21 @@ func @nms_padded(%arg0: tensor<100x4xf32>, %arg1: tensor<100xf32>, %arg2: tensor
|
||||
|
||||
module {
|
||||
// expected-error @+1 {{Invalid number of results from non_max_suppression_padded_v2}}
|
||||
func @nms_padded_invalid_num_results(%arg0: tensor<100x4xf32>, %arg1: tensor<100xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<i1>, %arg6: tensor<i1>, %arg7: tensor<i1>, %arg8: tensor<i32>) -> () attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"}
|
||||
func private @nms_padded_invalid_num_results(%arg0: tensor<100x4xf32>, %arg1: tensor<100xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<i1>, %arg6: tensor<i1>, %arg7: tensor<i1>, %arg8: tensor<i32>) -> () attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"}
|
||||
|
||||
// expected-error @+1 {{Invalid number of arguments to non_max_suppression_padded_v2}}
|
||||
func @nms_padded_invalid_num_args(%arg0: tensor<100x4xf32>, %arg1: tensor<100xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>) -> (tensor<1x10xi32>, tensor<i32>) attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"}
|
||||
func private @nms_padded_invalid_num_args(%arg0: tensor<100x4xf32>, %arg1: tensor<100xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>) -> (tensor<1x10xi32>, tensor<i32>) attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"}
|
||||
|
||||
// expected-error @+1 {{TFLite does not support batched input for non_max_suppression_padded}}
|
||||
func @nms_padded_with_batches(%arg0: tensor<2x100x4xf32>, %arg1: tensor<2x100xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<i1>, %arg6: tensor<i1>, %arg7: tensor<i1>, %arg8: tensor<i32>) -> (tensor<2x10xi32>, tensor<i32>) attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"}
|
||||
func private @nms_padded_with_batches(%arg0: tensor<2x100x4xf32>, %arg1: tensor<2x100xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<i1>, %arg6: tensor<i1>, %arg7: tensor<i1>, %arg8: tensor<i32>) -> (tensor<2x10xi32>, tensor<i32>) attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
// CHECK-LABEL: func @some_func
|
||||
// CHECK-LABEL: func private @some_func
|
||||
// CHECK-LABEL: func @func_with_call
|
||||
func @some_func(%arg0: tensor<100xf32>) -> tensor<100xf32> attributes {tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c"}
|
||||
func private @some_func(%arg0: tensor<100xf32>) -> tensor<100xf32> attributes {tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c"}
|
||||
func @func_with_call(%arg0: tensor<100xf32>) -> tensor<100xf32> {
|
||||
%0 = call @some_func(%arg0) : (tensor<100xf32>) -> tensor<100xf32>
|
||||
return %0 : tensor<100xf32>
|
||||
@ -545,13 +545,13 @@ func @tflite_custom_nms(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>,
|
||||
|
||||
module {
|
||||
// expected-error @+1 {{Invalid number of results from TFLite_Detection_PostProcess}}
|
||||
func @tflite_custom_nms_invalid_results(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"}
|
||||
func private @tflite_custom_nms_invalid_results(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"}
|
||||
|
||||
// expected-error @+1 {{Invalid number of arguments to TFLite_Detection_PostProcess}}
|
||||
func @tflite_custom_nms_invalid_args(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"}
|
||||
func private @tflite_custom_nms_invalid_args(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"}
|
||||
|
||||
// expected-error @+1 {{max_classes_per_detection attribute is not set or not an integer}}
|
||||
func @tflite_custom_nms_missing_func_args(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"} {
|
||||
func private @tflite_custom_nms_missing_func_args(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"} {
|
||||
%0 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
|
||||
%1 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
|
||||
%2 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
|
||||
|
@ -166,3 +166,37 @@ func @QuantizeTransposeConv(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<4xi32>)
|
||||
// PerTensor: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) : (tensor<1x32x42x128x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>) -> tensor<1x32x42x128xf32>
|
||||
// PerTensor: "tfl.transpose_conv"(%arg1, %arg0, %[[DEQUANTIZE]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeLstmCellInput
|
||||
func @QuantizeLstmCellInput(%arg0: tensor<1x28x28xf32>) -> tensor<1x28x20xf32> {
|
||||
%cst_1 = constant dense<1.0> : tensor<1x20xf32>
|
||||
%cst_2 = constant unit
|
||||
%cst_3 = constant dense<1.0> : tensor<20x20xf32>
|
||||
%cst_7 = constant dense<1.0> : tensor<20xf32>
|
||||
%cst_11 = constant dense<1.0> : tensor<20x28xf32>
|
||||
%cell_input = constant dense<0.0> : tensor<1x20xf32>
|
||||
%cell_stats = "quant.stats"(%cell_input) {layerStats = dense<[-2.73090601, 7.94872093]> : tensor<2xf32>} : (tensor<1x20xf32>) -> tensor<1x20xf32>
|
||||
%0 = "tfl.unidirectional_sequence_lstm"(%arg0,
|
||||
%cst_11, %cst_11, %cst_11, %cst_11,
|
||||
%cst_3, %cst_3, %cst_3, %cst_3,
|
||||
%cst_2, %cst_2, %cst_2,
|
||||
%cst_7, %cst_7, %cst_7, %cst_7,
|
||||
%cst_2, %cst_2,
|
||||
%cst_1, %cell_stats,
|
||||
%cst_2, %cst_2, %cst_2, %cst_2) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false}
|
||||
: ( tensor<1x28x28xf32>,
|
||||
tensor<20x28xf32>, tensor<20x28xf32>, tensor<20x28xf32>, tensor<20x28xf32>,
|
||||
tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>,
|
||||
none, none, none,
|
||||
tensor<20xf32>, tensor<20xf32>, tensor<20xf32>, tensor<20xf32>,
|
||||
none, none,
|
||||
tensor<1x20xf32>, tensor<1x20xf32>,
|
||||
none, none, none, none) -> tensor<1x28x20xf32>
|
||||
return %0 : tensor<1x28x20xf32>
|
||||
// CHECK: %[[none:.*]] = constant unit
|
||||
// CHECK: %[[cell_input:.*]] = constant dense<0.000000e+00> : tensor<1x20xf32>
|
||||
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cell_input]]) {qtype = tensor<1x20x!quant.uniform<i16:f32, 2.44140625E-4>>} : (tensor<1x20xf32>) -> tensor<1x20x!quant.uniform<i16:f32, 2.44140625E-4>>
|
||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) : (tensor<1x20x!quant.uniform<i16:f32, 2.44140625E-4>>) -> tensor<1x20xf32>
|
||||
// Checks if input 19 is correctly passed from a dequantize op.
|
||||
// CHECK: %[[lstm:.*]] = "tfl.unidirectional_sequence_lstm"(%arg0, {{(%[^%,]+, )+}}%[[dq]], %[[none]], %[[none]], %[[none]], %[[none]])
|
||||
}
|
||||
|
@ -520,6 +520,17 @@ func @PadStridedSliceNewAxisMask2(%arg0: tensor<4x64x64x1xf32>) -> tensor<1x4x64
|
||||
return %1 : tensor<1x4x64x64xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @AvoidPadStridedSliceNewAxisMaskOnUnknownShapes
|
||||
func @AvoidPadStridedSliceNewAxisMaskOnUnknownShapes(%arg0: tensor<?x?xf32>) -> tensor<1x?x?x1xf32> {
|
||||
%cst = constant dense<0> : tensor<4xi32>
|
||||
%cst_0 = constant dense<1> : tensor<4xi32>
|
||||
%0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 9 : i64, shrink_axis_mask = 0 : i64} : (tensor<?x?xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x?x?x1xf32>
|
||||
return %0 : tensor<1x?x?x1xf32>
|
||||
|
||||
// CHECK-NOT: "tf.Reshape"
|
||||
// CHECK: "tf.StridedSlice"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @StridedSliceRewriteMasks
|
||||
func @StridedSliceRewriteMasks(%arg0: tensor<8x4x16x2xf32>) -> tensor<8x4x16x1xf32> {
|
||||
%cst = "tf.Const"() {device = "", value = dense<[1, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
@ -540,37 +551,6 @@ func @StridedSliceRewriteMasks(%arg0: tensor<8x4x16x2xf32>) -> tensor<8x4x16x1xf
|
||||
return %0 : tensor<8x4x16x1xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @MatrixSetDiagV2Conversion
|
||||
func @MatrixSetDiagV2Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
|
||||
%cst = constant dense<0> : tensor<i32>
|
||||
%0 = "tf.MatrixSetDiagV2"(%arg0, %arg1, %cst) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
|
||||
return %0 : tensor<3x3xi32>
|
||||
|
||||
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @MatrixSetDiagV2NonZeroK
|
||||
func @MatrixSetDiagV2NonZeroK(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
|
||||
%cst = constant dense<1> : tensor<i32>
|
||||
%0 = "tf.MatrixSetDiagV2"(%arg0, %arg1, %cst) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
|
||||
return %0 : tensor<3x3xi32>
|
||||
|
||||
// CHECK: %[[CST:.*]] = constant dense<1> : tensor<i32>
|
||||
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiagV2"(%arg0, %arg1, %[[CST]]) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @MatrixSetDiagV3Conversion
|
||||
func @MatrixSetDiagV3Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
|
||||
%cst = constant dense<0> : tensor<i32>
|
||||
%0 = "tf.MatrixSetDiagV3"(%arg0, %arg1, %cst) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
|
||||
return %0 : tensor<3x3xi32>
|
||||
|
||||
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
||||
func @broadcast_to_f32_low_dim(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
|
||||
return %0: tensor<3x3xf32>
|
||||
|
@ -4,10 +4,10 @@ func @testSingleLstm(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<4x4
|
||||
// CHECK-LABEL: testSingleLstm
|
||||
// CHECK: %[[CST_0:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
|
||||
// CHECK: %[[CST_1:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
|
||||
// CHECK: %[[LSTM:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
// CHECK: %[[LSTM:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg1, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
|
||||
%0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const")
|
||||
%1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
%1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg1, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
return %1 : tensor<4x4xf32>
|
||||
}
|
||||
|
||||
@ -15,13 +15,13 @@ func @testMultipleLstms(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<
|
||||
// CHECK-LABEL: testMultipleLstms
|
||||
// CHECK: %[[CST_0:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
|
||||
// CHECK: %[[CST_1:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
|
||||
// CHECK: %[[LSTM_1:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
// CHECK: %[[LSTM_1:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg1, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
// CHECK: %[[CST_2:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
|
||||
// CHECK: %[[CST_3:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
|
||||
// CHECK: %[[LSTM_2:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%[[LSTM_1]], %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %[[CST_2]], %[[CST_3]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
// CHECK: %[[LSTM_2:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%[[LSTM_1]], %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg1, %[[CST_2]], %[[CST_3]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
|
||||
%0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const")
|
||||
%1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
%2 = "tfl.unidirectional_sequence_lstm"(%1, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
%1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg1, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
%2 = "tfl.unidirectional_sequence_lstm"(%1, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg1, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
return %2 : tensor<4x4xf32>
|
||||
}
|
||||
|
@ -30,9 +30,9 @@ func @while() -> tensor<1xf32>
|
||||
}) : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>) loc("WhileOp")
|
||||
return %0#1 : tensor<1xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @WhileOp_cond(
|
||||
// CHECK-LABEL: func private @WhileOp_cond(
|
||||
// CHECK: tfl.greater
|
||||
// CHECK-LABEL: func @WhileOp_body(
|
||||
// CHECK-LABEL: func private @WhileOp_body(
|
||||
// CHECK: tfl.sub
|
||||
// CHECK: tfl.add
|
||||
|
||||
@ -63,21 +63,21 @@ func @while2(%cst : tensor<i32>) -> tensor<1xf32> attributes {tf.entry_function
|
||||
return %0#1 : tensor<1xf32>
|
||||
}
|
||||
|
||||
func @WhileOp_cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>, %arg2: tensor<i32>) -> tensor<i1> attributes {sym_visibility = "private"} {
|
||||
func private @WhileOp_cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>, %arg2: tensor<i32>) -> tensor<i1> {
|
||||
%cst = constant dense<0> : tensor<i32>
|
||||
%0 = "tfl.greater"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
func @WhileOp_body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>, %arg2: tensor<i32>) -> (tensor<*xi32>, tensor<*xf32>, tensor<i32>) attributes {sym_visibility = "private"} {
|
||||
func private @WhileOp_body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>, %arg2: tensor<i32>) -> (tensor<*xi32>, tensor<*xf32>, tensor<i32>) {
|
||||
%0 = "tfl.sub"(%arg0, %arg2) {fused_activation_function = "NONE"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%1 = tfl.add %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32>
|
||||
return %0, %1, %arg2 : tensor<*xi32>, tensor<*xf32>, tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @WhileOp_cond(
|
||||
// CHECK-LABEL: func private @WhileOp_cond(
|
||||
// CHECK: tfl.greater
|
||||
// CHECK-LABEL: func @WhileOp_body(
|
||||
// CHECK-LABEL: func private @WhileOp_body(
|
||||
// CHECK: tfl.sub
|
||||
// CHECK: tfl.add
|
||||
|
||||
@ -152,14 +152,14 @@ func @rnn(%arg0: tensor<4x4x3xf32> {tf.device = "/device:CPU:0"}) -> tensor<4x?x
|
||||
// CHECK: tfl.yield
|
||||
// CHECK-SAME: (tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<4x4x3xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: func @tfl.while_cond(
|
||||
// CHECK-SAME: [[VAL_35:%.*]]: tensor<i32>, [[VAL_36:%.*]]: tensor<i32>, [[VAL_37:%.*]]: tensor<*xf32>, [[VAL_38:%.*]]: tensor<4x2xf32>, [[VAL_39:%.*]]: tensor<4x2xf32>, [[VAL_40:%.*]]: tensor<*xf32>, [[VAL_41:%.*]]: tensor<4x4x3xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
|
||||
// CHECK-LABEL: func private @tfl.while_cond(
|
||||
// CHECK-SAME: [[VAL_35:%.*]]: tensor<i32>, [[VAL_36:%.*]]: tensor<i32>, [[VAL_37:%.*]]: tensor<*xf32>, [[VAL_38:%.*]]: tensor<4x2xf32>, [[VAL_39:%.*]]: tensor<4x2xf32>, [[VAL_40:%.*]]: tensor<*xf32>, [[VAL_41:%.*]]: tensor<4x4x3xf32>) -> tensor<i1> {
|
||||
// CHECK: return
|
||||
// CHECK-SAME: tensor<i1>
|
||||
// CHECK: }
|
||||
|
||||
// CHECK-LABEL: func @tfl.while_body(
|
||||
// CHECK-SAME: [[VAL_46:%.*]]: tensor<i32>, [[VAL_47:%.*]]: tensor<i32>, [[VAL_48:%.*]]: tensor<*xf32>, [[VAL_49:%.*]]: tensor<4x2xf32>, [[VAL_50:%.*]]: tensor<4x2xf32>, [[VAL_51:%.*]]: tensor<*xf32>, [[VAL_52:%.*]]: tensor<4x4x3xf32>) -> (tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<4x4x3xf32>) attributes {sym_visibility = "private"} {
|
||||
// CHECK-LABEL: func private @tfl.while_body(
|
||||
// CHECK-SAME: [[VAL_46:%.*]]: tensor<i32>, [[VAL_47:%.*]]: tensor<i32>, [[VAL_48:%.*]]: tensor<*xf32>, [[VAL_49:%.*]]: tensor<4x2xf32>, [[VAL_50:%.*]]: tensor<4x2xf32>, [[VAL_51:%.*]]: tensor<*xf32>, [[VAL_52:%.*]]: tensor<4x4x3xf32>) -> (tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<4x4x3xf32>) {
|
||||
// CHECK: [[VAL_91:%.*]] = "tfl.cast"
|
||||
// CHECK: return
|
||||
// CHECK-SAME: [[VAL_91]], [[VAL_52]] : tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<4x4x3xf32>
|
||||
|
@ -234,6 +234,11 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
// tf.variable to model this.
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(
|
||||
mlir::TFL::CreateSplitMergedOperandsPass());
|
||||
|
||||
// Add CallOnceOp when there is a session initializer function in tf saved
|
||||
// model dialect.
|
||||
pass_manager->addPass(
|
||||
mlir::TFL::CreateInsertCallOnceOpFromSessionInitializerPass());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,78 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
namespace {
|
||||
|
||||
// This pass inserts a TFL::CallOnce op when tf_saved_model's session
|
||||
// initializer is given.
|
||||
class InsertCallOnceOpFromSessionInitializerPass
|
||||
: public mlir::PassWrapper<InsertCallOnceOpFromSessionInitializerPass,
|
||||
OperationPass<ModuleOp>> {
|
||||
private:
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
void InsertCallOnceOpFromSessionInitializerPass::runOnOperation() {
|
||||
ModuleOp module = getOperation();
|
||||
tf_saved_model::SessionInitializerOp session_init_op =
|
||||
tf_saved_model::GetSessionInitializerOp(module);
|
||||
|
||||
if (!session_init_op) return;
|
||||
|
||||
SymbolTable symbol_table(module);
|
||||
|
||||
for (auto sym_ref : session_init_op.initializers()) {
|
||||
FuncOp init_func_op = symbol_table.lookup<mlir::FuncOp>(
|
||||
sym_ref.cast<FlatSymbolRefAttr>().getValue());
|
||||
|
||||
if (!init_func_op) {
|
||||
module.emitError("no session initializer function found");
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
auto dict_attr =
|
||||
func.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
|
||||
if (!dict_attr) continue;
|
||||
|
||||
OpBuilder builder(func.getContext());
|
||||
builder.setInsertionPointToStart(&func.getBlocks().front());
|
||||
builder.create<TFL::CallOnceOp>(func.getLoc(), init_func_op.getName());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Inserts a TFL::CallOnce op when tf_saved_model's session initializer is
|
||||
// given.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
CreateInsertCallOnceOpFromSessionInitializerPass() {
|
||||
return std::make_unique<InsertCallOnceOpFromSessionInitializerPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<InsertCallOnceOpFromSessionInitializerPass> pass(
|
||||
"tfl-insert-call-once-op",
|
||||
"Insert CallOnce op when tf_saved_model's session initializer is given");
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
@ -54,7 +54,7 @@ def ExtractSingleElementAsInt32 : NativeCodeCall<
|
||||
"$_builder.getI32IntegerAttr(ExtractSingleElementAsInteger($_self.cast<ElementsAttr>()).getInt())">;
|
||||
|
||||
// Converts tensor with int64 to int32.
|
||||
def CreateTFLCastToInt32Op : NativeCodeCall<
|
||||
def CreateTFCastToInt32Op : NativeCodeCall<
|
||||
"CreateCastToInt32($0, $_loc, $_builder)">;
|
||||
|
||||
// Checks whether the given operation has static shapes and same shapes of all inputs.
|
||||
@ -193,8 +193,8 @@ def LegalizeRound : Pat<(TF_RoundOp $arg), (TFL_RoundOp $arg)>;
|
||||
def LegalizeRsqrt : Pat<(TF_RsqrtOp $arg), (TFL_RsqrtOp $arg)>;
|
||||
def LegalizeSqrt : Pat<(TF_SqrtOp $arg), (TFL_SqrtOp $arg)>;
|
||||
def LegalizeSquare : Pat<(TF_SquareOp $arg), (TFL_SquareOp $arg)>;
|
||||
def LegalizeSegmentSum : Pat<(TF_SegmentSumOp $data, I32Tensor:$segment_ids),
|
||||
(TFL_SegmentSumOp $data, $segment_ids)>;
|
||||
def LegalizeSegmentSum : Pat<(TF_SegmentSumOp $data, $segment_ids),
|
||||
(TFL_SegmentSumOp $data, (CreateTFCastToInt32Op $segment_ids))>;
|
||||
def LegalizeSelect : Pat<(TF_SelectOp $cond, $x, $y),
|
||||
(TFL_SelectOp $cond, $x, $y)>;
|
||||
def LegalizeSelectV2SameStaticShape : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y),
|
||||
@ -221,7 +221,7 @@ def LegalizeTanh : Pat<(TF_TanhOp $arg), (TFL_TanhOp $arg)>;
|
||||
|
||||
def LegalizeTranspose : Pat<(TF_TransposeOp $arg, $perm),
|
||||
(TFL_TransposeOp $arg,
|
||||
(CreateTFLCastToInt32Op $perm))>;
|
||||
(CreateTFCastToInt32Op $perm))>;
|
||||
|
||||
def LegalizeWhere : Pat<(TF_WhereOp $arg), (TFL_WhereOp $arg)>;
|
||||
def LegalizeZerosLike : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>;
|
||||
@ -309,8 +309,9 @@ def LegalizeRank : Pat<(TF_RankOp $input), (TFL_RankOp $input)>;
|
||||
def LegalizeSquaredDifference : Pat<(TF_SquaredDifferenceOp $l, $r),
|
||||
(TFL_SquaredDifferenceOp $l, $r)>;
|
||||
|
||||
def LegalizeReverseV2 : Pat<(TF_ReverseV2Op $arg0, $arg1),
|
||||
(TFL_ReverseV2Op $arg0, $arg1)>;
|
||||
def LegalizeReverseV2 : Pat<
|
||||
(TF_ReverseV2Op $arg0, $axis),
|
||||
(TFL_ReverseV2Op $arg0, (CreateTFCastToInt32Op $axis))>;
|
||||
|
||||
def LegalizeEqual : Pat<(TF_EqualOp $arg0, $arg1,
|
||||
/*incompatible_shape_error=*/ConstBoolAttrTrue),
|
||||
@ -349,11 +350,13 @@ def LegalizeCast : Pat<(TF_CastOp $arg0, BoolAttr:$arg1), (TFL_CastOp $arg0)>;
|
||||
|
||||
def LegalizeBatchToSpaceND : Pat<
|
||||
(TF_BatchToSpaceNDOp $input, $block_shape, $crops),
|
||||
(TFL_BatchToSpaceNdOp $input, $block_shape, $crops)>;
|
||||
(TFL_BatchToSpaceNdOp $input, (CreateTFCastToInt32Op $block_shape),
|
||||
(CreateTFCastToInt32Op $crops))>;
|
||||
|
||||
def LegalizeSpaceToBatchND : Pat<
|
||||
(TF_SpaceToBatchNDOp $input, $block_shape, $paddings),
|
||||
(TFL_SpaceToBatchNdOp $input, $block_shape, $paddings)>;
|
||||
(TFL_SpaceToBatchNdOp $input, (CreateTFCastToInt32Op $block_shape),
|
||||
(CreateTFCastToInt32Op $paddings))>;
|
||||
|
||||
def LegalizeSpaceToDepth : Pat<
|
||||
(TF_SpaceToDepthOp $input, $block_size, IsDataFormatNHWC:$data_format),
|
||||
@ -437,14 +440,34 @@ def LegalizeConv2DBackpropInput : Pat<
|
||||
/*stride_h=*/ ExtractI32At<1>:$strides,
|
||||
/*stride_w=*/ ExtractI32At<2>:$strides)>;
|
||||
|
||||
def IsRankZeroAttr
|
||||
: CPred<"$_self.cast<DenseElementsAttr>().getType().getRank() == 0">;
|
||||
|
||||
def HasValueZero
|
||||
: CPred<"$_self.cast<DenseElementsAttr>().getSplatValue()."
|
||||
"cast<::mlir::IntegerAttr>().getInt() == 0">;
|
||||
|
||||
// TFLite only supports MatrixSetDiag ops with scalar zero k attribute.
|
||||
def IsSupportedByTFLiteMatrixSetDiag
|
||||
: ElementsAttrBase<And<[ElementsAttr.predicate,
|
||||
IsRankZeroAttr, HasValueZero]>,
|
||||
"MatrixSetDiag attribute verification">;
|
||||
|
||||
// Attribute align doesn't matter when k is zero.
|
||||
def LegalizeMatrixSetDiag : Pat<
|
||||
(TF_MatrixSetDiagOp $input, $diagonal),
|
||||
(TF_MatrixSetDiagV3Op $input, $diagonal,
|
||||
(ConstantLikeMatcher IsSupportedByTFLiteMatrixSetDiag:$k), $align),
|
||||
(TFL_MatrixSetDiagOp $input, $diagonal)>;
|
||||
|
||||
def LegalizeScatterNd : Pat<
|
||||
(TF_ScatterNdOp I32Tensor:$indices, $updates, $shape),
|
||||
(TFL_ScatterNdOp I32Tensor:$indices, $updates, $shape)>;
|
||||
(TF_ScatterNdOp $indices, $updates, $shape),
|
||||
(TFL_ScatterNdOp (CreateTFCastToInt32Op $indices), $updates,
|
||||
(CreateTFCastToInt32Op $shape))>;
|
||||
|
||||
def LegalizeCumsum : Pat<
|
||||
(TF_CumsumOp $input, $axis, $exclusive, $reverse),
|
||||
(TFL_CumsumOp $input, $axis, $exclusive, $reverse)>;
|
||||
(TFL_CumsumOp $input, (CreateTFCastToInt32Op $axis), $exclusive, $reverse)>;
|
||||
|
||||
def LegalizeReshape : Pat<
|
||||
(TF_ReshapeOp $input, $shape),
|
||||
(TFL_ReshapeOp $input, (CreateTFCastToInt32Op $shape))>;
|
||||
|
@ -123,7 +123,8 @@ Value CreateCastToInt32(Value val, Location loc, PatternRewriter& rewriter) {
|
||||
auto shape = val.getType().dyn_cast<RankedTensorType>().getShape();
|
||||
IntegerType new_ele_type = rewriter.getIntegerType(32);
|
||||
ShapedType new_type = RankedTensorType::get(shape, new_ele_type);
|
||||
return rewriter.create<TFL::CastOp>(loc, new_type, val);
|
||||
return rewriter.createOrFold<TF::CastOp>(loc, new_type, val,
|
||||
rewriter.getBoolAttr(false));
|
||||
}
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/generated_legalize_tf.inc"
|
||||
@ -145,7 +146,6 @@ DECL_CONVERT_OP(MatMul);
|
||||
DECL_CONVERT_OP(MatrixDiagV2);
|
||||
DECL_CONVERT_OP(MatrixDiagV3);
|
||||
DECL_CONVERT_OP(Pack);
|
||||
DECL_CONVERT_OP(Reshape);
|
||||
DECL_CONVERT_OP(Split);
|
||||
DECL_CONVERT_OP(SplitV);
|
||||
DECL_CONVERT_OP(StridedSlice);
|
||||
@ -299,30 +299,6 @@ LogicalResult ConvertTFPackOp::matchAndRewrite(
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertTFReshapeOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_reshape_op = cast<TF::ReshapeOp>(op);
|
||||
|
||||
auto input = tf_reshape_op.tensor();
|
||||
auto shape = tf_reshape_op.shape();
|
||||
|
||||
ShapedType shape_type = shape.getType().cast<ShapedType>();
|
||||
// The tfl reshape's #2 operand needs to i32 tensor type, so we have to cast.
|
||||
if (!shape_type.getElementType().isSignlessInteger(32)) {
|
||||
auto new_shape = shape_type.getShape();
|
||||
IntegerType new_ele_type = rewriter.getIntegerType(32);
|
||||
ShapedType new_type = RankedTensorType::get(new_shape, new_ele_type);
|
||||
// Uses TF::CastOp to be folded if the shape input is a constant.
|
||||
shape = rewriter
|
||||
.create<TF::CastOp>(op->getLoc(), new_type, shape,
|
||||
rewriter.getBoolAttr(false))
|
||||
.y();
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<ReshapeOp>(op, tf_reshape_op.output().getType(),
|
||||
input, shape);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertTFSplitOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_split_op = cast<TF::SplitOp>(op);
|
||||
@ -792,10 +768,9 @@ void addPatterns(MLIRContext* context, OwningRewritePatternList& patterns) {
|
||||
populateWithGenerated(context, patterns);
|
||||
patterns
|
||||
.insert<ConvertTFConcatV2Op, ConvertTFMatMulOp, ConvertTFMatrixDiagV2Op,
|
||||
ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFReshapeOp,
|
||||
ConvertTFSplitOp, ConvertTFSplitVOp, ConvertTFStridedSliceOp,
|
||||
ConvertTFUnpackOp, ConvertTFAssertOp, ConvertTFRandomUniformOp>(
|
||||
context);
|
||||
ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFSplitOp,
|
||||
ConvertTFSplitVOp, ConvertTFStridedSliceOp, ConvertTFUnpackOp,
|
||||
ConvertTFAssertOp, ConvertTFRandomUniformOp>(context);
|
||||
|
||||
// Ophint python converter converted tf node pattern.
|
||||
patterns.insert<LegalizeUnidirectionalSequenceLstm,
|
||||
|
@ -62,7 +62,7 @@ void RunOnWhile(TF::WhileOp while_op) {
|
||||
auto call = builder.create<CallOp>(while_op.getLoc(), func, new_operands);
|
||||
builder.create<YieldOp>(while_op.getLoc(), call.getResults());
|
||||
// Mark old function as private so that it can be DCE'd if not called.
|
||||
func.setVisibility(SymbolTable::Visibility::Private);
|
||||
func.setPrivate();
|
||||
};
|
||||
create_region_with_call(while_op.cond_function(), new_op.cond());
|
||||
create_region_with_call(while_op.body_function(), new_op.body());
|
||||
|
@ -27,11 +27,14 @@ limitations under the License.
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/None.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
@ -286,6 +289,18 @@ static bool ShapeMatchesReduceWithKeepAxes(Value input,
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool FloatValueEquals(const Attribute &attr, double value) {
|
||||
auto fp_attr = attr.dyn_cast_or_null<DenseFPElementsAttr>();
|
||||
if (!fp_attr) return false;
|
||||
|
||||
if (fp_attr.isSplat()) {
|
||||
return fp_attr.getSplatValue<APFloat>().isExactlyValue(value);
|
||||
}
|
||||
return llvm::all_of(fp_attr.getFloatValues(), [value](const APFloat &f) {
|
||||
return f.isExactlyValue(value);
|
||||
});
|
||||
}
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc"
|
||||
|
||||
// Fuse Add with proceeding FullyConnected.
|
||||
@ -729,6 +744,144 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
||||
}
|
||||
};
|
||||
|
||||
// If the operand to a broadcastable op is a splat constant, try to replace it
|
||||
// with a 0-d constant, e.g. before this optimization,
|
||||
// %cst = constant dense<1.0> : tensor<16x16x4xf32>
|
||||
// %0 = "tfl.conv_2d"...
|
||||
// %1 = "tfl.add"(%0, %cst) : (tensor<16x16x4xf32>, tensor<16x16x4xf32>)
|
||||
// After this optimization:
|
||||
// %cst = constant dense<1.0> : tensor<f32>
|
||||
// %0 = "tfl.conv_2d"...
|
||||
// %1 = "tfl.add"(%0, %cst) : (tensor<16x16x4xf32>, tensor<f32>)
|
||||
// This pattern can enable more fusing opportunities when the binary op is
|
||||
// following conv ops.
|
||||
template <typename BinaryOpType>
|
||||
struct ScalarizeSplatConstantForBroadcastableOps
|
||||
: public OpRewritePattern<BinaryOpType> {
|
||||
using OpRewritePattern<BinaryOpType>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(BinaryOpType binary_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
DenseElementsAttr splat_elements_attr;
|
||||
if (!IsScalarizableSplatConstant(binary_op.rhs(), &splat_elements_attr)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
constexpr int kSplatOperandIndex = 1;
|
||||
auto result_type =
|
||||
binary_op.getResult().getType().template cast<ShapedType>();
|
||||
mlir::Value non_splat_operand =
|
||||
binary_op.getOperand(1 - kSplatOperandIndex);
|
||||
auto non_splat_operand_type =
|
||||
non_splat_operand.getType().cast<ShapedType>();
|
||||
// If the other operand's shape does not equal to the result shape, then we
|
||||
// cannot scalarize the splat constant because the result shape relies on
|
||||
// the splat constant op's shape for broadcasting.
|
||||
if (!non_splat_operand_type.hasStaticShape() ||
|
||||
non_splat_operand_type.getShape() != result_type.getShape() ||
|
||||
non_splat_operand_type.getRank() > 4) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// If non-splat operand is not fusable affine ops, then no need to apply
|
||||
// this transformation.
|
||||
if (!CanFuseAffineOp(non_splat_operand.getDefiningOp(), binary_op)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Creates a new scalar constant op using the splat value.
|
||||
mlir::Value splat_operand = binary_op.getOperand(kSplatOperandIndex);
|
||||
auto scalar_elements_attr = DenseElementsAttr::get(
|
||||
RankedTensorType::get({},
|
||||
splat_elements_attr.getType().getElementType()),
|
||||
splat_elements_attr.getSplatValue());
|
||||
|
||||
auto scalar_constant_op = rewriter.create<ConstantOp>(
|
||||
splat_operand.getLoc(), scalar_elements_attr.getType(),
|
||||
scalar_elements_attr);
|
||||
|
||||
binary_op.setOperand(kSplatOperandIndex, scalar_constant_op);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
// Returns true if this value is a splat constant op which can be scalarized.
|
||||
// Also returns the elements attr if this value is indeed a splat constant.
|
||||
bool IsScalarizableSplatConstant(mlir::Value value,
|
||||
DenseElementsAttr *elements_attr) const {
|
||||
if (!matchPattern(value, m_Constant(elements_attr))) {
|
||||
return false;
|
||||
}
|
||||
auto element_type = value.getType().cast<ShapedType>().getElementType();
|
||||
// Ignore per-axis quantized constants because after converting to scalar,
|
||||
// we will lose per-axis qantization parameter.
|
||||
if (element_type.isa<quant::UniformQuantizedPerAxisType>()) {
|
||||
return false;
|
||||
}
|
||||
if (IsScalar(value)) {
|
||||
return false;
|
||||
}
|
||||
return elements_attr->isSplat();
|
||||
}
|
||||
|
||||
// If this type is a scalar shaped type.
|
||||
bool IsScalar(mlir::Value value) const {
|
||||
auto type = value.getType().dyn_cast<ShapedType>();
|
||||
if (!type) {
|
||||
return false;
|
||||
}
|
||||
if (!type.hasStaticShape()) {
|
||||
return false;
|
||||
}
|
||||
return type.getNumElements() == 1;
|
||||
}
|
||||
|
||||
// Returns true if we can fuse an affine op with consuming binary op.
|
||||
bool CanFuseAffineOp(Operation *affine_op, Operation *binary_op) const {
|
||||
if (!isa_and_nonnull<TFL::Conv2DOp, TFL::DepthwiseConv2DOp,
|
||||
TFL::FullyConnectedOp>(affine_op)) {
|
||||
return false;
|
||||
}
|
||||
DenseElementsAttr value;
|
||||
// Check that bias are constants if not none.
|
||||
Value bias = affine_op->getOperand(2);
|
||||
if (!bias.getType().isa<NoneType>() &&
|
||||
!matchPattern(bias, m_Constant(&value))) {
|
||||
return false;
|
||||
}
|
||||
// If the binary op is mul/div, also check that filter is constant.
|
||||
if (isa<TFL::MulOp, TFL::DivOp>(binary_op) &&
|
||||
!matchPattern(affine_op->getOperand(1), m_Constant(&value))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// We can only fuse F32/BF16.
|
||||
auto is_fusable_type = [](Type t) {
|
||||
Type element_type = t;
|
||||
if (auto shaped_type = t.dyn_cast<ShapedType>()) {
|
||||
element_type = shaped_type.getElementType();
|
||||
}
|
||||
return element_type.isBF16() || element_type.isF32();
|
||||
};
|
||||
for (Type t : binary_op->getOperandTypes()) {
|
||||
if (!is_fusable_type(t)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
using ScalarizeSplatConstantForSub =
|
||||
ScalarizeSplatConstantForBroadcastableOps<TFL::SubOp>;
|
||||
using ScalarizeSplatConstantForAdd =
|
||||
ScalarizeSplatConstantForBroadcastableOps<TFL::AddOp>;
|
||||
using ScalarizeSplatConstantForMul =
|
||||
ScalarizeSplatConstantForBroadcastableOps<TFL::MulOp>;
|
||||
using ScalarizeSplatConstantForDiv =
|
||||
ScalarizeSplatConstantForBroadcastableOps<TFL::DivOp>;
|
||||
|
||||
struct ConvertTrivialTransposeOpToReshapeOp
|
||||
: public OpRewritePattern<TFL::TransposeOp> {
|
||||
using OpRewritePattern<TFL::TransposeOp>::OpRewritePattern;
|
||||
@ -818,6 +971,8 @@ void Optimize::runOnFunction() {
|
||||
OwningRewritePatternList phase_2_patterns;
|
||||
TFL::populateWithGenerated(ctx, phase_2_patterns);
|
||||
phase_2_patterns.insert<
|
||||
ScalarizeSplatConstantForAdd, ScalarizeSplatConstantForSub,
|
||||
ScalarizeSplatConstantForMul, ScalarizeSplatConstantForDiv,
|
||||
FuseFullyConnectedAndAdd, FuseFullyConnectedAndReluX<TFL::ReluOp, kRelu>,
|
||||
FuseFullyConnectedAndReluX<TFL::Relu6Op, kRelu6>,
|
||||
FuseFullyConnectedAndReluX<TFL::Relu1Op, kRelu1>,
|
||||
|
@ -376,13 +376,17 @@ multiclass FuseTileBroadcastIntoFollowingBinary<dag BinaryOp> {
|
||||
(BinaryOp:$result (TFL_TileOp $input, (ConstantOp $tile)),
|
||||
$operand, $act_func),
|
||||
(BinaryOp $input, $operand, $act_func),
|
||||
[(OperandsBroadcastToOutputType $input, $operand, $result)]>;
|
||||
[(OperandsBroadcastToOutputType $input, $operand, $result),
|
||||
(HasRankAtMost<4> $input),
|
||||
(HasRankAtMost<4> $operand)]>;
|
||||
|
||||
def FuseTileBroadcastToBinaryOp2#BinaryOp : Pat<
|
||||
(BinaryOp:$result $operand,
|
||||
(TFL_TileOp $input, (ConstantOp $tile)), $act_func),
|
||||
(BinaryOp $operand, $input, $act_func),
|
||||
[(OperandsBroadcastToOutputType $operand, $input, $result)]>;
|
||||
[(OperandsBroadcastToOutputType $operand, $input, $result),
|
||||
(HasRankAtMost<4> $operand),
|
||||
(HasRankAtMost<4> $input)]>;
|
||||
}
|
||||
|
||||
// Multi-pattern consisting of matching stand-alone op or op followed by relu.
|
||||
@ -427,8 +431,9 @@ foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in {
|
||||
// `input`. In other words, the shape of the `Reshape` op are not
|
||||
// changed after the transformation.
|
||||
(IsTailOfShape $rhs, $input),
|
||||
(HasRankAtMost<5> $input),
|
||||
(HasRankAtMost<5> $rhs)]>;
|
||||
(HasRankAtMost<4> $input),
|
||||
(HasRankAtMost<4> $lhs),
|
||||
(HasRankAtMost<4> $rhs)]>;
|
||||
}
|
||||
|
||||
foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
|
||||
@ -457,7 +462,10 @@ foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
|
||||
// The result of the new "BinaryOp" will have the same shape as
|
||||
// `input`. In other words, the shape of the `Reshape` op are not
|
||||
// changed after the transformation.
|
||||
(IsTailOfShape $rhs, $input)]>;
|
||||
(IsTailOfShape $rhs, $input),
|
||||
(HasRankAtMost<4> $input),
|
||||
(HasRankAtMost<4> $lhs),
|
||||
(HasRankAtMost<4> $rhs)]>;
|
||||
}
|
||||
|
||||
// Reorder the element-wise value operations and the element move operations,
|
||||
@ -495,9 +503,7 @@ def ConvertExpandDimsToReshape : Pat<
|
||||
[(AnyStaticShapeTensor $expand_dims_op)]>;
|
||||
|
||||
class FloatValueEquals<string val> : Constraint<CPred<
|
||||
"$0.isa<DenseFPElementsAttr>() && "
|
||||
"llvm::all_of($0.cast<DenseElementsAttr>().getFloatValues(), "
|
||||
"[](const APFloat& f) { return f.isExactlyValue(" # val # "); })">>;
|
||||
"FloatValueEquals($0, " # val # ")">>;
|
||||
|
||||
// ReLU patterns
|
||||
def MatchReluPattern : Pat<
|
||||
@ -570,7 +576,10 @@ foreach ActFun = [TFL_AF_Relu, TFL_AF_Relu6, TFL_AF_Relu1, TFL_AF_None] in {
|
||||
(TFL_AddOp $input,
|
||||
(TFL_AddOp (ConstantOp $a), (ConstantOp $b), TFL_AF_None),
|
||||
ActFun),
|
||||
[(HasOneUse $first_output)]>;
|
||||
[(HasOneUse $first_output),
|
||||
(HasRankAtMost<4> $input),
|
||||
(HasRankAtMost<4> $a),
|
||||
(HasRankAtMost<4> $b)]>;
|
||||
}
|
||||
|
||||
// We can eliminate Relu from Relu(SquaredDifference(x, y)),
|
||||
|
@ -94,6 +94,10 @@ std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass();
|
||||
// Creates raise custom ops pass, which legalize custom ops to TFL::CustomOp
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateRaiseCustomOpsPass();
|
||||
|
||||
// Inserts an TFL::CallOnce op when the tf_saved_model's session initialzer is
|
||||
// given.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
CreateInsertCallOnceOpFromSessionInitializerPass();
|
||||
} // namespace TFL
|
||||
|
||||
} // namespace mlir
|
||||
|
@ -139,6 +139,30 @@ struct RemoveVolatileOps : public OpRewritePattern<DequantizeOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Removes LSTMs that have dangling output.
|
||||
// LSTMs are not removed automatically becuase they are stateful ops.
|
||||
template <typename LstmOpTy>
|
||||
struct PruneUnusedLstm : public OpRewritePattern<LstmOpTy> {
|
||||
public:
|
||||
explicit PruneUnusedLstm(MLIRContext* context)
|
||||
: OpRewritePattern<LstmOpTy>(context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(LstmOpTy lstm_op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
Operation* op = lstm_op.getOperation();
|
||||
if (op->isKnownTerminator()) {
|
||||
return failure();
|
||||
}
|
||||
for (auto result : op->getOpResults()) {
|
||||
if (!result.use_empty()) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/generated_post_quantize.inc"
|
||||
|
||||
void PostQuantizePass::runOnFunction() {
|
||||
@ -147,6 +171,7 @@ void PostQuantizePass::runOnFunction() {
|
||||
auto* ctx = func.getContext();
|
||||
TFL::populateWithGenerated(ctx, patterns);
|
||||
patterns.insert<quant::FoldTrivalRequantizeOp<QuantizeOp>>(ctx);
|
||||
patterns.insert<PruneUnusedLstm<TFL::UnidirectionalSequenceLSTMOp>>(ctx);
|
||||
applyPatternsAndFoldGreedily(func, std::move(patterns));
|
||||
|
||||
if (!emit_quant_adaptor_ops_) {
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass applies quantization propagation on TFLite dialect.
|
||||
#include <cmath>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
|
||||
@ -21,10 +22,13 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
|
||||
@ -305,6 +309,52 @@ bool PrepareQuantizePass::ContainsQuantizeOps(FuncOp func) {
|
||||
using PrepareQuantStats =
|
||||
quant::ConvertStatsToQDQs<quant::QuantizeCastOp, quant::DequantizeCastOp>;
|
||||
|
||||
// Calculates the minimum power of two that is not less than the value.
|
||||
double power_of_two_bound(double value) {
|
||||
return std::pow(2, std::ceil(std::log2(value)));
|
||||
}
|
||||
|
||||
// Quantize recurrent input of LSTM with 16 bits.
|
||||
template <typename SourceOp, typename Q, typename DQ>
|
||||
struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
|
||||
public:
|
||||
explicit ConvertLstmStatsToQDQs(MLIRContext* context)
|
||||
: OpRewritePattern<SourceOp>(context, /*benefit=*/2) {}
|
||||
LogicalResult matchAndRewrite(SourceOp op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
quant::StatisticsOp stats_op = llvm::dyn_cast_or_null<quant::StatisticsOp>(
|
||||
op.input_cell_state().getDefiningOp());
|
||||
// Recurrent input is be used within an LSTM, and thus should have one use.
|
||||
if (!stats_op || !stats_op.getResult().hasOneUse()) {
|
||||
return failure();
|
||||
}
|
||||
auto stats = stats_op.layerStats().dyn_cast<DenseFPElementsAttr>();
|
||||
if (!stats) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
double max = std::max(
|
||||
std::abs(FloatAttr::getValueAsDouble(stats.getValue<APFloat>({0}))),
|
||||
std::abs(FloatAttr::getValueAsDouble(stats.getValue<APFloat>({1}))));
|
||||
double bound = power_of_two_bound(max);
|
||||
Type expressed = stats_op.getType().cast<ShapedType>().getElementType();
|
||||
// maximum value is adjusted to get a scale of power_of_two(max)/32768.
|
||||
quant::QuantizedType quant_type = quant::fakeQuantAttrsToType(
|
||||
stats_op.getLoc(), 16, -bound, bound * 32767.0 / 32768.0,
|
||||
/*narrow_range*/ false, expressed, /*is_signed*/ true);
|
||||
|
||||
rewriter.setInsertionPointAfter(stats_op);
|
||||
Type result_type = quant_type.castFromExpressedType(stats_op.getType());
|
||||
auto q = rewriter.create<Q>(stats_op.getLoc(), result_type, stats_op.arg());
|
||||
rewriter.replaceOpWithNewOp<DQ>(stats_op, stats_op.getType(), q);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
using PrepareLstmQuantStats =
|
||||
ConvertLstmStatsToQDQs<TFL::UnidirectionalSequenceLSTMOp,
|
||||
quant::QuantizeCastOp, quant::DequantizeCastOp>;
|
||||
|
||||
void PrepareQuantizePass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
MLIRContext* ctx = func.getContext();
|
||||
@ -326,7 +376,14 @@ void PrepareQuantizePass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
bool is_signed = quant_specs_.IsSignedInferenceType();
|
||||
int bit_width = quant_specs_.GetQuantizationTypeWidth();
|
||||
bool enforce_fixed_output_range = ContainsQuantizeOps(func);
|
||||
bool quantization_aware_training_mode = ContainsQuantizeOps(func);
|
||||
// Enforce fixed output range for post-training quantization and
|
||||
// when the model has quantization emulation ops, unless it was disabled
|
||||
// explicitly by the flag.
|
||||
bool enforced_output_range =
|
||||
(quant_specs_.post_training_quantization ||
|
||||
quantization_aware_training_mode) &&
|
||||
!quant_specs_.disable_enforced_fixed_output_range;
|
||||
if (is_signed) {
|
||||
patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
|
||||
// Convert quant stats to int8 quantization parameters.
|
||||
@ -337,6 +394,7 @@ void PrepareQuantizePass::runOnFunction() {
|
||||
// Currently, only activation stats are imported, so narrow_range = false.
|
||||
patterns.insert<PrepareQuantStats>(bit_width, false, false, ctx);
|
||||
}
|
||||
patterns.insert<PrepareLstmQuantStats>(ctx);
|
||||
applyPatternsAndFoldGreedily(func, std::move(patterns));
|
||||
|
||||
SanityCheckAndAdjustment(func);
|
||||
@ -345,8 +403,7 @@ void PrepareQuantizePass::runOnFunction() {
|
||||
// values (tensors).
|
||||
ApplyQuantizationParamsPropagation(
|
||||
func, is_signed, disable_per_channel || quant_specs_.disable_per_channel,
|
||||
GetOpQuantSpec,
|
||||
enforce_fixed_output_range || quant_specs_.post_training_quantization);
|
||||
GetOpQuantSpec, enforced_output_range);
|
||||
|
||||
ConvertMlirQuantOpsToTFLQuantOps(func);
|
||||
}
|
||||
|
@ -64,6 +64,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||
|
||||
#define DEBUG_TYPE "tf-tfl-legalization"
|
||||
@ -518,9 +519,10 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
explicit ConvertTFStridedSlice(MLIRContext *context)
|
||||
: RewritePattern(TF::StridedSliceOp::getOperationName(), 2, context) {}
|
||||
|
||||
LogicalResult RewriteNewAxisMask(Operation *op, uint64_t new_axis_mask,
|
||||
LogicalResult RewriteNewAxisMask(Operation *op,
|
||||
PatternRewriter &rewriter) const {
|
||||
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
|
||||
uint64_t new_axis_mask = strided_slice_op.new_axis_mask();
|
||||
|
||||
// Insert a new reshape op.
|
||||
Value original_input = strided_slice_op.input();
|
||||
@ -528,48 +530,51 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
original_input.getType().cast<RankedTensorType>();
|
||||
const ArrayRef<int64_t> &original_input_shape =
|
||||
original_input_type.getShape();
|
||||
SmallVector<int64_t, 4> new_shape;
|
||||
SmallVector<int64_t, 4> revised_shape;
|
||||
int index = 0;
|
||||
const int original_input_rank = original_input_shape.size();
|
||||
while (index < original_input_rank || new_axis_mask) {
|
||||
if (new_axis_mask & 1) {
|
||||
new_shape.emplace_back(1);
|
||||
revised_shape.emplace_back(1);
|
||||
} else {
|
||||
new_shape.emplace_back(original_input_shape[index++]);
|
||||
revised_shape.emplace_back(original_input_shape[index++]);
|
||||
}
|
||||
new_axis_mask >>= 1;
|
||||
}
|
||||
|
||||
const int dim_size = new_shape.size();
|
||||
if (failed(TF::VerifyShapeOfReshapeOp(revised_shape))) return failure();
|
||||
|
||||
const int dim_size = revised_shape.size();
|
||||
Location loc = strided_slice_op.getLoc();
|
||||
auto shape_type =
|
||||
RankedTensorType::get({dim_size}, rewriter.getIntegerType(32));
|
||||
SmallVector<Attribute, 4> result_shape_data(dim_size);
|
||||
for (int i = 0; i < dim_size; ++i) {
|
||||
result_shape_data[i] =
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(new_shape[i]));
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(revised_shape[i]));
|
||||
}
|
||||
|
||||
auto shape_attr = DenseElementsAttr::get(shape_type, result_shape_data);
|
||||
auto shape = rewriter.create<ConstantOp>(loc, shape_type, shape_attr);
|
||||
auto new_output_type =
|
||||
RankedTensorType::get(new_shape, original_input_type.getElementType());
|
||||
auto revised_output_type = RankedTensorType::get(
|
||||
revised_shape, original_input_type.getElementType());
|
||||
TF::ReshapeOp reshape = rewriter.create<TF::ReshapeOp>(
|
||||
loc, new_output_type, original_input, shape);
|
||||
loc, revised_output_type, original_input, shape);
|
||||
|
||||
// Replace the original strided_slice.
|
||||
uint64_t new_begin_mask = strided_slice_op.begin_mask();
|
||||
uint64_t new_end_mask = strided_slice_op.end_mask();
|
||||
uint64_t revised_begin_mask = strided_slice_op.begin_mask();
|
||||
uint64_t revised_end_mask = strided_slice_op.end_mask();
|
||||
// Since we expand the dims, we need to apply them to the begin_mask &
|
||||
// end_mask.
|
||||
new_begin_mask |= strided_slice_op.new_axis_mask();
|
||||
new_end_mask |= strided_slice_op.new_axis_mask();
|
||||
revised_begin_mask |= strided_slice_op.new_axis_mask();
|
||||
revised_end_mask |= strided_slice_op.new_axis_mask();
|
||||
|
||||
auto attribute_type = rewriter.getIntegerType(64);
|
||||
rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
|
||||
op, strided_slice_op.getType(), reshape, strided_slice_op.begin(),
|
||||
strided_slice_op.end(), strided_slice_op.strides(),
|
||||
rewriter.getIntegerAttr(attribute_type, new_begin_mask),
|
||||
rewriter.getIntegerAttr(attribute_type, new_end_mask),
|
||||
rewriter.getIntegerAttr(attribute_type, revised_begin_mask),
|
||||
rewriter.getIntegerAttr(attribute_type, revised_end_mask),
|
||||
rewriter.getIntegerAttr(attribute_type,
|
||||
strided_slice_op.ellipsis_mask()),
|
||||
rewriter.getI64IntegerAttr(0),
|
||||
@ -578,10 +583,16 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult RewriteEllipsisMask(Operation *op, uint64_t ellipsis_mask,
|
||||
LogicalResult RewriteEllipsisMask(Operation *op,
|
||||
PatternRewriter &rewriter) const {
|
||||
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
|
||||
|
||||
uint64_t ellipsis_mask = strided_slice_op.ellipsis_mask();
|
||||
uint64_t shrink_axis_mask = strided_slice_op.shrink_axis_mask();
|
||||
|
||||
// Enforce operator precedence.
|
||||
shrink_axis_mask &= ~ellipsis_mask;
|
||||
|
||||
DenseIntElementsAttr begin_dense_elem_attr;
|
||||
Value begin = strided_slice_op.begin();
|
||||
auto begin_ranked_attr_type = begin.getType().dyn_cast<RankedTensorType>();
|
||||
@ -623,8 +634,9 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
|
||||
int64_t begin_mask = strided_slice_op.begin_mask();
|
||||
int64_t end_mask = strided_slice_op.end_mask();
|
||||
int64_t new_begin_mask = 0;
|
||||
int64_t new_end_mask = 0;
|
||||
int64_t revised_begin_mask = 0;
|
||||
int64_t revised_end_mask = 0;
|
||||
int64_t revised_shrink_axis_mask = 0;
|
||||
|
||||
SmallVector<int32_t, 4> padded_begin;
|
||||
SmallVector<int32_t, 4> padded_end;
|
||||
@ -637,16 +649,18 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(index));
|
||||
padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(index));
|
||||
padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(index));
|
||||
if ((begin_mask >> index) & 1) new_begin_mask |= (1 << new_index);
|
||||
if ((end_mask >> index) & 1) new_end_mask |= (1 << new_index);
|
||||
if ((begin_mask >> index) & 1) revised_begin_mask |= (1 << new_index);
|
||||
if ((end_mask >> index) & 1) revised_end_mask |= (1 << new_index);
|
||||
if ((shrink_axis_mask >> index) & 1)
|
||||
revised_shrink_axis_mask |= (1 << new_index);
|
||||
++index;
|
||||
++new_index;
|
||||
}
|
||||
|
||||
// Ellipsis.
|
||||
for (; new_index < index + ellipsis_filled_dim_size; ++new_index) {
|
||||
new_begin_mask |= (1 << new_index);
|
||||
new_end_mask |= (1 << new_index);
|
||||
revised_begin_mask |= (1 << new_index);
|
||||
revised_end_mask |= (1 << new_index);
|
||||
|
||||
// Mimic the begin/end/strides mask behavior.
|
||||
padded_begin.push_back(0);
|
||||
@ -663,8 +677,10 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(index));
|
||||
padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(index));
|
||||
|
||||
if ((begin_mask >> index) & 1) new_begin_mask |= (1 << new_index);
|
||||
if ((end_mask >> index) & 1) new_end_mask |= (1 << new_index);
|
||||
if ((begin_mask >> index) & 1) revised_begin_mask |= (1 << new_index);
|
||||
if ((end_mask >> index) & 1) revised_end_mask |= (1 << new_index);
|
||||
if ((shrink_axis_mask >> index) & 1)
|
||||
revised_shrink_axis_mask |= (1 << new_index);
|
||||
|
||||
++index;
|
||||
++new_index;
|
||||
@ -687,13 +703,12 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
|
||||
op, strided_slice_op.getType(), input, begin_op.getResult(),
|
||||
end_op.getResult(), stride_op.getResult(),
|
||||
rewriter.getIntegerAttr(attribute_type, new_begin_mask),
|
||||
rewriter.getIntegerAttr(attribute_type, new_end_mask),
|
||||
/*ellipsis_maks=*/rewriter.getI64IntegerAttr(0),
|
||||
rewriter.getIntegerAttr(attribute_type, revised_begin_mask),
|
||||
rewriter.getIntegerAttr(attribute_type, revised_end_mask),
|
||||
/*ellipsis_mask=*/rewriter.getI64IntegerAttr(0),
|
||||
rewriter.getIntegerAttr(attribute_type,
|
||||
strided_slice_op.new_axis_mask()),
|
||||
rewriter.getIntegerAttr(attribute_type,
|
||||
strided_slice_op.shrink_axis_mask()));
|
||||
rewriter.getIntegerAttr(attribute_type, revised_shrink_axis_mask));
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -701,20 +716,18 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
PatternRewriter &rewriter) const override {
|
||||
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
|
||||
|
||||
// TODO(renjieliu): Consider expand the transformation for shrink mask as
|
||||
// well.
|
||||
if (strided_slice_op.shrink_axis_mask()) return failure();
|
||||
|
||||
// Handle new axis mask.
|
||||
uint64_t new_axis_mask = strided_slice_op.new_axis_mask();
|
||||
if (new_axis_mask != 0) {
|
||||
return RewriteNewAxisMask(strided_slice_op, new_axis_mask, rewriter);
|
||||
if (strided_slice_op.new_axis_mask() != 0) {
|
||||
// We currently don't handle simultaneous shrink_ and new_axis masks.
|
||||
if (strided_slice_op.shrink_axis_mask()) {
|
||||
return failure();
|
||||
}
|
||||
return RewriteNewAxisMask(strided_slice_op, rewriter);
|
||||
}
|
||||
|
||||
// Handle ellipsis mask.
|
||||
uint64_t ellipsis_mask = strided_slice_op.ellipsis_mask();
|
||||
if (ellipsis_mask != 0) {
|
||||
return RewriteEllipsisMask(strided_slice_op, ellipsis_mask, rewriter);
|
||||
if (strided_slice_op.ellipsis_mask() != 0) {
|
||||
return RewriteEllipsisMask(strided_slice_op, rewriter);
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
@ -182,7 +182,7 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
|
||||
b.create<ReturnOp>(yield_op->getLoc(), args);
|
||||
yield_op->erase();
|
||||
symbol_table.insert(outlined_func);
|
||||
outlined_func.setVisibility(FuncOp::Visibility::Private);
|
||||
outlined_func.setPrivate();
|
||||
return outlined_func;
|
||||
};
|
||||
|
||||
|
@ -57,6 +57,8 @@ mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) {
|
||||
return mlir::ComplexType::get(builder.getF64Type());
|
||||
case tflite::TensorType_INT8:
|
||||
return builder.getIntegerType(8);
|
||||
case tflite::TensorType_UINT64:
|
||||
return builder.getIntegerType(64, /*isSigned=*/false);
|
||||
}
|
||||
}
|
||||
|
||||
@ -86,6 +88,8 @@ tensorflow::DataType TflTypeToTfType(tflite::TensorType type) {
|
||||
return tensorflow::DT_STRING;
|
||||
case tflite::TensorType_UINT8:
|
||||
return tensorflow::DT_UINT8;
|
||||
case tflite::TensorType_UINT64:
|
||||
return tensorflow::DT_UINT64;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -51,7 +51,7 @@ static ConfigProto::Experimental::MlirBridgeRollout GetUserRequest(
|
||||
}
|
||||
|
||||
MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy(
|
||||
absl::optional<ConfigProto> config_proto) {
|
||||
const tensorflow::Graph& graph, absl::optional<ConfigProto> config_proto) {
|
||||
switch (GetUserRequest(config_proto)) {
|
||||
case ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED:
|
||||
return MlirBridgeRolloutPolicy::kEnabledByUser;
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_MLIR_BRIDGE_ROLLOUT_POLICY_H_
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -46,6 +47,7 @@ enum class MlirBridgeRolloutPolicy {
|
||||
// The config_proto param is a required input for all TF1 graphs but it is
|
||||
// redundant for TF2 graphs.
|
||||
MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy(
|
||||
const tensorflow::Graph& graph,
|
||||
absl::optional<tensorflow::ConfigProto> config_proto);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
@ -32,10 +33,20 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
auto* shadow_run_success =
|
||||
monitoring::Counter<0>::New("/tensorflow/core/mlir_shadow_run_success",
|
||||
"Success count of MLIR shadow runs");
|
||||
|
||||
auto* shadow_run_failure = monitoring::Counter<2>::New(
|
||||
"/tensorflow/core/mlir_shadow_run_failure",
|
||||
"Failure count of MLIR shadow runs", "kind", "name");
|
||||
|
||||
static inline absl::string_view StringRefToView(llvm::StringRef ref) {
|
||||
return {ref.data(), ref.size()};
|
||||
}
|
||||
@ -109,7 +120,7 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
// Skip conversion from Graph to MLIR if none of the passes are enabled.
|
||||
const bool is_enabled =
|
||||
llvm::any_of(registry_->passes(), [&](auto& pass_registration) -> bool {
|
||||
return pass_registration.pass->IsEnabled(config_proto);
|
||||
return pass_registration.pass->IsEnabled(config_proto, **graph);
|
||||
});
|
||||
|
||||
if (!is_enabled) {
|
||||
@ -123,6 +134,17 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
<< "(registered " << registry_->passes().size()
|
||||
<< " passes)";
|
||||
|
||||
// For scenarios when the new bridge is enabled by analysis we need to make
|
||||
// sure that MLIR transformations are executed in a shadow mode.
|
||||
// In this case, no changes should be done to the original `graph`
|
||||
// and no failures propagated to the user.
|
||||
bool enabled_by_analysis =
|
||||
mlir_rollout_policy_(**graph, config_proto) ==
|
||||
MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis;
|
||||
if (enabled_by_analysis) {
|
||||
LOG_FIRST_N(INFO, 1) << "Shadow run of MLIR enabled after graph analysis";
|
||||
}
|
||||
|
||||
GraphDebugInfo debug_info;
|
||||
mlir::MLIRContext context;
|
||||
RegisterDialects(context.getDialectRegistry());
|
||||
@ -130,10 +152,21 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
import_config.graph_as_function = true;
|
||||
import_config.control_outputs = *control_ret_node_names;
|
||||
import_config.upgrade_legacy = true;
|
||||
TF_ASSIGN_OR_RETURN(auto module_ref,
|
||||
ConvertGraphToMlir(**graph, debug_info, *flib_def,
|
||||
import_config, &context));
|
||||
|
||||
auto module_ref_status = ConvertGraphToMlir(**graph, debug_info, *flib_def,
|
||||
import_config, &context);
|
||||
if (!module_ref_status.ok()) {
|
||||
if (enabled_by_analysis) {
|
||||
shadow_run_failure->GetCell("graph_to_mlir", "")->IncrementBy(1);
|
||||
|
||||
// Do not fail, let the old bridge to run on the original `graph`.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return module_ref_status.status();
|
||||
}
|
||||
|
||||
auto module_ref = std::move(module_ref_status.ValueOrDie());
|
||||
AddDevicesToOp(*module_ref, &device_set);
|
||||
|
||||
for (auto& pass_registration : registry_->passes()) {
|
||||
@ -144,7 +177,17 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
DumpModule(*module_ref, llvm::formatv("mlir_{0}_before_", name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(pass_registration.pass->Run(config_proto, *module_ref));
|
||||
auto pass_status =
|
||||
pass_registration.pass->Run(config_proto, *module_ref, **graph);
|
||||
if (!pass_status.ok()) {
|
||||
if (enabled_by_analysis) {
|
||||
shadow_run_failure->GetCell("pass", name.str())->IncrementBy(1);
|
||||
// Do not fail, let the old bridge to run on the original `graph`.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return pass_status;
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
DumpModule(*module_ref, llvm::formatv("mlir_{0}_after_", name));
|
||||
@ -153,6 +196,25 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
|
||||
GraphExportConfig export_config;
|
||||
absl::flat_hash_set<Node*> control_ret_nodes;
|
||||
|
||||
// In case MLIR is enabled by analysis, verify that MLIR could be converted
|
||||
// back to TF graph. Original `graph` must stay the same.
|
||||
if (enabled_by_analysis) {
|
||||
auto empty_graph = std::make_unique<Graph>(OpRegistry::Global());
|
||||
FunctionLibraryDefinition empty_flib = empty_graph->flib_def();
|
||||
|
||||
auto mlir_to_graph_status =
|
||||
ConvertMlirToGraph(*module_ref, export_config, &empty_graph,
|
||||
&empty_flib, &control_ret_nodes);
|
||||
if (mlir_to_graph_status.ok()) {
|
||||
shadow_run_success->GetCell()->IncrementBy(1);
|
||||
} else {
|
||||
shadow_run_failure->GetCell("mlir_to_graph", "")->IncrementBy(1);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
ConvertMlirToGraph(*module_ref, export_config, graph, flib_def,
|
||||
&control_ret_nodes),
|
||||
@ -183,7 +245,7 @@ Status MlirV1CompatGraphOptimizationPass::Run(
|
||||
const bool is_enabled =
|
||||
absl::c_any_of(registry_->passes(), [&](auto& pass_registration) -> bool {
|
||||
return pass_registration.pass->IsEnabled(
|
||||
options.session_options->config);
|
||||
options.session_options->config, **options.graph);
|
||||
});
|
||||
|
||||
if (!is_enabled) {
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user