Merge branch 'master' into dlpack
This commit is contained in:
commit
0dad831803
10
.bazelrc
10
.bazelrc
@ -319,11 +319,13 @@ build:xla --define=with_xla_support=true
|
||||
# BEGIN TF REMOTE BUILD EXECUTION OPTIONS
|
||||
# Options when using remote execution
|
||||
# WARNING: THESE OPTIONS WONT WORK IF YOU DO NOT HAVE PROPER AUTHENTICATION AND PERMISSIONS
|
||||
|
||||
# Flag to enable remote config
|
||||
common --experimental_repo_remote_exec
|
||||
|
||||
build:rbe --action_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1
|
||||
build:rbe --auth_enabled=true
|
||||
build:rbe --auth_scope=https://www.googleapis.com/auth/cloud-source-tools
|
||||
build:rbe --google_default_credentials
|
||||
build:rbe --bes_backend=buildeventservice.googleapis.com
|
||||
build:rbe --bes_best_effort=false
|
||||
build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations"
|
||||
build:rbe --bes_timeout=600s
|
||||
build:rbe --define=EXECUTOR=remote
|
||||
@ -336,7 +338,7 @@ build:rbe --spawn_strategy=remote,worker,standalone,local
|
||||
test:rbe --test_env=USER=anon
|
||||
# Attempt to minimize the amount of data transfer between bazel and the remote
|
||||
# workers:
|
||||
build:rbe --experimental_inmemory_jdeps_files --experimental_inmemory_dotd_files --experimental_remote_download_outputs=toplevel
|
||||
build:rbe --remote_download_toplevel
|
||||
|
||||
build:rbe_linux --config=rbe
|
||||
build:rbe_linux --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin"
|
||||
|
26
.github/ISSUE_TEMPLATE/00-bug-issue.md
vendored
26
.github/ISSUE_TEMPLATE/00-bug-issue.md
vendored
@ -10,13 +10,20 @@ labels: 'type:bug'
|
||||
we only address code/doc bugs, performance issues, feature requests and
|
||||
build/installation issues on GitHub. tag:bug_template</em>
|
||||
|
||||
**System information** - Have I written custom code (as opposed to using a stock
|
||||
example script provided in TensorFlow): - OS Platform and Distribution (e.g.,
|
||||
Linux Ubuntu 16.04): - Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
|
||||
the issue happens on mobile device: - TensorFlow installed from (source or
|
||||
binary): - TensorFlow version (use command below): - Python version: - Bazel
|
||||
version (if compiling from source): - GCC/Compiler version (if compiling from
|
||||
source): - CUDA/cuDNN version: - GPU model and memory:
|
||||
**System information**
|
||||
- Have I written custom code (as opposed to using a stock
|
||||
example script provided in TensorFlow):
|
||||
- OS Platform and Distribution (e.g.,
|
||||
Linux Ubuntu 16.04):
|
||||
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
|
||||
the issue happens on mobile device:
|
||||
- TensorFlow installed from (source or
|
||||
binary): - TensorFlow version (use command below):
|
||||
- Python version: - Bazel
|
||||
version (if compiling from source):
|
||||
- GCC/Compiler version (if compiling from
|
||||
source):
|
||||
- CUDA/cuDNN version: - GPU model and memory:
|
||||
|
||||
You can collect some of this information using our environment capture
|
||||
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
|
||||
@ -28,8 +35,9 @@ tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
|
||||
|
||||
**Describe the expected behavior**
|
||||
|
||||
**Code to reproduce the issue** Provide a reproducible test case that is the
|
||||
bare minimum necessary to generate the problem.
|
||||
**Standalone code to reproduce the issue**
|
||||
Provide a reproducible test case that is the bare minimum necessary to generate
|
||||
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
|
||||
|
||||
**Other info / logs** Include any logs or source code that would be helpful to
|
||||
diagnose the problem. If including tracebacks, please include the full
|
||||
|
@ -17,8 +17,14 @@ labels: 'comp:lite'
|
||||
# Copy and paste here
|
||||
```
|
||||
|
||||
**Standalone code to reproduce the issue**
|
||||
Provide a reproducible test case that is the bare minimum necessary to generate
|
||||
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
|
||||
|
||||
Also, please include a link to a GraphDef or the model if possible.
|
||||
|
||||
**Any other info / logs**
|
||||
|
||||
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
|
||||
Include any logs or source code that would be helpful to diagnose the problem.
|
||||
If including tracebacks, please include the full traceback. Large logs and files
|
||||
should be attached.
|
||||
|
@ -1,6 +1,7 @@
|
||||
---
|
||||
name: TensorFlow Lite New Converter Issue
|
||||
about: Use this template for reporting issues during model conversion to TFLite
|
||||
labels: 'TFLiteConverter'
|
||||
|
||||
---
|
||||
|
||||
@ -12,6 +13,7 @@ about: Use this template for reporting issues during model conversion to TFLite
|
||||
|
||||
|
||||
**Command used to run the converter or code if you’re using the Python API**
|
||||
If possible, please share a link to Colab/Jupyter/any notebook.
|
||||
|
||||
```
|
||||
# Copy and paste here the exact command
|
||||
|
26
.github/ISSUE_TEMPLATE/80-performance-issue.md
vendored
26
.github/ISSUE_TEMPLATE/80-performance-issue.md
vendored
@ -11,13 +11,20 @@ As per our
|
||||
we only address code/doc bugs, performance issues, feature requests and
|
||||
build/installation issues on GitHub. tag:performance_template</em>
|
||||
|
||||
**System information** - Have I written custom code (as opposed to using a stock
|
||||
example script provided in TensorFlow): - OS Platform and Distribution (e.g.,
|
||||
Linux Ubuntu 16.04): - Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
|
||||
the issue happens on mobile device: - TensorFlow installed from (source or
|
||||
binary): - TensorFlow version (use command below): - Python version: - Bazel
|
||||
version (if compiling from source): - GCC/Compiler version (if compiling from
|
||||
source): - CUDA/cuDNN version: - GPU model and memory:
|
||||
**System information**
|
||||
- Have I written custom code (as opposed to using a stock
|
||||
example script provided in TensorFlow):
|
||||
- OS Platform and Distribution (e.g.,
|
||||
Linux Ubuntu 16.04):
|
||||
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
|
||||
the issue happens on mobile device:
|
||||
- TensorFlow installed from (source or
|
||||
binary): - TensorFlow version (use command below):
|
||||
- Python version: - Bazel
|
||||
version (if compiling from source):
|
||||
- GCC/Compiler version (if compiling from
|
||||
source):
|
||||
- CUDA/cuDNN version: - GPU model and memory:
|
||||
|
||||
You can collect some of this information using our environment capture
|
||||
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
|
||||
@ -29,8 +36,9 @@ tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
|
||||
|
||||
**Describe the expected behavior**
|
||||
|
||||
**Code to reproduce the issue** Provide a reproducible test case that is the
|
||||
bare minimum necessary to generate the problem.
|
||||
**Standalone code to reproduce the issue**
|
||||
Provide a reproducible test case that is the bare minimum necessary to generate
|
||||
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
|
||||
|
||||
**Other info / logs** Include any logs or source code that would be helpful to
|
||||
diagnose the problem. If including tracebacks, please include the full
|
||||
|
25
WORKSPACE
25
WORKSPACE
@ -113,3 +113,28 @@ http_archive(
|
||||
"https://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip",
|
||||
],
|
||||
)
|
||||
|
||||
# Required for dependency @com_github_grpc_grpc
|
||||
|
||||
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
|
||||
|
||||
grpc_deps()
|
||||
|
||||
load(
|
||||
"@build_bazel_rules_apple//apple:repositories.bzl",
|
||||
"apple_rules_dependencies",
|
||||
)
|
||||
|
||||
apple_rules_dependencies()
|
||||
|
||||
load(
|
||||
"@build_bazel_apple_support//lib:repositories.bzl",
|
||||
"apple_support_dependencies",
|
||||
)
|
||||
|
||||
apple_support_dependencies()
|
||||
|
||||
load("@upb//bazel:repository_defs.bzl", "bazel_version_repository")
|
||||
|
||||
bazel_version_repository(name = "bazel_version")
|
||||
|
||||
|
@ -49,7 +49,7 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
|
||||
_TF_WORKSPACE_ROOT = ''
|
||||
_TF_BAZELRC = ''
|
||||
_TF_CURRENT_BAZEL_VERSION = None
|
||||
_TF_MIN_BAZEL_VERSION = '1.2.1'
|
||||
_TF_MIN_BAZEL_VERSION = '2.0.0'
|
||||
_TF_MAX_BAZEL_VERSION = '2.0.0'
|
||||
|
||||
NCCL_LIB_PATHS = [
|
||||
|
@ -505,13 +505,15 @@ selects.config_setting_group(
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = [
|
||||
# To pass open source testing in the pip Kokoros.
|
||||
"//bazel_pip/tensorflow/...",
|
||||
"//learning/brain/swift/x10/...",
|
||||
"//perftools/accelerators/xprof/api/...",
|
||||
"//third_party/py/autograph/...",
|
||||
"//third_party/swift/tensorflow/x10/...",
|
||||
"//tensorflow/...",
|
||||
"//tensorflow_estimator/python/estimator/...",
|
||||
"//tensorflow_models/official/...",
|
||||
"//third_party/py/autograph/...",
|
||||
"//third_party/swift/tensorflow/x10/...",
|
||||
],
|
||||
)
|
||||
|
||||
@ -545,8 +547,8 @@ cc_library(
|
||||
name = "grpc",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
":linux_s390x": ["@grpc//:grpc_unsecure"],
|
||||
"//conditions:default": ["@grpc"],
|
||||
":linux_s390x": ["@com_github_grpc_grpc//:grpc_unsecure"],
|
||||
"//conditions:default": ["@com_github_grpc_grpc//:grpc"],
|
||||
}),
|
||||
)
|
||||
|
||||
@ -554,8 +556,8 @@ cc_library(
|
||||
name = "grpc++",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
":linux_s390x": ["@grpc//:grpc++_unsecure"],
|
||||
"//conditions:default": ["@grpc//:grpc++"],
|
||||
":linux_s390x": ["@com_github_grpc_grpc//:grpc++_unsecure"],
|
||||
"//conditions:default": ["@com_github_grpc_grpc//:grpc++"],
|
||||
}),
|
||||
)
|
||||
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/net.h"
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
@ -816,12 +817,15 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
||||
|
||||
const int num_inputs = input_shapes->num_items;
|
||||
NodeDef node_def;
|
||||
node_def.set_name(tfe_op->operation.Name());
|
||||
node_def.set_op(tfe_op->operation.Name());
|
||||
node_def.set_name(tfe_op->operation->Name());
|
||||
node_def.set_op(tfe_op->operation->Name());
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
node_def.add_input("dummy_input");
|
||||
}
|
||||
tfe_op->operation.Attrs().FillAttrValueMap(node_def.mutable_attr());
|
||||
tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
tfe_op->operation.get())
|
||||
->Attrs()
|
||||
.FillAttrValueMap(node_def.mutable_attr());
|
||||
|
||||
const tensorflow::OpRegistrationData* op_reg_data;
|
||||
status->status =
|
||||
|
@ -28,6 +28,8 @@ tf_cuda_library(
|
||||
"c_api_debug.cc",
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
"operation_interface.cc",
|
||||
"operation_interface.h",
|
||||
"tensor_handle_interface.h",
|
||||
],
|
||||
hdrs = ["c_api.h"],
|
||||
@ -56,6 +58,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/platform:casts",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
@ -93,6 +96,7 @@ filegroup(
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
"dlpack.h",
|
||||
"operation_interface.h",
|
||||
"tensor_handle_interface.h",
|
||||
],
|
||||
visibility = [
|
||||
@ -105,6 +109,7 @@ tf_cuda_library(
|
||||
name = "c_api_internal",
|
||||
srcs = [
|
||||
"c_api_experimental.h",
|
||||
"operation_interface.h",
|
||||
"tensor_handle_interface.h",
|
||||
],
|
||||
hdrs = ["c_api_internal.h"],
|
||||
@ -129,6 +134,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/common_runtime/eager:kernel_and_device",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"@com_google_absl//absl/container:fixed_array",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -27,7 +27,6 @@ limitations under the License.
|
||||
// clang-format on
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/fixed_array.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
@ -95,14 +94,6 @@ using tensorflow::string;
|
||||
|
||||
namespace {
|
||||
|
||||
const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) {
|
||||
const tensorflow::OpDef* op_def = op->operation.OpDef();
|
||||
if (op_def) return op_def;
|
||||
status->status =
|
||||
tensorflow::OpDefForOp(op->operation.Name().c_str(), &op_def);
|
||||
return op_def;
|
||||
}
|
||||
|
||||
bool IsCPU(
|
||||
absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> variant) {
|
||||
if (VariantDeviceIsCustom(variant)) {
|
||||
@ -1125,9 +1116,13 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
|
||||
return retval;
|
||||
} else {
|
||||
tensorflow::Tensor tensor;
|
||||
if (IsCPU(handle_->device())) {
|
||||
if (IsCPU(handle_->device()) || handle_->HasLocalMirror(nullptr)) {
|
||||
const tensorflow::Tensor* src = nullptr;
|
||||
if (handle_->HasLocalMirror(nullptr)) {
|
||||
*status = handle_->TensorFromDevice(nullptr, &src);
|
||||
} else {
|
||||
*status = handle_->Tensor(&src);
|
||||
}
|
||||
if (!status->ok()) return nullptr;
|
||||
tensor = *src;
|
||||
} else {
|
||||
@ -1135,6 +1130,13 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
|
||||
CHECK_NE(ctx, nullptr);
|
||||
*status = handle_->CopyToDevice(*ctx, ctx->HostCPU(), &tensor);
|
||||
if (!status->ok()) return nullptr;
|
||||
if (handle_->ImplicitMirroring()) {
|
||||
*status = handle_->AddEmptyLocalMirror(nullptr);
|
||||
if (!status->ok()) return nullptr;
|
||||
Tensor mirror = tensor;
|
||||
*status = handle_->SetTensor(std::move(mirror), nullptr);
|
||||
if (!status->ok()) return nullptr;
|
||||
}
|
||||
}
|
||||
return tensorflow::TF_TensorFromTensor(tensor, status);
|
||||
}
|
||||
@ -1199,18 +1201,11 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
||||
}
|
||||
|
||||
if (dtype == TF_STRING || dtype == TF_RESOURCE ||
|
||||
!tensorflow::DataTypeCanUseMemcpy(
|
||||
static_cast<tensorflow::DataType>(dtype))) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Trying to create a tensor with a pointer to non-pod memory.");
|
||||
deallocator(data, len, deallocator_arg);
|
||||
return nullptr;
|
||||
}
|
||||
// TODO(apassos) do we need to wrap the deallocator here to make sure to sync
|
||||
// the device?
|
||||
TF_ManagedBuffer* buf =
|
||||
new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
|
||||
new TF_ManagedBuffer(data, len, deallocator, deallocator_arg,
|
||||
/*owns_memory=*/false);
|
||||
|
||||
tensorflow::Tensor t(static_cast<tensorflow::DataType>(dtype),
|
||||
tensorflow::TensorShape(dimvec), buf);
|
||||
@ -1261,9 +1256,8 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
||||
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op> new_op(
|
||||
new TFE_Op{tensorflow::EagerOperation(ctx->context)});
|
||||
status->status =
|
||||
new_op->operation.Reset(op_or_function_name, nullptr, false, nullptr);
|
||||
new TFE_Op{std::make_unique<tensorflow::OperationInterface>(ctx)});
|
||||
status->status = new_op->operation->Reset(op_or_function_name, nullptr);
|
||||
if (!status->status.ok()) {
|
||||
new_op.reset();
|
||||
}
|
||||
@ -1273,49 +1267,51 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
void TFE_DeleteOp(TFE_Op* op) { delete op; }
|
||||
|
||||
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
|
||||
status->status = op->operation.SetDeviceName(device_name);
|
||||
status->status = op->operation->SetDeviceName(device_name);
|
||||
}
|
||||
|
||||
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
|
||||
tensorflow::Device* device = (op->operation.Device() == nullptr)
|
||||
? op->operation.EagerContext().HostCPU()
|
||||
: op->operation.Device();
|
||||
return device->name().c_str();
|
||||
return op->operation->DeviceName().c_str();
|
||||
}
|
||||
|
||||
void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
|
||||
op->operation.SetUseXla(enable);
|
||||
#ifndef TENSORFLOW_EAGER_USE_XLA
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
tensorflow::Status s = op->operation->SetUseXla(enable);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Could not enable XLA compilation for op: " << s;
|
||||
}
|
||||
#else
|
||||
LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
|
||||
"built with XLA support.";
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
}
|
||||
|
||||
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
|
||||
tensorflow::TensorHandle* h =
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
input->handle.get())
|
||||
->Handle();
|
||||
op->operation.AddInput(h);
|
||||
status->status = op->operation.MaybeInferSingleInputAttrs(h);
|
||||
status->status = op->operation->AddInput(input->handle);
|
||||
}
|
||||
|
||||
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
|
||||
TF_Status* status) {
|
||||
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>> handles(
|
||||
num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
op->operation.AddInput(
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
inputs[i]->handle.get())
|
||||
->Handle());
|
||||
handles[i].reset(inputs[i]->handle->Copy());
|
||||
}
|
||||
status->status = op->operation.InferInputListAttrs(num_inputs);
|
||||
status->status = op->operation->AddInputList(handles);
|
||||
}
|
||||
|
||||
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
||||
unsigned char* is_list, TF_Status* status) {
|
||||
TF_AttrType ret = TF_ATTR_INT;
|
||||
status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(),
|
||||
attr_name, &ret, is_list);
|
||||
const tensorflow::AttrTypeMap* attr_types_;
|
||||
bool is_function;
|
||||
status->status = tensorflow::AttrTypeMapForOp(op->operation->Name().c_str(),
|
||||
&attr_types_, &is_function);
|
||||
if (!status->status.ok()) {
|
||||
return ret;
|
||||
}
|
||||
status->status =
|
||||
tensorflow::AttrTypeByName(*attr_types_, attr_name, &ret, is_list);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@ -1336,221 +1332,150 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
|
||||
|
||||
void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
|
||||
size_t length) {
|
||||
op->operation.MutableAttrs()->Set(
|
||||
attr_name,
|
||||
tensorflow::StringPiece(static_cast<const char*>(value), length));
|
||||
auto s = op->operation->SetAttrString(
|
||||
attr_name, static_cast<const char*>(value), length);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
|
||||
op->operation.MutableAttrs()->Set(attr_name, static_cast<int64>(value));
|
||||
auto s = op->operation->SetAttrInt(attr_name, value);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
|
||||
op->operation.MutableAttrs()->Set(attr_name, value);
|
||||
auto s = op->operation->SetAttrFloat(attr_name, value);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
|
||||
op->operation.MutableAttrs()->Set(attr_name, (value == 0) ? false : true);
|
||||
auto s = op->operation->SetAttrBool(attr_name, (value == 0) ? false : true);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
|
||||
op->operation.MutableAttrs()->Set(attr_name,
|
||||
static_cast<tensorflow::DataType>(value));
|
||||
auto s = op->operation->SetAttrType(attr_name, value);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
|
||||
const int num_dims, TF_Status* out_status) {
|
||||
if (num_dims > tensorflow::TensorShape::MaxDimensions()) {
|
||||
TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
|
||||
tensorflow::strings::StrCat(
|
||||
"Value specified for `", attr_name, "` has ", num_dims,
|
||||
" dimensions which is over the limit of ",
|
||||
tensorflow::TensorShape::MaxDimensions(), ".")
|
||||
.c_str());
|
||||
return;
|
||||
}
|
||||
tensorflow::TensorShapeProto proto;
|
||||
if (num_dims < 0) {
|
||||
proto.set_unknown_rank(true);
|
||||
} else {
|
||||
for (int d = 0; d < num_dims; ++d) {
|
||||
proto.add_dim()->set_size(dims[d]);
|
||||
}
|
||||
}
|
||||
op->operation.MutableAttrs()->Set(attr_name, proto);
|
||||
out_status->status = op->operation->SetAttrShape(attr_name, dims, num_dims);
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
|
||||
const TFE_Op* value) {
|
||||
tensorflow::AttrValue attr_value;
|
||||
tensorflow::NameAttrList* func = attr_value.mutable_func();
|
||||
func->set_name(value->operation.Name());
|
||||
value->operation.Attrs().FillAttrValueMap(func->mutable_attr());
|
||||
op->operation.MutableAttrs()->Set(attr_name, attr_value);
|
||||
auto s = op->operation->SetAttrFunction(attr_name, value->operation);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
|
||||
const char* data, size_t length) {
|
||||
tensorflow::AttrValue attr_value;
|
||||
tensorflow::NameAttrList* func = attr_value.mutable_func();
|
||||
func->set_name(data, length);
|
||||
op->operation.MutableAttrs()->Set(attr_name, attr_value);
|
||||
auto s = op->operation->SetAttrFunctionName(attr_name, data, length);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
|
||||
TF_Status* status) {
|
||||
tensorflow::Tensor t;
|
||||
status->status = TF_TensorToTensor(tensor, &t);
|
||||
if (status->status.ok()) op->operation.MutableAttrs()->Set(attr_name, t);
|
||||
status->status = op->operation->SetAttrTensor(attr_name, tensor);
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
|
||||
const void* const* values, const size_t* lengths,
|
||||
int num_values) {
|
||||
std::vector<tensorflow::StringPiece> v(num_values);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
v[i] = tensorflow::StringPiece(static_cast<const char*>(values[i]),
|
||||
lengths[i]);
|
||||
auto s =
|
||||
op->operation->SetAttrStringList(attr_name, values, lengths, num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
op->operation.MutableAttrs()->Set(attr_name, v);
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
|
||||
const float* values, int num_values) {
|
||||
op->operation.MutableAttrs()->Set(
|
||||
attr_name, tensorflow::gtl::ArraySlice<const float>(values, num_values));
|
||||
auto s = op->operation->SetAttrFloatList(attr_name, values, num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
|
||||
const int64_t* values, int num_values) {
|
||||
op->operation.MutableAttrs()->Set(
|
||||
attr_name, tensorflow::gtl::ArraySlice<const int64>(
|
||||
reinterpret_cast<const int64*>(values), num_values));
|
||||
auto s = op->operation->SetAttrIntList(attr_name, values, num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
|
||||
const TF_DataType* values, int num_values) {
|
||||
op->operation.MutableAttrs()->Set(
|
||||
attr_name,
|
||||
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
|
||||
reinterpret_cast<const tensorflow::DataType*>(values), num_values));
|
||||
auto s = op->operation->SetAttrTypeList(attr_name, values, num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
|
||||
const unsigned char* values, int num_values) {
|
||||
std::unique_ptr<bool[]> b(new bool[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
b[i] = values[i];
|
||||
auto s = op->operation->SetAttrBoolList(attr_name, values, num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
op->operation.MutableAttrs()->Set(
|
||||
attr_name, tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
|
||||
const int64_t** dims, const int* num_dims,
|
||||
int num_values, TF_Status* out_status) {
|
||||
std::unique_ptr<tensorflow::TensorShapeProto[]> proto(
|
||||
new tensorflow::TensorShapeProto[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
const auto num_dims_i = num_dims[i];
|
||||
|
||||
if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) {
|
||||
TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
|
||||
tensorflow::strings::StrCat(
|
||||
"Value specified for `", attr_name, "` has ", num_dims_i,
|
||||
" dimensions which is over the limit of ",
|
||||
tensorflow::TensorShape::MaxDimensions(), ".")
|
||||
.c_str());
|
||||
return;
|
||||
}
|
||||
if (num_dims_i < 0) {
|
||||
proto[i].set_unknown_rank(true);
|
||||
} else {
|
||||
const int64_t* dims_i = dims[i];
|
||||
auto proto_i = &proto[i];
|
||||
for (int d = 0; d < num_dims_i; ++d) {
|
||||
proto_i->add_dim()->set_size(dims_i[d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
op->operation.MutableAttrs()->Set(
|
||||
attr_name, tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
|
||||
proto.get(), num_values));
|
||||
out_status->status =
|
||||
op->operation->SetAttrShapeList(attr_name, dims, num_dims, num_values);
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
|
||||
const TFE_Op** value, int num_values) {
|
||||
std::unique_ptr<tensorflow::NameAttrList[]> funcs(
|
||||
new tensorflow::NameAttrList[num_values]);
|
||||
for (int i = 0; i < num_values; i++) {
|
||||
funcs[i].set_name(value[i]->operation.Name());
|
||||
value[i]->operation.Attrs().FillAttrValueMap(funcs[i].mutable_attr());
|
||||
auto s = op->operation->SetAttrFunctionList(attr_name, value, num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
op->operation.MutableAttrs()->Set(
|
||||
attr_name, tensorflow::gtl::ArraySlice<const tensorflow::NameAttrList>(
|
||||
funcs.get(), num_values));
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
|
||||
const char* input_name,
|
||||
TF_Status* status) {
|
||||
const tensorflow::OpDef* op_def = GetOpDef(op, status);
|
||||
if (!status->status.ok()) {
|
||||
return -1;
|
||||
}
|
||||
tensorflow::AttrValueMap attrs;
|
||||
op->operation.Attrs().FillAttrValueMap(&attrs);
|
||||
tensorflow::NameRangeMap name_ranges;
|
||||
status->status = tensorflow::NameRangesForNode(
|
||||
tensorflow::AttrSlice(&attrs), *op_def, &name_ranges, nullptr);
|
||||
if (!status->status.ok()) {
|
||||
return -1;
|
||||
}
|
||||
auto iter = name_ranges.find(input_name);
|
||||
if (iter == name_ranges.end()) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Input '", input_name,
|
||||
"' not found");
|
||||
return -1;
|
||||
}
|
||||
return iter->second.second - iter->second.first;
|
||||
int ret = -1;
|
||||
status->status = op->operation->InputLength(input_name, &ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
|
||||
const char* output_name,
|
||||
TF_Status* status) {
|
||||
const tensorflow::OpDef* op_def = GetOpDef(op, status);
|
||||
if (!status->status.ok()) {
|
||||
return -1;
|
||||
}
|
||||
tensorflow::AttrValueMap attrs;
|
||||
op->operation.Attrs().FillAttrValueMap(&attrs);
|
||||
tensorflow::NameRangeMap name_ranges;
|
||||
status->status = tensorflow::NameRangesForNode(
|
||||
tensorflow::AttrSlice(&attrs), *op_def, nullptr, &name_ranges);
|
||||
if (!status->status.ok()) {
|
||||
return -1;
|
||||
}
|
||||
auto iter = name_ranges.find(output_name);
|
||||
if (iter == name_ranges.end()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Output '", output_name, "' not found");
|
||||
return -1;
|
||||
}
|
||||
return iter->second.second - iter->second.first;
|
||||
int ret = -1;
|
||||
status->status = op->operation->OutputLength(output_name, &ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||
TF_Status* status) {
|
||||
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
|
||||
VLOG(1) << "Calling TFE_Execute() on op " << op;
|
||||
status->status = tensorflow::EagerExecute(&op->operation,
|
||||
handle_retvals.data(), num_retvals);
|
||||
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>> handles(
|
||||
*num_retvals);
|
||||
status->status = op->operation->Execute(&handles, num_retvals);
|
||||
if (!status->status.ok()) {
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals[i] = new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(handle_retvals[i])};
|
||||
retvals[i] = new TFE_TensorHandle{std::move(handles[i])};
|
||||
}
|
||||
}
|
||||
|
||||
@ -1678,6 +1603,23 @@ void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); }
|
||||
|
||||
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); }
|
||||
|
||||
void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) {
|
||||
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
op->operation.get());
|
||||
*attrs = TFE_OpAttrs(&operation->Attrs());
|
||||
}
|
||||
|
||||
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
|
||||
tensorflow::AttrValueMap m;
|
||||
attrs->attributes->FillAttrValueMap(&m);
|
||||
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
op->operation.get());
|
||||
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
|
||||
for (auto attribute : m) {
|
||||
destination->Set(attribute.first, attribute.second);
|
||||
}
|
||||
}
|
||||
|
||||
namespace tensorflow {
|
||||
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||
const tensorflow::AttrValue& default_value,
|
||||
@ -1797,10 +1739,10 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
op->Inputs()[i])});
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
|
||||
// TODO(allenl): figure out how to get attrs from EagerOperation
|
||||
TF_Status status;
|
||||
TFE_OpAttrs attributes(&op->Attrs());
|
||||
device_.execute(inputs.size(), inputs.data(), op->Name().c_str(),
|
||||
num_retvals, outputs.data(), &status, info_);
|
||||
&attributes, num_retvals, outputs.data(), &status, info_);
|
||||
if (status.status.ok()) {
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals[i] = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
|
@ -31,20 +31,14 @@ using tensorflow::string;
|
||||
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
||||
const char* raw_device_name, TF_Status* status) {
|
||||
if (op_to_reset) {
|
||||
status->status = op_to_reset->operation.Reset(
|
||||
op_or_function_name, raw_device_name, false, nullptr);
|
||||
status->status =
|
||||
op_to_reset->operation->Reset(op_or_function_name, raw_device_name);
|
||||
} else {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"op_to_reset should not be nullptr");
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
|
||||
op->operation.ConsumeInput(
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||
->Handle());
|
||||
}
|
||||
|
||||
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
|
||||
ctx->context->SetShouldStoreGraphs(true);
|
||||
}
|
||||
@ -520,8 +514,7 @@ void TFE_DeleteCancellationManager(
|
||||
void TFE_OpSetCancellationManager(TFE_Op* op,
|
||||
TFE_CancellationManager* cancellation_manager,
|
||||
TF_Status* status) {
|
||||
op->operation.SetCancellationManager(
|
||||
&cancellation_manager->cancellation_manager);
|
||||
status->status = op->operation->SetCancellationManager(cancellation_manager);
|
||||
}
|
||||
|
||||
TFE_Executor* TFE_NewExecutor(bool is_async) {
|
||||
@ -569,3 +562,22 @@ void TFE_TensorHandleEnableImplicitMirroring(TFE_TensorHandle* h,
|
||||
h->handle->EnableImplicitMirroring();
|
||||
status->status = tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
|
||||
TF_Buffer* buf, TF_Status* status) {
|
||||
auto* function_def = ctx->context->FindFunctionDef(function_name);
|
||||
if (function_def == nullptr) {
|
||||
status->status = tensorflow::errors::NotFound(
|
||||
"Unable to find FunctionDef with name: ", function_name);
|
||||
return;
|
||||
}
|
||||
string str = function_def->SerializeAsString();
|
||||
void* data = tensorflow::port::Malloc(str.length());
|
||||
str.copy(static_cast<char*>(data), str.length(), 0);
|
||||
buf->data = data;
|
||||
buf->length = str.length();
|
||||
buf->data_deallocator = [](void* data, size_t length) {
|
||||
tensorflow::port::Free(data);
|
||||
};
|
||||
status->status = tensorflow::Status::OK();
|
||||
}
|
||||
|
@ -34,9 +34,6 @@ TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset,
|
||||
const char* raw_device_name,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
|
||||
TF_Status* status);
|
||||
|
||||
// Enables only graph collection in RunMetadata on the functions executed from
|
||||
// this context.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx);
|
||||
@ -424,7 +421,27 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
|
||||
TF_Buffer* buf);
|
||||
|
||||
#define TFE_CUSTOM_DEVICE_VERSION 0
|
||||
// APIs for generically dealing with op attributes (e.g. when forwarding them
|
||||
// through custom device implementations).
|
||||
//
|
||||
// TODO(allenl): Currently these are black boxes, but we should have some way to
|
||||
// inspect values. This would let people e.g. copy over most attributes and then
|
||||
// modify some based on their values.
|
||||
|
||||
// A reference to an op's name -> attribute mapping
|
||||
typedef struct TFE_OpAttrs TFE_OpAttrs;
|
||||
|
||||
// Fetch a struct with a reference to information about attributes of `op`.
|
||||
//
|
||||
// The `attrs` struct does not own any memory, and `op` must outlive it.
|
||||
TF_CAPI_EXPORT extern void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs);
|
||||
|
||||
// Add attributes in `attrs` to `op`.
|
||||
//
|
||||
// Does not overwrite or update existing attributes, but adds new ones.
|
||||
TF_CAPI_EXPORT extern void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs);
|
||||
|
||||
#define TFE_CUSTOM_DEVICE_VERSION 1
|
||||
|
||||
// Struct to be filled in
|
||||
typedef struct TFE_CustomDevice {
|
||||
@ -441,10 +458,10 @@ typedef struct TFE_CustomDevice {
|
||||
void* device_info);
|
||||
|
||||
// Method to execute an operation.
|
||||
// TODO(allenl) figure out a generic way of passing attrs here
|
||||
void (*execute)(int num_inputs, TFE_TensorHandle** inputs,
|
||||
const char* operation_name, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int* num_outputs, TFE_TensorHandle** outputs, TF_Status* s,
|
||||
void* device_info);
|
||||
|
||||
// Method to delete a device.
|
||||
void (*delete_device)(void* device_info);
|
||||
@ -475,6 +492,11 @@ typedef struct TFE_CustomDevice {
|
||||
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
||||
const char* device_name, void* device_info);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx,
|
||||
const char* function_name,
|
||||
TF_Buffer* buf,
|
||||
TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -27,12 +27,12 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
@ -89,7 +89,7 @@ struct TFE_TensorDebugInfo {
|
||||
};
|
||||
|
||||
struct TFE_Op {
|
||||
tensorflow::EagerOperation operation;
|
||||
std::unique_ptr<AbstractOperationInterface> operation;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringCounterCell {
|
||||
@ -236,4 +236,13 @@ struct TFE_Executor {
|
||||
tensorflow::EagerExecutor* unowned_executor;
|
||||
};
|
||||
|
||||
struct TFE_OpAttrs {
|
||||
explicit TFE_OpAttrs() : attributes(nullptr) {}
|
||||
|
||||
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value)
|
||||
: attributes(value) {}
|
||||
|
||||
const tensorflow::AttrBuilder* attributes;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
@ -367,7 +369,7 @@ TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevicesAsync) {
|
||||
void TensorHandleSilentCopy(bool async,
|
||||
TFE_ContextDevicePlacementPolicy global_policy,
|
||||
TFE_ContextDevicePlacementPolicy thread_policy,
|
||||
bool cpu_op) {
|
||||
bool mirror, bool cpu_op) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
@ -390,6 +392,12 @@ void TensorHandleSilentCopy(bool async,
|
||||
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
|
||||
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
if (mirror) {
|
||||
TFE_TensorHandleEnableImplicitMirroring(hcpu, status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
TFE_TensorHandleEnableImplicitMirroring(hgpu, status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
|
||||
if (cpu_op) {
|
||||
@ -414,10 +422,23 @@ void TensorHandleSilentCopy(bool async,
|
||||
hgpu->handle.get())
|
||||
->Handle();
|
||||
|
||||
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
matmul->operation.get());
|
||||
if (mirror) {
|
||||
// The input handles should never change since they have been mirrored.
|
||||
ASSERT_EQ(matmul->operation.Inputs()[0], arg0);
|
||||
ASSERT_EQ(matmul->operation.Inputs()[1], arg1);
|
||||
|
||||
ASSERT_EQ(op->GetInput(0), arg0);
|
||||
ASSERT_EQ(op->GetInput(1), arg1);
|
||||
} else {
|
||||
if (cpu_op) {
|
||||
ASSERT_EQ(op->GetInput(0), arg0);
|
||||
// The GPU handle should be replaced with a CPU copy
|
||||
ASSERT_NE(op->GetInput(1), arg1);
|
||||
} else {
|
||||
// The CPU handle should be replaced with a GPU copy
|
||||
ASSERT_NE(op->GetInput(0), arg0);
|
||||
ASSERT_EQ(op->GetInput(1), arg1);
|
||||
}
|
||||
}
|
||||
TFE_DeleteOp(matmul);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
TFE_DeleteTensorHandle(hgpu);
|
||||
@ -433,19 +454,27 @@ void TensorHandleSilentCopy(bool async,
|
||||
}
|
||||
TEST(CAPI, TensorHandleSilentCopy) {
|
||||
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false);
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false, false);
|
||||
}
|
||||
TEST(CAPI, TensorHandleSilentCopyAsync) {
|
||||
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_SILENT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false);
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false, false);
|
||||
}
|
||||
TEST(CAPI, TensorHandleSilentCopyLocalPolicy) {
|
||||
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_EXPLICIT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false);
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false, false);
|
||||
}
|
||||
TEST(CAPI, TensorHandleSilentCopyLocalPolicyAsync) {
|
||||
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_EXPLICIT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false);
|
||||
TFE_DEVICE_PLACEMENT_SILENT, false, false);
|
||||
}
|
||||
TEST(CAPI, TensorHandleMirrorCopy) {
|
||||
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT, true, false);
|
||||
}
|
||||
TEST(CAPI, TensorHandleMirrorCopyCpu) {
|
||||
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT, true, true);
|
||||
}
|
||||
|
||||
void SetAndGetOpDevices(bool async) {
|
||||
@ -581,6 +610,91 @@ TEST(CAPI, TensorHandleDevices) {
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
void ExecuteAdd(bool async, bool forward_input) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* n = TestMatrixTensorHandle100x100();
|
||||
// If a GPU exists, copy the handle to GPU so that we can exercise
|
||||
// unprotecting a mirror.
|
||||
std::string gpu_device_name;
|
||||
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
|
||||
TFE_TensorHandle* n_gpu =
|
||||
TFE_TensorHandleCopyToDevice(n, ctx, gpu_device_name.c_str(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_TensorHandleEnableImplicitMirroring(n_gpu, status);
|
||||
TFE_DeleteTensorHandle(n);
|
||||
n = n_gpu;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle100x100();
|
||||
|
||||
// Store pointer to raw buffer for validation of forwarding behaviour.
|
||||
TF_Tensor* orig = TFE_TensorHandleResolve(n, status);
|
||||
void* orig_ptr = TF_TensorData(orig);
|
||||
TF_DeleteTensor(orig);
|
||||
|
||||
TFE_Op* add_op = AddOp(ctx, n, m);
|
||||
std::string cpu_device_name;
|
||||
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
|
||||
TFE_OpSetDevice(add_op, cpu_device_name.c_str(), status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
if (forward_input) {
|
||||
TFE_DeleteTensorHandle(n);
|
||||
}
|
||||
|
||||
int num_retvals = 1;
|
||||
|
||||
if (async) {
|
||||
// Enqueue dummy ops so we backlog async execution & actually test async.
|
||||
for (int i = 0; i < 10000; ++i) {
|
||||
TFE_TensorHandle* dummy = nullptr;
|
||||
TFE_Execute(add_op, &dummy, &num_retvals, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(dummy);
|
||||
}
|
||||
}
|
||||
|
||||
TFE_TensorHandle* retval = nullptr;
|
||||
TFE_Execute(add_op, &retval, &num_retvals, status);
|
||||
EXPECT_EQ(1, num_retvals);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
if (!forward_input) {
|
||||
TFE_DeleteTensorHandle(n);
|
||||
}
|
||||
TFE_DeleteOp(add_op);
|
||||
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retval, status);
|
||||
if (forward_input || async) {
|
||||
EXPECT_EQ(orig_ptr, TF_TensorData(t));
|
||||
} else {
|
||||
EXPECT_NE(orig_ptr, TF_TensorData(t));
|
||||
}
|
||||
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(m);
|
||||
TFE_DeleteTensorHandle(retval);
|
||||
TFE_DeleteContext(ctx);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
float result[100 * 100] = {0};
|
||||
EXPECT_EQ(sizeof(result), TF_TensorByteSize(t));
|
||||
memcpy(&result[0], TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
for (int i = 0; i < 100 * 100; ++i) {
|
||||
EXPECT_EQ(2.0f, result[i]);
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
TEST(CAPI, ExecuteAdd) { ExecuteAdd(false, false); }
|
||||
TEST(CAPI, ExecuteAddAsync) { ExecuteAdd(true, false); }
|
||||
TEST(CAPI, ExecuteAddForward) { ExecuteAdd(false, true); }
|
||||
TEST(CAPI, ExecuteAddForwardAsync) { ExecuteAdd(true, true); }
|
||||
|
||||
void Execute_MatMul_CPU(bool async) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
@ -1219,6 +1333,14 @@ TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
|
||||
TFE_DeleteTensorHandle(h_shares_tensor);
|
||||
}
|
||||
|
||||
tensorflow::AttrValueMap ExtractAttrs(TFE_Op* op) {
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
tensorflow::down_cast<tensorflow::OperationInterface*>(op->operation.get())
|
||||
->Attrs()
|
||||
.FillAttrValueMap(&attr_values);
|
||||
return attr_values;
|
||||
}
|
||||
|
||||
TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
@ -1235,8 +1357,7 @@ TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
|
||||
TFE_OpAddInput(minOp, axis, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
minOp->operation.Attrs().FillAttrValueMap(&attr_values);
|
||||
tensorflow::AttrValueMap attr_values = ExtractAttrs(minOp);
|
||||
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
|
||||
EXPECT_NE(attr_found, attr_values.cend());
|
||||
EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
|
||||
@ -1275,8 +1396,7 @@ TEST(CAPI, TestTFE_OpInferSingleTypeInputListAttrs) {
|
||||
TFE_OpAddInputList(concatOp, inputs, 2, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
concatOp->operation.Attrs().FillAttrValueMap(&attr_values);
|
||||
tensorflow::AttrValueMap attr_values = ExtractAttrs(concatOp);
|
||||
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
|
||||
EXPECT_NE(attr_found, attr_values.cend());
|
||||
EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
|
||||
@ -1316,8 +1436,7 @@ TEST(CAPI, TestTFE_OpInferMixedTypeInputListAttrs) {
|
||||
TFE_OpAddInputList(assertOp, data, 3, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
assertOp->operation.Attrs().FillAttrValueMap(&attr_values);
|
||||
tensorflow::AttrValueMap attr_values = ExtractAttrs(assertOp);
|
||||
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
|
||||
EXPECT_NE(attr_found, attr_values.cend());
|
||||
EXPECT_EQ(attr_found->second.list().type(0), tensorflow::DataType::DT_BOOL);
|
||||
@ -1353,16 +1472,15 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
|
||||
TFE_TensorHandle* inputs[] = {input1, input2};
|
||||
TFE_OpAddInput(concatOp, dim, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
CHECK(concatOp->operation.OpDef());
|
||||
CHECK(concatOp->operation->OpDef());
|
||||
TFE_OpAddInput(concatOp, inputs[0], status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_FALSE(concatOp->operation.OpDef())
|
||||
EXPECT_FALSE(concatOp->operation->OpDef())
|
||||
<< "Inference context is still present";
|
||||
TFE_OpAddInput(concatOp, inputs[1], status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
concatOp->operation.Attrs().FillAttrValueMap(&attr_values);
|
||||
tensorflow::AttrValueMap attr_values = ExtractAttrs(concatOp);
|
||||
EXPECT_EQ(attr_values.find("T"), attr_values.end());
|
||||
EXPECT_EQ(attr_values.find("N"), attr_values.end());
|
||||
|
||||
@ -1449,4 +1567,40 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) {
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST(CAPI, TestTFE_OpGetAttrs) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
|
||||
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
|
||||
TFE_OpAttrs attributes;
|
||||
TFE_OpGetAttrs(var_op, &attributes);
|
||||
|
||||
TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT);
|
||||
TFE_OpAddAttrs(copy_op, &attributes);
|
||||
unsigned char is_list = 0;
|
||||
ASSERT_EQ(TF_ATTR_TYPE,
|
||||
TFE_OpGetAttrType(copy_op, "dtype", &is_list, status));
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(TF_ATTR_SHAPE,
|
||||
TFE_OpGetAttrType(copy_op, "shape", &is_list, status));
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
copy_op->operation.get());
|
||||
op->Attrs().FillAttrValueMap(&attr_values);
|
||||
EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type());
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteOp(var_op);
|
||||
TFE_DeleteOp(copy_op);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -131,6 +131,21 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2() {
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
TFE_Op* op = TFE_NewOp(ctx, "AddV2", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(op, a, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(op, b, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
|
||||
|
||||
return op;
|
||||
}
|
||||
|
||||
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
|
@ -42,6 +42,9 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2();
|
||||
// Return a tensor handle containing a 3x2 matrix of floats
|
||||
TFE_TensorHandle* TestMatrixTensorHandle3X2();
|
||||
|
||||
// Return an add op multiplying `a` by `b`.
|
||||
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
|
||||
|
||||
// Return a matmul op multiplying `a` by `b`.
|
||||
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
|
||||
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace {
|
||||
@ -31,6 +32,8 @@ struct LoggingDevice {
|
||||
tensorflow::string underlying_device;
|
||||
// Set to true whenever a TensorHandle is copied onto the device
|
||||
bool* arrived_flag;
|
||||
// Set to true whenever an operation is executed
|
||||
bool* executed_flag;
|
||||
};
|
||||
|
||||
struct LoggedTensor {
|
||||
@ -81,12 +84,14 @@ TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_TensorHandle* tensor,
|
||||
}
|
||||
|
||||
void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs,
|
||||
const char* operation_name, int* num_outputs,
|
||||
const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* s,
|
||||
void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_Op* op(TFE_NewOp(dev->ctx, operation_name, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpAddAttrs(op, attributes);
|
||||
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
|
||||
for (int j = 0; j < num_inputs; ++j) {
|
||||
TFE_TensorHandle* input = inputs[j];
|
||||
@ -115,6 +120,7 @@ void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs,
|
||||
outputs[i] = MakeLoggedTensorHandle(dev->ctx, dev->device_name,
|
||||
std::move(logged_tensor), s);
|
||||
}
|
||||
*(dev->executed_flag) = true;
|
||||
}
|
||||
|
||||
void DeleteLoggingDevice(void* device_info) {
|
||||
@ -122,7 +128,7 @@ void DeleteLoggingDevice(void* device_info) {
|
||||
}
|
||||
|
||||
void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
bool* arrived_flag) {
|
||||
bool* arrived_flag, bool* executed_flag) {
|
||||
TFE_CustomDevice custom_device;
|
||||
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
|
||||
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
|
||||
@ -131,6 +137,7 @@ void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
LoggingDevice* device = new LoggingDevice;
|
||||
device->ctx = context;
|
||||
device->arrived_flag = arrived_flag;
|
||||
device->executed_flag = executed_flag;
|
||||
device->device_name = name;
|
||||
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
TFE_RegisterCustomDevice(context, custom_device, name, device);
|
||||
@ -144,13 +151,15 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context, name, &arrived);
|
||||
RegisterLoggingDevice(context, name, &arrived, &executed);
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||
ASSERT_FALSE(arrived);
|
||||
TFE_TensorHandle* hdevice =
|
||||
TFE_TensorHandleCopyToDevice(hcpu, context, name, status.get());
|
||||
ASSERT_TRUE(arrived);
|
||||
ASSERT_FALSE(executed);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> matmul(
|
||||
MatMulOp(context, hcpu, hdevice), TFE_DeleteOp);
|
||||
@ -160,6 +169,7 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
|
||||
TFE_DeleteTensorHandle(retval);
|
||||
TFE_DeleteTensorHandle(hcpu);
|
||||
@ -167,4 +177,118 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
||||
TFE_DeleteContext(context);
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, ResetOperation) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts, status.get()), TFE_DeleteContext);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
const char* custom_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed);
|
||||
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> reused_op(
|
||||
TFE_NewOp(context.get(), "Identity", status.get()), TFE_DeleteOp);
|
||||
TFE_OpReset(reused_op.get(), "Identity", custom_device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())),
|
||||
tensorflow::string(custom_device_name));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpReset(reused_op.get(), "Identity",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0", status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())),
|
||||
tensorflow::string("/job:localhost/replica:0/task:0/device:CPU:0"));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, MakeVariable) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context.get(), name, &arrived, &executed);
|
||||
|
||||
// Create a variable handle placed on the custom device.
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get());
|
||||
TFE_OpSetAttrString(op.get(), "container", "", 0);
|
||||
TFE_OpSetAttrString(op.get(), "shared_name", "", 0);
|
||||
TFE_OpSetDevice(op.get(), name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_TensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
executed = false;
|
||||
TFE_Execute(op.get(), &var_handle, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
auto handle_cleaner = tensorflow::gtl::MakeCleanup(
|
||||
[var_handle]() { TFE_DeleteTensorHandle(var_handle); });
|
||||
|
||||
// Assign to the variable, copying to the custom device.
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
|
||||
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
|
||||
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
TFE_OpAddInput(op.get(), one.get(), status.get());
|
||||
TFE_OpSetDevice(op.get(), name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
executed = false;
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
|
||||
// Read the variable's value.
|
||||
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
TFE_OpSetDevice(op.get(), name, status.get());
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
executed = false;
|
||||
num_retvals = 1;
|
||||
TFE_TensorHandle* var_value = nullptr;
|
||||
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
auto value_cleaner = tensorflow::gtl::MakeCleanup(
|
||||
[var_value]() { TFE_DeleteTensorHandle(var_value); });
|
||||
ASSERT_EQ(tensorflow::string(name),
|
||||
tensorflow::string(
|
||||
TFE_TensorHandleBackingDeviceName(var_value, status.get())));
|
||||
TFE_TensorHandle* var_value_unpacked =
|
||||
reinterpret_cast<LoggedTensor*>(
|
||||
TFE_TensorHandleDevicePointer(var_value, status.get()))
|
||||
->tensor;
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> resolved_value(
|
||||
TFE_TensorHandleResolve(var_value_unpacked, status.get()),
|
||||
TF_DeleteTensor);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_EQ(111., *static_cast<float*>(TF_TensorData(resolved_value.get())));
|
||||
|
||||
// Free the backing buffer for the variable.
|
||||
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
TFE_OpSetDevice(op.get(), name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
312
tensorflow/c/eager/operation_interface.cc
Normal file
312
tensorflow/c/eager/operation_interface.cc
Normal file
@ -0,0 +1,312 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
|
||||
#include "absl/container/fixed_array.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/common_runtime/eager/execute.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
OperationInterface::OperationInterface(TFE_Context* ctx)
|
||||
: operation_(ctx->context) {}
|
||||
|
||||
const string& OperationInterface::DeviceName() const {
|
||||
absl::variant<Device*, CustomDevice*> variant_device =
|
||||
(operation_.Device() == kVariantDeviceNull)
|
||||
? operation_.EagerContext().HostCPU()
|
||||
: operation_.Device();
|
||||
return absl::visit([](auto* d) -> const string& { return d->name(); },
|
||||
variant_device);
|
||||
}
|
||||
|
||||
Status OperationInterface::SetDeviceName(const char* name) {
|
||||
return operation_.SetDeviceName(name);
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrString(const char* attr_name,
|
||||
const char* data, size_t length) {
|
||||
operation_.MutableAttrs()->Set(attr_name, StringPiece(data, length));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrInt(const char* attr_name, int64_t value) {
|
||||
operation_.MutableAttrs()->Set(attr_name, static_cast<int64>(value));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrFloat(const char* attr_name, float value) {
|
||||
operation_.MutableAttrs()->Set(attr_name, value);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrBool(const char* attr_name, bool value) {
|
||||
operation_.MutableAttrs()->Set(attr_name, value);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrType(const char* attr_name,
|
||||
TF_DataType value) {
|
||||
operation_.MutableAttrs()->Set(attr_name, static_cast<DataType>(value));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrShape(const char* attr_name,
|
||||
const int64_t* dims,
|
||||
const int num_dims) {
|
||||
if (num_dims > TensorShape::MaxDimensions()) {
|
||||
return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
|
||||
num_dims,
|
||||
" dimensions which is over the limit of ",
|
||||
TensorShape::MaxDimensions(), ".");
|
||||
}
|
||||
|
||||
TensorShapeProto proto;
|
||||
if (num_dims < 0) {
|
||||
proto.set_unknown_rank(true);
|
||||
} else {
|
||||
for (int d = 0; d < num_dims; ++d) {
|
||||
proto.add_dim()->set_size(dims[d]);
|
||||
}
|
||||
}
|
||||
|
||||
operation_.MutableAttrs()->Set(attr_name, proto);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrFunction(
|
||||
const char* attr_name,
|
||||
const std::unique_ptr<AbstractOperationInterface>& value) {
|
||||
AttrValue attr_value;
|
||||
NameAttrList* func = attr_value.mutable_func();
|
||||
func->set_name(value->Name());
|
||||
OperationInterface* value_operation =
|
||||
tensorflow::down_cast<OperationInterface*>(value.get());
|
||||
value_operation->operation_.Attrs().FillAttrValueMap(func->mutable_attr());
|
||||
operation_.MutableAttrs()->Set(attr_name, attr_value);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrFunctionName(const char* attr_name,
|
||||
const char* data,
|
||||
size_t length) {
|
||||
AttrValue attr_value;
|
||||
NameAttrList* func = attr_value.mutable_func();
|
||||
func->set_name(data, length);
|
||||
operation_.MutableAttrs()->Set(attr_name, attr_value);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrTensor(const char* attr_name,
|
||||
TF_Tensor* tensor) {
|
||||
Tensor t;
|
||||
TF_RETURN_IF_ERROR(TF_TensorToTensor(tensor, &t));
|
||||
operation_.MutableAttrs()->Set(attr_name, t);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrStringList(const char* attr_name,
|
||||
const void* const* values,
|
||||
const size_t* lengths,
|
||||
int num_values) {
|
||||
std::vector<StringPiece> v(num_values);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
|
||||
}
|
||||
operation_.MutableAttrs()->Set(attr_name, v);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrFloatList(const char* attr_name,
|
||||
const float* values,
|
||||
int num_values) {
|
||||
operation_.MutableAttrs()->Set(
|
||||
attr_name, gtl::ArraySlice<const float>(values, num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrIntList(const char* attr_name,
|
||||
const int64_t* values,
|
||||
int num_values) {
|
||||
operation_.MutableAttrs()->Set(
|
||||
attr_name, gtl::ArraySlice<const int64>(
|
||||
reinterpret_cast<const int64*>(values), num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrTypeList(const char* attr_name,
|
||||
const TF_DataType* values,
|
||||
int num_values) {
|
||||
operation_.MutableAttrs()->Set(
|
||||
attr_name, gtl::ArraySlice<const DataType>(
|
||||
reinterpret_cast<const DataType*>(values), num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrBoolList(const char* attr_name,
|
||||
const unsigned char* values,
|
||||
int num_values) {
|
||||
std::unique_ptr<bool[]> b(new bool[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
b[i] = values[i];
|
||||
}
|
||||
operation_.MutableAttrs()->Set(
|
||||
attr_name, gtl::ArraySlice<const bool>(b.get(), num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrShapeList(const char* attr_name,
|
||||
const int64_t** dims,
|
||||
const int* num_dims,
|
||||
int num_values) {
|
||||
std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
const auto num_dims_i = num_dims[i];
|
||||
|
||||
if (num_dims_i > TensorShape::MaxDimensions()) {
|
||||
return errors::InvalidArgument(
|
||||
strings::StrCat("Value specified for `", attr_name, "` has ",
|
||||
num_dims_i, " dimensions which is over the limit of ",
|
||||
TensorShape::MaxDimensions(), "."));
|
||||
}
|
||||
if (num_dims_i < 0) {
|
||||
proto[i].set_unknown_rank(true);
|
||||
} else {
|
||||
const int64_t* dims_i = dims[i];
|
||||
auto proto_i = &proto[i];
|
||||
for (int d = 0; d < num_dims_i; ++d) {
|
||||
proto_i->add_dim()->set_size(dims_i[d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
operation_.MutableAttrs()->Set(
|
||||
attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrFunctionList(const char* attr_name,
|
||||
const TFE_Op** value,
|
||||
int num_values) {
|
||||
std::unique_ptr<NameAttrList[]> funcs(new NameAttrList[num_values]);
|
||||
for (int i = 0; i < num_values; i++) {
|
||||
auto value_operation =
|
||||
tensorflow::down_cast<OperationInterface*>(value[i]->operation.get());
|
||||
funcs[i].set_name(value_operation->operation_.Name());
|
||||
value_operation->operation_.Attrs().FillAttrValueMap(
|
||||
funcs[i].mutable_attr());
|
||||
}
|
||||
operation_.MutableAttrs()->Set(
|
||||
attr_name, gtl::ArraySlice<const NameAttrList>(funcs.get(), num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const OpDef* OperationInterface::GetOpDef(Status* status) {
|
||||
const tensorflow::OpDef* op_def = operation_.OpDef();
|
||||
if (op_def) return op_def;
|
||||
*status = OpDefForOp(Name(), &op_def);
|
||||
return op_def;
|
||||
}
|
||||
|
||||
Status OperationInterface::InputLength(const char* input_name, int* length) {
|
||||
Status status;
|
||||
const tensorflow::OpDef* op_def = GetOpDef(&status);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
AttrValueMap attrs;
|
||||
operation_.Attrs().FillAttrValueMap(&attrs);
|
||||
NameRangeMap name_ranges;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NameRangesForNode(AttrSlice(&attrs), *op_def, &name_ranges, nullptr));
|
||||
auto iter = name_ranges.find(input_name);
|
||||
if (iter == name_ranges.end()) {
|
||||
return errors::InvalidArgument("Input '", input_name, "' not found");
|
||||
}
|
||||
*length = iter->second.second - iter->second.first;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::OutputLength(const char* output_name, int* length) {
|
||||
Status status;
|
||||
const tensorflow::OpDef* op_def = GetOpDef(&status);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
AttrValueMap attrs;
|
||||
operation_.Attrs().FillAttrValueMap(&attrs);
|
||||
NameRangeMap name_ranges;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NameRangesForNode(AttrSlice(&attrs), *op_def, nullptr, &name_ranges));
|
||||
auto iter = name_ranges.find(output_name);
|
||||
if (iter == name_ranges.end()) {
|
||||
return errors::InvalidArgument("Output '", output_name, "' not found");
|
||||
}
|
||||
*length = iter->second.second - iter->second.first;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::AddInput(
|
||||
const std::unique_ptr<AbstractTensorHandleInterface>& input) {
|
||||
TensorHandle* h =
|
||||
tensorflow::down_cast<TensorHandleInterface*>(input.get())->Handle();
|
||||
operation_.AddInput(h);
|
||||
return operation_.MaybeInferSingleInputAttrs(h);
|
||||
}
|
||||
|
||||
Status OperationInterface::AddInputList(
|
||||
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
|
||||
inputs) {
|
||||
for (auto& input : inputs) {
|
||||
TensorHandle* h =
|
||||
tensorflow::down_cast<TensorHandleInterface*>(input.get())->Handle();
|
||||
operation_.AddInput(h);
|
||||
}
|
||||
return operation_.InferInputListAttrs(inputs.size());
|
||||
}
|
||||
|
||||
Status OperationInterface::Execute(
|
||||
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
|
||||
int* num_retvals) {
|
||||
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
|
||||
TF_RETURN_IF_ERROR(
|
||||
EagerExecute(&operation_, handle_retvals.data(), num_retvals));
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals->at(i).reset(
|
||||
new tensorflow::TensorHandleInterface(handle_retvals[i]));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetCancellationManager(
|
||||
TFE_CancellationManager* cancellation_manager) {
|
||||
operation_.SetCancellationManager(
|
||||
&cancellation_manager->cancellation_manager);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetUseXla(bool enable) {
|
||||
operation_.SetUseXla(enable);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
188
tensorflow/c/eager/operation_interface.h
Normal file
188
tensorflow/c/eager/operation_interface.h
Normal file
@ -0,0 +1,188 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
||||
#define TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/container/fixed_array.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
|
||||
// Abstract interface to an operation.
|
||||
class AbstractOperationInterface {
|
||||
public:
|
||||
virtual ~AbstractOperationInterface() {}
|
||||
|
||||
virtual void Clear() = 0;
|
||||
virtual tensorflow::Status Reset(const char* op,
|
||||
const char* raw_device_name) = 0;
|
||||
|
||||
virtual const tensorflow::string& Name() const = 0;
|
||||
virtual const tensorflow::string& DeviceName() const = 0;
|
||||
virtual tensorflow::Status SetDeviceName(const char* name) = 0;
|
||||
|
||||
virtual tensorflow::Status AddInput(
|
||||
const std::unique_ptr<AbstractTensorHandleInterface>& input) = 0;
|
||||
virtual tensorflow::Status AddInputList(
|
||||
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
|
||||
inputs) = 0;
|
||||
virtual tensorflow::Status Execute(
|
||||
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
|
||||
int* num_retvals) = 0;
|
||||
virtual const tensorflow::OpDef* OpDef() const = 0;
|
||||
|
||||
virtual tensorflow::Status SetAttrString(const char* attr_name,
|
||||
const char* data, size_t length) = 0;
|
||||
virtual tensorflow::Status SetAttrInt(const char* attr_name,
|
||||
int64_t value) = 0;
|
||||
virtual tensorflow::Status SetAttrFloat(const char* attr_name,
|
||||
float value) = 0;
|
||||
virtual tensorflow::Status SetAttrBool(const char* attr_name, bool value) = 0;
|
||||
virtual tensorflow::Status SetAttrType(const char* attr_name,
|
||||
TF_DataType value) = 0;
|
||||
virtual tensorflow::Status SetAttrShape(const char* attr_name,
|
||||
const int64_t* dims,
|
||||
const int num_dims) = 0;
|
||||
virtual tensorflow::Status SetAttrFunction(
|
||||
const char* attr_name,
|
||||
const std::unique_ptr<AbstractOperationInterface>& value) = 0;
|
||||
virtual tensorflow::Status SetAttrFunctionName(const char* attr_name,
|
||||
const char* value,
|
||||
size_t length) = 0;
|
||||
virtual tensorflow::Status SetAttrTensor(const char* attr_name,
|
||||
TF_Tensor* tensor) = 0;
|
||||
virtual tensorflow::Status SetAttrStringList(const char* attr_name,
|
||||
const void* const* values,
|
||||
const size_t* lengths,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrFloatList(const char* attr_name,
|
||||
const float* values,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrIntList(const char* attr_name,
|
||||
const int64_t* values,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrTypeList(const char* attr_name,
|
||||
const TF_DataType* values,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrBoolList(const char* attr_name,
|
||||
const unsigned char* values,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrShapeList(const char* attr_name,
|
||||
const int64_t** dims,
|
||||
const int* num_dims,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrFunctionList(const char* attr_name,
|
||||
const TFE_Op** value,
|
||||
int num_values) = 0;
|
||||
|
||||
virtual tensorflow::Status InputLength(const char* input_name,
|
||||
int* length) = 0;
|
||||
virtual tensorflow::Status OutputLength(const char* output_name,
|
||||
int* length) = 0;
|
||||
|
||||
// Experimental
|
||||
virtual tensorflow::Status SetUseXla(bool enable) {
|
||||
return tensorflow::errors::Unimplemented("SetUseXla not implemented");
|
||||
}
|
||||
virtual tensorflow::Status SetCancellationManager(
|
||||
TFE_CancellationManager* cancellation_manager) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetCancellationManager not implemented");
|
||||
}
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class OpDef;
|
||||
|
||||
class OperationInterface : public AbstractOperationInterface {
|
||||
public:
|
||||
explicit OperationInterface(TFE_Context* ctx);
|
||||
~OperationInterface() override{};
|
||||
|
||||
void Clear() override { operation_.Clear(); }
|
||||
Status Reset(const char* op, const char* raw_device_name) override {
|
||||
return operation_.Reset(op, raw_device_name, false, nullptr);
|
||||
}
|
||||
|
||||
const string& Name() const override { return operation_.Name(); }
|
||||
const string& DeviceName() const override;
|
||||
Status SetDeviceName(const char* name) override;
|
||||
|
||||
Status AddInput(
|
||||
const std::unique_ptr<AbstractTensorHandleInterface>& input) override;
|
||||
Status AddInputList(
|
||||
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
|
||||
inputs) override;
|
||||
Status Execute(
|
||||
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
|
||||
int* num_retvals) override;
|
||||
const tensorflow::OpDef* OpDef() const override {
|
||||
return operation_.OpDef();
|
||||
};
|
||||
|
||||
Status SetAttrString(const char* attr_name, const char* data,
|
||||
size_t length) override;
|
||||
Status SetAttrInt(const char* attr_name, int64_t value) override;
|
||||
Status SetAttrFloat(const char* attr_name, float value) override;
|
||||
Status SetAttrBool(const char* attr_name, bool value) override;
|
||||
Status SetAttrType(const char* attr_name, TF_DataType value) override;
|
||||
Status SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||
const int num_dims) override;
|
||||
Status SetAttrFunction(
|
||||
const char* attr_name,
|
||||
const std::unique_ptr<AbstractOperationInterface>& value) override;
|
||||
Status SetAttrFunctionName(const char* attr_name, const char* data,
|
||||
size_t length) override;
|
||||
Status SetAttrTensor(const char* attr_name, TF_Tensor* tensor) override;
|
||||
Status SetAttrStringList(const char* attr_name, const void* const* values,
|
||||
const size_t* lengths, int num_values) override;
|
||||
Status SetAttrFloatList(const char* attr_name, const float* values,
|
||||
int num_values) override;
|
||||
Status SetAttrIntList(const char* attr_name, const int64_t* values,
|
||||
int num_values) override;
|
||||
Status SetAttrTypeList(const char* attr_name, const TF_DataType* values,
|
||||
int num_values) override;
|
||||
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
|
||||
int num_values) override;
|
||||
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
|
||||
const int* num_dims, int num_values) override;
|
||||
Status SetAttrFunctionList(const char* attr_name, const TFE_Op** value,
|
||||
int num_values) override;
|
||||
|
||||
Status InputLength(const char* input_name, int* length) override;
|
||||
Status OutputLength(const char* output_name, int* length) override;
|
||||
|
||||
Status SetUseXla(bool enable) override;
|
||||
Status SetCancellationManager(
|
||||
TFE_CancellationManager* cancellation_manager) override;
|
||||
|
||||
// TODO(gjn): Remove once TFE_InferShapes is removed
|
||||
const tensorflow::AttrBuilder& Attrs() const { return operation_.Attrs(); }
|
||||
tensorflow::AttrBuilder* MutableAttrs() { return operation_.MutableAttrs(); }
|
||||
|
||||
const TensorHandle* GetInput(int i) const { return operation_.Inputs()[i]; }
|
||||
|
||||
private:
|
||||
const tensorflow::OpDef* GetOpDef(Status* status);
|
||||
EagerOperation operation_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
@ -64,25 +65,41 @@ void deallocate_buffer(void* data, size_t len, void* arg) {
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
||||
namespace {
|
||||
TF_Tensor* CreateTensor(TF_ManagedBuffer* buf, TF_DataType dtype,
|
||||
const int64_t* dims, int num_dims, size_t len) {
|
||||
std::vector<tensorflow::int64> dimvec(num_dims);
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
||||
}
|
||||
|
||||
// TODO(gjn): Make the choice of interface a compile-time configuration.
|
||||
tensorflow::TensorInterface ret(
|
||||
Tensor(static_cast<tensorflow::DataType>(dtype),
|
||||
tensorflow::TensorShape(dimvec), buf));
|
||||
buf->Unref();
|
||||
size_t elem_size = TF_DataTypeSize(dtype);
|
||||
if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(ret)};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims,
|
||||
int num_dims, size_t len) {
|
||||
void* data = tensorflow::allocate_tensor("TF_AllocateTensor", len,
|
||||
tensorflow::cpu_allocator());
|
||||
return TF_NewTensor(dtype, dims, num_dims, data, len,
|
||||
tensorflow::deallocate_buffer,
|
||||
tensorflow::cpu_allocator());
|
||||
TF_ManagedBuffer* buf =
|
||||
new TF_ManagedBuffer(data, len, tensorflow::deallocate_buffer,
|
||||
tensorflow::cpu_allocator(), /*owns_memory=*/true);
|
||||
return CreateTensor(buf, dtype, dims, num_dims, len);
|
||||
}
|
||||
|
||||
TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
|
||||
void* data, size_t len,
|
||||
void (*deallocator)(void* data, size_t len, void* arg),
|
||||
void* deallocator_arg) {
|
||||
std::vector<tensorflow::int64> dimvec(num_dims);
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
||||
}
|
||||
|
||||
TF_ManagedBuffer* buf = nullptr;
|
||||
if (dtype != TF_STRING && dtype != TF_RESOURCE &&
|
||||
tensorflow::DataTypeCanUseMemcpy(
|
||||
@ -97,24 +114,17 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
|
||||
// Other types have the same representation, so copy only if it is safe to
|
||||
// do so.
|
||||
buf = new TF_ManagedBuffer(tensorflow::allocate_tensor("TF_NewTensor", len),
|
||||
len, tensorflow::deallocate_buffer, nullptr);
|
||||
len, tensorflow::deallocate_buffer, nullptr,
|
||||
/*owns_memory=*/true);
|
||||
std::memcpy(buf->data(), data, len);
|
||||
// Free the original buffer.
|
||||
deallocator(data, len, deallocator_arg);
|
||||
} else {
|
||||
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
|
||||
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg,
|
||||
/*owns_memory=*/false);
|
||||
}
|
||||
|
||||
// TODO(gjn): Make the choice of interface a compile-time configuration.
|
||||
tensorflow::TensorInterface ret(
|
||||
Tensor(static_cast<tensorflow::DataType>(dtype),
|
||||
tensorflow::TensorShape(dimvec), buf));
|
||||
buf->Unref();
|
||||
size_t elem_size = TF_DataTypeSize(dtype);
|
||||
if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(ret)};
|
||||
return CreateTensor(buf, dtype, dims, num_dims, len);
|
||||
}
|
||||
|
||||
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* t) {
|
||||
|
@ -38,11 +38,12 @@ class TF_ManagedBuffer : public tensorflow::TensorBuffer {
|
||||
public:
|
||||
TF_ManagedBuffer(void* data, size_t len,
|
||||
void (*deallocator)(void* data, size_t len, void* arg),
|
||||
void* deallocator_arg)
|
||||
void* deallocator_arg, bool owns_memory)
|
||||
: TensorBuffer(data),
|
||||
len_(len),
|
||||
deallocator_(deallocator),
|
||||
deallocator_arg_(deallocator_arg) {}
|
||||
deallocator_arg_(deallocator_arg),
|
||||
owns_memory_(owns_memory) {}
|
||||
|
||||
~TF_ManagedBuffer() override {
|
||||
(*deallocator_)(data(), len_, deallocator_arg_);
|
||||
@ -57,13 +58,13 @@ class TF_ManagedBuffer : public tensorflow::TensorBuffer {
|
||||
proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
|
||||
}
|
||||
|
||||
// Prevents input forwarding from mutating this buffer.
|
||||
bool OwnsMemory() const override { return false; }
|
||||
bool OwnsMemory() const override { return owns_memory_; }
|
||||
|
||||
private:
|
||||
const size_t len_;
|
||||
void (*const deallocator_)(void* data, size_t len, void* arg);
|
||||
void* const deallocator_arg_;
|
||||
bool owns_memory_;
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -15,13 +15,12 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||
#include "tensorflow/cc/framework/gradients.h"
|
||||
#include "tensorflow/cc/ops/array_ops_internal.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||
#include "tensorflow/cc/framework/gradients.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
namespace {
|
||||
@ -90,15 +89,25 @@ Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op,
|
||||
}
|
||||
REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad);
|
||||
|
||||
Status QuantizeAndDequantizeV2Grad(const Scope& scope, const Operation& op,
|
||||
Status QuantizeAndDequantizeV2GradHelper(const Scope& scope,
|
||||
const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
grad_outputs->push_back(Identity(scope, grad_inputs[0]));
|
||||
grad_outputs->push_back(NoGradient());
|
||||
grad_outputs->push_back(NoGradient());
|
||||
Input input = Shape(scope, op.input(0));
|
||||
Input input_min = op.input(1);
|
||||
Input input_max = op.input(2);
|
||||
int64 axis;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
|
||||
auto qdq_v2_grad = QuantizeAndDequantizeV2Grad(
|
||||
scope, grad_inputs[0], input, input_min, input_max,
|
||||
QuantizeAndDequantizeV2Grad::Axis(axis));
|
||||
grad_outputs->push_back(qdq_v2_grad.input_backprop);
|
||||
grad_outputs->push_back(qdq_v2_grad.input_min_backprop);
|
||||
grad_outputs->push_back(qdq_v2_grad.input_max_backprop);
|
||||
return scope.status();
|
||||
}
|
||||
REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2", QuantizeAndDequantizeV2Grad);
|
||||
REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2",
|
||||
QuantizeAndDequantizeV2GradHelper);
|
||||
|
||||
Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
|
@ -68,6 +68,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/platform:resource_loader",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -21,15 +21,22 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
#include "tensorflow/core/platform/resource_loader.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
constexpr char kTestDataPbTxt[] =
|
||||
"cc/saved_model/testdata/half_plus_two_pbtxt/00000123";
|
||||
constexpr char kTestDataSharded[] =
|
||||
"cc/saved_model/testdata/half_plus_two/00000123";
|
||||
string TestDataPbTxt() {
|
||||
return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
|
||||
"half_plus_two_pbtxt", "00000123");
|
||||
}
|
||||
|
||||
string TestDataSharded() {
|
||||
return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
|
||||
"half_plus_two", "00000123");
|
||||
}
|
||||
|
||||
class ReaderTest : public ::testing::Test {
|
||||
protected:
|
||||
@ -49,8 +56,7 @@ class ReaderTest : public ::testing::Test {
|
||||
TEST_F(ReaderTest, TagMatch) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
|
||||
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||
&meta_graph_def));
|
||||
CheckMetaGraphDef(meta_graph_def);
|
||||
@ -59,8 +65,7 @@ TEST_F(ReaderTest, TagMatch) {
|
||||
TEST_F(ReaderTest, NoTagMatch) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
|
||||
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
|
||||
&meta_graph_def);
|
||||
EXPECT_FALSE(st.ok());
|
||||
@ -73,8 +78,7 @@ TEST_F(ReaderTest, NoTagMatch) {
|
||||
TEST_F(ReaderTest, NoTagMatchMultiple) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
|
||||
Status st = ReadMetaGraphDefFromSavedModel(
|
||||
export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
|
||||
EXPECT_FALSE(st.ok());
|
||||
@ -87,8 +91,7 @@ TEST_F(ReaderTest, NoTagMatchMultiple) {
|
||||
TEST_F(ReaderTest, PbtxtFormat) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPbTxt);
|
||||
const string export_dir = GetDataDependencyFilepath(TestDataPbTxt());
|
||||
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||
&meta_graph_def));
|
||||
CheckMetaGraphDef(meta_graph_def);
|
||||
@ -97,8 +100,7 @@ TEST_F(ReaderTest, PbtxtFormat) {
|
||||
TEST_F(ReaderTest, InvalidExportPath) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path");
|
||||
const string export_dir = GetDataDependencyFilepath("missing-path");
|
||||
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||
&meta_graph_def);
|
||||
EXPECT_FALSE(st.ok());
|
||||
|
@ -84,6 +84,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/platform:resource_loader",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support", # fixdeps: keep
|
||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/aot/codegen.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@ -29,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/resource_loader.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -139,23 +141,40 @@ TEST_F(ParseCppClassTest, ParseFail) {
|
||||
|
||||
static void CompareWithGoldenFile(
|
||||
const string& tensorflow_relative_golden_file_name,
|
||||
const string& expected_contents) {
|
||||
const string& expected_contents, bool ignore_cr) {
|
||||
// Get rid of all CR characters, we may be running under windows.
|
||||
string sanitized_expected_contents(expected_contents);
|
||||
if (ignore_cr) {
|
||||
sanitized_expected_contents.erase(
|
||||
std::remove(sanitized_expected_contents.begin(),
|
||||
sanitized_expected_contents.end(), '\r'),
|
||||
sanitized_expected_contents.end());
|
||||
}
|
||||
|
||||
// To update the golden file, flip update_golden to true and run the
|
||||
// following:
|
||||
// bazel test --test_strategy=local \
|
||||
// third_party/tensorflow/compiler/aot:codegen_test
|
||||
const bool update_golden = false;
|
||||
const string golden_file_name = io::JoinPath(
|
||||
testing::TensorFlowSrcRoot(), tensorflow_relative_golden_file_name);
|
||||
string golden_file_name;
|
||||
|
||||
if (update_golden) {
|
||||
golden_file_name = io::JoinPath(testing::TensorFlowSrcRoot(),
|
||||
tensorflow_relative_golden_file_name);
|
||||
TF_EXPECT_OK(
|
||||
WriteStringToFile(Env::Default(), golden_file_name, expected_contents));
|
||||
}
|
||||
|
||||
golden_file_name =
|
||||
GetDataDependencyFilepath(tensorflow_relative_golden_file_name);
|
||||
string golden_file_contents;
|
||||
TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name,
|
||||
&golden_file_contents));
|
||||
if (ignore_cr) {
|
||||
golden_file_contents.erase(std::remove(golden_file_contents.begin(),
|
||||
golden_file_contents.end(), '\r'),
|
||||
golden_file_contents.end());
|
||||
}
|
||||
EXPECT_EQ(golden_file_contents, expected_contents);
|
||||
}
|
||||
|
||||
@ -229,14 +248,18 @@ TEST(CodegenTest, Golden) {
|
||||
// The other fields in metadata_result are tested as part of the generated
|
||||
// header test.
|
||||
|
||||
CompareWithGoldenFile("compiler/aot/codegen_test_o.golden",
|
||||
metadata_result.object_file_data);
|
||||
// This specific golden test checks a binary file. It can potentially run into
|
||||
// issues due to ABIs not being stable, but has not so far.
|
||||
// If we see any ABI issues, we should reconsider this specific test case.
|
||||
CompareWithGoldenFile("tensorflow/compiler/aot/codegen_test_o.golden",
|
||||
metadata_result.object_file_data, false);
|
||||
|
||||
string header;
|
||||
TF_ASSERT_OK(
|
||||
GenerateHeader(opts, config, compile_result, metadata_result, &header));
|
||||
|
||||
CompareWithGoldenFile("compiler/aot/codegen_test_h.golden", header);
|
||||
CompareWithGoldenFile("tensorflow/compiler/aot/codegen_test_h.golden", header,
|
||||
true);
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tfcompile
|
||||
|
@ -1883,6 +1883,8 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"EmptyTensorList",
|
||||
"ExtractImagePatches",
|
||||
"Igamma",
|
||||
"IgammaGradA",
|
||||
"RandomGammaGrad",
|
||||
"Igammac",
|
||||
"FFT",
|
||||
"FFT2D",
|
||||
@ -1909,7 +1911,6 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"LinSpace",
|
||||
"ListDiff",
|
||||
"LogMatrixDeterminant",
|
||||
"LowerBound",
|
||||
"MatMul",
|
||||
"MatrixBandPart",
|
||||
"MatrixDiag",
|
||||
@ -2036,7 +2037,6 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"TensorScatterUpdate",
|
||||
"TridiagonalSolve",
|
||||
"TruncatedNormal",
|
||||
"UpperBound",
|
||||
"UnsortedSegmentMax",
|
||||
"UnsortedSegmentMin",
|
||||
"UnsortedSegmentProd",
|
||||
|
@ -20,15 +20,17 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
bool XlaKernelCreator::CanCreateKernel(const FunctionLibraryRuntime& flr,
|
||||
const NodeDef& node_def) const {
|
||||
return CanCreateXlaKernel(node_def);
|
||||
bool XlaKernelCreator::CanCreateKernel(
|
||||
const FunctionLibraryRuntime& flr,
|
||||
const std::shared_ptr<const NodeProperties>& props) const {
|
||||
return CanCreateXlaKernel(props->node_def);
|
||||
}
|
||||
|
||||
Status XlaKernelCreator::CreateKernel(FunctionLibraryRuntime* flr,
|
||||
const NodeDef& node_def,
|
||||
Status XlaKernelCreator::CreateKernel(
|
||||
FunctionLibraryRuntime* flr,
|
||||
const std::shared_ptr<const NodeProperties>& props,
|
||||
std::unique_ptr<OpKernel>* kernel) const {
|
||||
return CreateXlaKernel(flr, node_def, kernel);
|
||||
return CreateXlaKernel(flr, props->node_def, kernel);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -29,11 +29,13 @@ class XlaKernelCreator : public CustomKernelCreator {
|
||||
// Given a NodeDef 'node_def' and the function library runtime 'flr', returns
|
||||
// true if 'node_def' is a call to a compilable function defined in 'flr',
|
||||
// with the kXlaCompileAttr set.
|
||||
bool CanCreateKernel(const FunctionLibraryRuntime& flr,
|
||||
const NodeDef& node_def) const override;
|
||||
bool CanCreateKernel(
|
||||
const FunctionLibraryRuntime& flr,
|
||||
const std::shared_ptr<const NodeProperties>& props) const override;
|
||||
|
||||
// Given a supported NodeDef, returns a XlaLaunchOp that computes the node.
|
||||
Status CreateKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
Status CreateKernel(FunctionLibraryRuntime* flr,
|
||||
const std::shared_ptr<const NodeProperties>& props,
|
||||
std::unique_ptr<OpKernel>* kernel) const override;
|
||||
};
|
||||
|
||||
|
@ -30,10 +30,12 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
NodeDef ToNodeDef(const string& text) {
|
||||
std::shared_ptr<NodeProperties> ToNodeProperties(const string& text) {
|
||||
NodeDef node_def;
|
||||
DataTypeVector dummy;
|
||||
EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
|
||||
return node_def;
|
||||
return std::make_shared<NodeProperties>(nullptr, std::move(node_def), dummy,
|
||||
dummy);
|
||||
}
|
||||
|
||||
// Create a FunctionDef that takes one resource and one regular param
|
||||
@ -98,11 +100,11 @@ TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) {
|
||||
(*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
|
||||
Init({fdef});
|
||||
XlaKernelCreator xla_kernel_creator;
|
||||
NodeDef callsite =
|
||||
ToNodeDef(R"pb(
|
||||
auto callsite =
|
||||
ToNodeProperties(R"pb(
|
||||
name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
|
||||
)pb");
|
||||
(*callsite.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
|
||||
(*(callsite->node_def.mutable_attr()))["_XlaMustCompile"] = BoolAttr(true);
|
||||
|
||||
// Note: need to set attribute on the created node.
|
||||
Status status = xla_kernel_creator.CreateKernel(flr_, callsite, &kernel_);
|
||||
@ -127,7 +129,8 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrNotSet) {
|
||||
Init({fdef});
|
||||
XlaKernelCreator xla_kernel_creator;
|
||||
|
||||
Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto(
|
||||
Status status =
|
||||
xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto(
|
||||
name: 'XTimesY'
|
||||
op: 'XTimesY'
|
||||
input: 'a'
|
||||
@ -143,7 +146,8 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) {
|
||||
Init({fdef});
|
||||
XlaKernelCreator xla_kernel_creator;
|
||||
|
||||
Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto(
|
||||
Status status =
|
||||
xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto(
|
||||
name: 'XTimesY'
|
||||
op: 'XTimesY'
|
||||
input: 'a'
|
||||
|
@ -218,11 +218,12 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
|
||||
Device* dev = flr->device();
|
||||
Status s;
|
||||
OpKernelConstruction construction(
|
||||
DeviceType(dev->device_type()), dev,
|
||||
dev->GetAllocator(AllocatorAttributes()), &node_def,
|
||||
&fbody->fdef.signature(), flr, dev->resource_manager(), fbody->arg_types,
|
||||
input_memory_types, fbody->ret_types, output_memory_types,
|
||||
auto props = std::make_shared<NodeProperties>(
|
||||
&fbody->fdef.signature(), node_def, fbody->arg_types, fbody->ret_types);
|
||||
OpKernelConstruction construction(DeviceType(dev->device_type()), dev,
|
||||
dev->GetAllocator(AllocatorAttributes()),
|
||||
flr, dev->resource_manager(), props,
|
||||
input_memory_types, output_memory_types,
|
||||
flr->graph_def_version(), &s);
|
||||
|
||||
*kernel = absl::make_unique<XlaLocalLaunchBase>(
|
||||
|
3
tensorflow/compiler/mlir/g3doc/README.md
Normal file
3
tensorflow/compiler/mlir/g3doc/README.md
Normal file
@ -0,0 +1,3 @@
|
||||
# TensorFlow MLIR
|
||||
|
||||
These are the docs for: https://www.tensorflow.org/mlir
|
26
tensorflow/compiler/mlir/g3doc/_book.yaml
Normal file
26
tensorflow/compiler/mlir/g3doc/_book.yaml
Normal file
@ -0,0 +1,26 @@
|
||||
upper_tabs:
|
||||
# Tabs left of dropdown menu
|
||||
- include: /_upper_tabs_left.yaml
|
||||
- include: /api_docs/_upper_tabs_api.yaml
|
||||
# Dropdown menu
|
||||
- name: Resources
|
||||
path: /resources
|
||||
is_default: true
|
||||
menu:
|
||||
- include: /resources/_menu_toc.yaml
|
||||
lower_tabs:
|
||||
# Subsite tabs
|
||||
other:
|
||||
- name: Guide
|
||||
contents:
|
||||
- title: Overview
|
||||
path: /mlir/overview
|
||||
- heading: Dialects
|
||||
- title: Overview
|
||||
path: /mlir/dialects
|
||||
- title: TensorFlow
|
||||
path: /mlir/tf_ops
|
||||
- title: TensorFlow Lite
|
||||
path: /mlir/tfl_ops
|
||||
|
||||
- include: /_upper_tabs_right.yaml
|
54
tensorflow/compiler/mlir/g3doc/_index.yaml
Normal file
54
tensorflow/compiler/mlir/g3doc/_index.yaml
Normal file
@ -0,0 +1,54 @@
|
||||
book_path: /mlir/_book.yaml
|
||||
project_path: /mlir/_project.yaml
|
||||
description: <!--no description-->
|
||||
landing_page:
|
||||
custom_css_path: /site-assets/css/style.css
|
||||
rows:
|
||||
- heading: MLIR unifies the infrastructure for high-performance ML models in TensorFlow.
|
||||
items:
|
||||
- description: >
|
||||
The <a href="https://mlir.llvm.org/" class="external">MLIR</a> project defines a common
|
||||
intermediate representation (IR) that unifies the infrastructure required to execute high
|
||||
performance machine learning models in TensorFlow and similar ML frameworks. This project
|
||||
will include the application of HPC techniques, along with integration of
|
||||
search algorithms like reinforcement learning. MLIR aims to reduce the
|
||||
cost to bring up new hardware, and improve usability for existing
|
||||
TensorFlow users.
|
||||
|
||||
- code_block: |
|
||||
<pre class = "prettyprint">
|
||||
// Syntactically similar to LLVM:
|
||||
func @testFunction(%arg0: i32) {
|
||||
%x = call @thingToCall(%arg0) : (i32) -> i32
|
||||
br ^bb1
|
||||
^bb1:
|
||||
%y = addi %x, %x : i32
|
||||
return %y : i32
|
||||
}
|
||||
</pre>
|
||||
|
||||
- classname: devsite-landing-row-cards
|
||||
items:
|
||||
- heading: "Multi-Level Intermediate Representation for Compiler Infrastructure"
|
||||
youtube_id: qzljG6DKgic
|
||||
buttons:
|
||||
- label: Watch the video
|
||||
path: https://www.youtube.com/watch?v=qzljG6DKgic
|
||||
- heading: "A new intermediate representation and compiler framework"
|
||||
image_path: /resources/images/tf-logo-card-16x9.png
|
||||
path: https://blog.tensorflow.org/2019/04/mlir-new-intermediate-representation.html
|
||||
buttons:
|
||||
- label: Read on TensorFlow blog
|
||||
path: https://blog.tensorflow.org/2019/04/mlir-new-intermediate-representation.html
|
||||
- heading: MLIR on GitHub
|
||||
image_path: /resources/images/github-card-16x9.png
|
||||
path: https://github.com/llvm/llvm-project/tree/master/mlir
|
||||
buttons:
|
||||
- label: View on GitHub
|
||||
path: https://github.com/llvm/llvm-project/tree/master/mlir
|
||||
- heading: TensorFlow MLIR on GitHub
|
||||
image_path: /resources/images/github-card-16x9.png
|
||||
path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir
|
||||
buttons:
|
||||
- label: View on GitHub
|
||||
path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir
|
37
tensorflow/compiler/mlir/g3doc/dialects.md
Normal file
37
tensorflow/compiler/mlir/g3doc/dialects.md
Normal file
@ -0,0 +1,37 @@
|
||||
# MLIR dialects
|
||||
|
||||
## Overview
|
||||
|
||||
|
||||
To separate different hardware and software targets, MLIR has “dialects”,
|
||||
including:
|
||||
|
||||
* TensorFlow IR, which represents all things possible in TensorFlow graphs.
|
||||
* XLA HLO IR, which is designed to take advantage of XLA’s compilation
|
||||
abilities (with output to, among other things, TPUs).
|
||||
* An experimental affine dialect, which focuses on
|
||||
[polyhedral representations](https://en.wikipedia.org/wiki/Polytope_model)
|
||||
and optimizations.
|
||||
* LLVM IR, which has a 1:1 mapping between it and LLVM’s own representation,
|
||||
allowing MLIR to emit GPU and CPU code through LLVM.
|
||||
* TensorFlow Lite, which will translate to running code on mobile platforms.
|
||||
|
||||
Each dialect consists of a set of defined operations which have invariants
|
||||
placed on them, like: “This is a binary operator, and the inputs and outputs
|
||||
have the same types.”
|
||||
|
||||
## Adding to MLIR
|
||||
|
||||
MLIR has no fixed/built-in list of globally known operations (no “intrinsics”).
|
||||
Dialects can define entirely custom types, which is how MLIR can model things
|
||||
like the LLVM IR type system (which has first class aggregates), domain
|
||||
abstractions important for ML-optimized accelerators like quantized types, and
|
||||
even the Swift or Clang type systems (which are built around Swift/Clang
|
||||
declaration nodes) in the future.
|
||||
|
||||
If you want to connect a new low-level compiler, you would create a new dialect
|
||||
and the lowerings between the TensorFlow Graph dialect and your dialect.
|
||||
This smooths the path for hardware and compiler makers. You can even target
|
||||
dialects at different levels in the same model; the higher-level optimizers
|
||||
will respect the unfamiliar parts of the IR and wait for a lower level to handle
|
||||
it.
|
1
tensorflow/compiler/mlir/g3doc/images/mlir-infra.svg
Normal file
1
tensorflow/compiler/mlir/g3doc/images/mlir-infra.svg
Normal file
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 148 KiB |
36
tensorflow/compiler/mlir/g3doc/overview.md
Normal file
36
tensorflow/compiler/mlir/g3doc/overview.md
Normal file
@ -0,0 +1,36 @@
|
||||
# MLIR
|
||||
|
||||
## Overview
|
||||
|
||||
MLIR, or Multi-Level Intermediate Representation, is a representation format
|
||||
and library of compiler utilities that sits between the model representation
|
||||
and low-level compilers/executors that generate hardware-specific code.
|
||||
|
||||
MLIR is, at its heart, a flexible infrastructure for modern optimizing
|
||||
compilers. This means it consists of a specification for intermediate
|
||||
representations (IR) and a code toolkit to perform transformations on that
|
||||
representation. (In compiler parlance, as you move from higher-level
|
||||
representations to lower-level representations, these transformations can be
|
||||
called “lowerings”)
|
||||
|
||||
MLIR is highly influenced by [LLVM](https://llvm.org/) and unabashedly reuses
|
||||
many great ideas from it. It has a flexible type system, and allows
|
||||
representing, analyzing and transforming graphs combining multiple levels of
|
||||
abstraction in the same compilation unit. These abstractions include TensorFlow
|
||||
operations, nested polyhedral loop regions, and even LLVM instructions and fixed
|
||||
hardware operations and types.
|
||||
|
||||
We expect MLIR to be of interest to many groups, including:
|
||||
|
||||
* Compiler researchers and implementers looking to optimize performance and
|
||||
memory consumption of machine learning models
|
||||
* Hardware makers looking for a way to connect their hardware to TensorFlow,
|
||||
such as TPUs, portable neural hardware in phones, and other custom ASICs
|
||||
* People writing language bindings that want to take advantage of optimizing
|
||||
compilers and hardware acceleration.
|
||||
|
||||
The TensorFlow ecosystem contains a number of compilers and optimizers that
|
||||
operate at multiple levels of the software and hardware stack. We expect the
|
||||
gradual adoption of MLIR to simplify every aspect of this stack.
|
||||
|
||||
<img alt="MLIR overview diagram" src="./images/mlir-infra.svg"/>
|
@ -208,6 +208,7 @@ cc_library(
|
||||
"ir/tfl_ops.h.inc",
|
||||
"ir/tfl_ops_interface.cc.inc",
|
||||
"ir/tfl_ops_interface.h.inc",
|
||||
"runtime_verifiers.inc",
|
||||
"utils/attribute_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
@ -303,12 +304,14 @@ cc_library(
|
||||
"transforms/optimize_functional_ops.cc",
|
||||
"transforms/prepare_composite_functions_tf.cc",
|
||||
"transforms/prepare_tf.cc",
|
||||
"transforms/runtime_type_verify.cc",
|
||||
"transforms/split_merged_operands.cc",
|
||||
"transforms/trim_functions_tf.cc",
|
||||
"transforms/unroll_batch_matmul.cc",
|
||||
"transforms/while_loop_outline.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"ir/tfl_ops_interface.h.inc",
|
||||
"transforms/dilated_conv.h",
|
||||
"transforms/passes.h",
|
||||
"transforms/unroll_batch_matmul.h",
|
||||
@ -461,9 +464,9 @@ cc_library(
|
||||
)
|
||||
|
||||
tf_native_cc_binary(
|
||||
name = "operator-converter-gen",
|
||||
name = "converter-gen",
|
||||
srcs = [
|
||||
"operator_converter_gen.cc",
|
||||
"converter_gen.cc",
|
||||
],
|
||||
deps = [
|
||||
"@llvm-project//llvm:support",
|
||||
@ -473,14 +476,18 @@ tf_native_cc_binary(
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "operator_converter_inc",
|
||||
name = "converter_inc",
|
||||
tbl_outs = [
|
||||
(
|
||||
"", # This driver has no options.
|
||||
"--gen-operator-converters",
|
||||
"operator_converters.inc",
|
||||
),
|
||||
(
|
||||
"--gen-runtime-verifiers",
|
||||
"runtime_verifiers.inc",
|
||||
),
|
||||
],
|
||||
tblgen = ":operator-converter-gen",
|
||||
tblgen = ":converter-gen",
|
||||
td_file = "ir/tfl_ops.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
@ -582,7 +589,6 @@ cc_library(
|
||||
"@com_google_absl//absl/strings",
|
||||
"@flatbuffers",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
@ -645,12 +651,14 @@ tf_cc_binary(
|
||||
"//tensorflow/compiler/mlir:init_mlir",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
@ -694,7 +702,6 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
@ -727,7 +734,6 @@ cc_library(
|
||||
"//tensorflow/lite/tools/optimize:quantize_weights",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
|
@ -28,6 +28,9 @@ limitations under the License.
|
||||
#include "llvm/TableGen/Record.h"
|
||||
#include "llvm/TableGen/TableGenBackend.h"
|
||||
#include "mlir/TableGen/Attribute.h" // TF:llvm-project
|
||||
#include "mlir/TableGen/Format.h" // TF:llvm-project
|
||||
#include "mlir/TableGen/Operator.h" // TF:llvm-project
|
||||
#include "mlir/TableGen/Predicate.h" // TF:llvm-project
|
||||
|
||||
using llvm::DefInit;
|
||||
using llvm::dyn_cast;
|
||||
@ -41,6 +44,19 @@ using llvm::SmallVector;
|
||||
using llvm::StringInit;
|
||||
using llvm::StringRef;
|
||||
|
||||
enum ActionType {
|
||||
OpConv,
|
||||
RuntimeVerify,
|
||||
};
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
llvm::cl::opt<ActionType> action(
|
||||
llvm::cl::desc("Action to perform:"),
|
||||
llvm::cl::values(clEnumValN(OpConv, "gen-operator-converters",
|
||||
"Generate operator converters"),
|
||||
clEnumValN(RuntimeVerify, "gen-runtime-verifiers",
|
||||
"Generate TFLite runtime verifiers")));
|
||||
|
||||
// Returns the associated option name for the given op definition.
|
||||
static inline std::string GetOperatorOptionName(const Record &def) {
|
||||
assert(def.getName().startswith("TFL_") && "unexpected op prefix");
|
||||
@ -342,8 +358,101 @@ static bool OperatorWritersMain(raw_ostream &os, RecordKeeper &records) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static void GenOperandResultVerifier(raw_ostream &os,
|
||||
llvm::ArrayRef<llvm::Init *> values,
|
||||
StringRef valueKind) {
|
||||
mlir::tblgen::FmtContext fctx;
|
||||
|
||||
bool first = true;
|
||||
for (auto static_value : llvm::enumerate(values)) {
|
||||
auto *definit = llvm::cast<llvm::DefInit>(static_value.value());
|
||||
auto *val = definit->getDef()->getValue("tflRuntimeTypePredicate");
|
||||
if (!val) continue;
|
||||
|
||||
// Create code block on first type to verify.
|
||||
if (first) {
|
||||
os << " {\n";
|
||||
os << " unsigned index = " << static_value.index() << ";\n";
|
||||
first = false;
|
||||
}
|
||||
|
||||
mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
|
||||
auto desc =
|
||||
definit->getDef()->getValueAsString("tflRuntimeTypeDescription");
|
||||
|
||||
// Emit a loop to check all the dynamic values in the pack.
|
||||
os << formatv(" for (Value v : top.getODS{0}{1}s({2})) {{\n",
|
||||
// Capitalize the first letter to match the function name
|
||||
valueKind.substr(0, 1).upper(), valueKind.substr(1),
|
||||
static_value.index());
|
||||
|
||||
os << " (void)v;\n"
|
||||
<< " if (!("
|
||||
<< tgfmt(pred.getCondition(), &fctx.withSelf("v.getType()")) << ")) {\n"
|
||||
<< formatv(
|
||||
" return op->emitOpError(\"{0} #\") << index "
|
||||
"<< \" must be {1}, but got \" << v.getType();\n",
|
||||
valueKind, desc)
|
||||
<< " }\n" // if
|
||||
<< " ++index;\n"
|
||||
<< " }\n"; // for
|
||||
}
|
||||
|
||||
// Emit closing brace if needed.
|
||||
if (!first) os << " }\n";
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
||||
emitSourceFileHeader("MLIR TFLite Runtime Verifiers", os);
|
||||
|
||||
// Retrieve all the definitions derived from TFL_Op and sort by record name.
|
||||
std::vector<Record *> defs = records.getAllDerivedDefinitions("Op");
|
||||
llvm::sort(defs, LessRecord());
|
||||
|
||||
// Iterate through all the ops defined.
|
||||
for (const auto *def : defs) {
|
||||
mlir::tblgen::Operator op(*def);
|
||||
if (!op.getTrait("TflRuntimeVerifyOpInterface::Trait")) continue;
|
||||
|
||||
mlir::tblgen::FmtContext verify_ctx;
|
||||
os << "::mlir::LogicalResult " << op.getCppClassName()
|
||||
<< "::VerifyTflRuntimeTypes(::mlir::Operation *op) {\n";
|
||||
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
|
||||
verify_ctx.withOp("top");
|
||||
|
||||
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
|
||||
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
|
||||
auto &value = op.getOperand(i);
|
||||
// Skip from from first variadic operands for now. Else getOperand index
|
||||
// used below doesn't match.
|
||||
if (value.isVariadic()) break;
|
||||
if (!value.name.empty())
|
||||
verify_ctx.addSubst(value.name, formatv("op->getOperand({0})", i));
|
||||
}
|
||||
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
|
||||
auto &value = op.getResult(i);
|
||||
// Skip from from first variadic results for now. Else getResult index
|
||||
// used below doesn't match.
|
||||
if (value.isVariadic()) break;
|
||||
if (!value.name.empty())
|
||||
verify_ctx.addSubst(value.name, formatv("op->getResult({0})", i));
|
||||
}
|
||||
}
|
||||
GenOperandResultVerifier(os, def->getValueAsDag("arguments")->getArgs(),
|
||||
"operand");
|
||||
GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(),
|
||||
"result");
|
||||
os << " return mlir::success();\n}\n";
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv);
|
||||
if (action == ActionType::OpConv)
|
||||
return TableGenMain(argv[0], &OperatorWritersMain);
|
||||
return TableGenMain(argv[0], &RuntimeVerifierWriterMain);
|
||||
}
|
@ -71,4 +71,23 @@ def TFL_SparseOp : OpInterface<"SparseOpInterface"> {
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL runtime type verification of operand/result types.
|
||||
|
||||
def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> {
|
||||
let description = [{
|
||||
Interface to verify TFLite runtime op verification.
|
||||
|
||||
This verifies that the converted TFLite ops has operand/result type
|
||||
supported by the TFLite runtime.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
StaticInterfaceMethod<
|
||||
[{Returns whether the op's operands/results are supported by runtime.}],
|
||||
"LogicalResult", "VerifyTflRuntimeTypes", (ins "Operation*":$op)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
#endif // TFL_OP_INTERFACES
|
||||
|
@ -723,12 +723,11 @@ static LogicalResult Verify(PackOp op) {
|
||||
}
|
||||
|
||||
// Make sure all inputs have the same shape and element type.
|
||||
// TODO(rahulsp): Simplify once b/135032064 is fixed.
|
||||
for (Value operand : op.getOperands()) {
|
||||
auto other_type = operand.getType().cast<ShapedType>();
|
||||
if (input_type != other_type)
|
||||
// TODO(b/135032063): Simplify once fixed.
|
||||
for (Type operand_type : op.getOperandTypes()) {
|
||||
if (failed(mlir::verifyCompatibleShape(input_type, operand_type)))
|
||||
return op.emitOpError("operands should be of the same type. got ")
|
||||
<< input_type << ", " << other_type;
|
||||
<< input_type << ", " << operand_type;
|
||||
}
|
||||
|
||||
return success();
|
||||
@ -1872,6 +1871,7 @@ LogicalResult WhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) {
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.cc.inc"
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
|
||||
#include "tensorflow/compiler/mlir/lite/runtime_verifiers.inc"
|
||||
|
||||
Operation *TensorFlowLiteDialect::materializeConstant(OpBuilder &builder,
|
||||
Attribute value,
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -282,6 +282,7 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
if (pass_config.legalize_tf_while) {
|
||||
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
|
||||
}
|
||||
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
|
||||
|
||||
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
|
||||
|
@ -150,7 +150,8 @@ struct QuantizationPattern : public RewritePattern {
|
||||
|
||||
explicit QuantizationPattern(MLIRContext* context, bool enable_verify,
|
||||
float error_tolerance, bool single_layer_verify)
|
||||
: RewritePattern(DQ::getOperationName(), 1, context),
|
||||
// Set the score to a large number so it is always preferred.
|
||||
: RewritePattern(DQ::getOperationName(), 300, context),
|
||||
enable_verify(enable_verify),
|
||||
error_tolerance(error_tolerance),
|
||||
single_layer_verify(single_layer_verify) {}
|
||||
|
@ -178,15 +178,20 @@ func @inputsAfterOutputs() {
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{Found malformed ophint regions: missing inputs or outputs.}}
|
||||
module {
|
||||
func @extractOphintFailure() {
|
||||
func @extractOphintSame() {
|
||||
%0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x16x1xf32>
|
||||
%1 = call @AnotherFunc(%0) : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||
%2 = "tf.Sigmoid"(%1) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||
%3 = "tf.Mul"(%2, %1) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x16x1xf32>, tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||
%4 = "tf.Identity"(%3) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "d4b1eb00b81211e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation-d4b1eb00b81211e99426dc4a3e957995-0-None-None"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||
return
|
||||
|
||||
// CHECK: [[VAL_0:%.*]] = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x16x1xf32>
|
||||
// CHECK: [[VAL_1:%.*]] = call @AnotherFunc([[VAL_0]]) : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||
// CHECK: [[VAL_2:%.*]] = "tf.Sigmoid"([[VAL_1]]) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||
// CHECK: [[VAL_3:%.*]] = "tf.Mul"([[VAL_2]], [[VAL_1]]) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x16x1xf32>, tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||
// CHECK: [[VAL_4:%.*]] = "tf.Identity"([[VAL_3]]) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "d4b1eb00b81211e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation-d4b1eb00b81211e99426dc4a3e957995-0-None-None"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||
}
|
||||
|
||||
func @AnotherFunc(%arg0: tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> {
|
||||
|
@ -739,6 +739,15 @@ func @matrix_diag_v3(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
|
||||
// CHECK: return [[VAL_6]] : tensor<8x16x16xf32>
|
||||
}
|
||||
|
||||
func @matrix_set_diag(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
|
||||
%0 = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
|
||||
return %0 : tensor<3x3xi32>
|
||||
|
||||
// CHECK-LABEL: func @matrix_set_diag(
|
||||
// CHECK: [[VAL_0:%.*]] = "tfl.matrix_set_diag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
|
||||
// CHECK: return [[VAL_0]]
|
||||
}
|
||||
|
||||
func @maximum(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
%0 = "tf.Maximum"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
return %0 : tensor<8x16xf32>
|
||||
@ -1364,3 +1373,83 @@ func @reciprocal_i64(%arg0: tensor<8xi64>) -> tensor<8xi64> {
|
||||
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xi64>, tensor<8xi64>) -> tensor<8xi64>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
func @random_uniform() -> tensor<2x5xf32> {
|
||||
%0 = "tf.Const"() { value = dense<[2, 5]> : tensor<2xi32> } : () -> tensor<2xi32>
|
||||
%1 = "tf.RandomUniform"(%0) { seed = 1, seed2 = 0} : (tensor<2xi32>) -> tensor<2x5xf32>
|
||||
return %1 : tensor<2x5xf32>
|
||||
|
||||
// CHECK-LABEL: random_uniform
|
||||
// CHECK: %[[CST:.*]] = constant dense
|
||||
// CHECK: return %[[CST:.*]] : tensor<2x5xf32>
|
||||
}
|
||||
|
||||
func @random_uniform_no_fold(%arg0: tensor<2xi32>) -> tensor<2x5xf32> {
|
||||
%1 = "tf.RandomUniform"(%arg0) { seed = 0, seed2 = 0} : (tensor<2xi32>) -> tensor<2x5xf32>
|
||||
return %1 : tensor<2x5xf32>
|
||||
|
||||
// CHECK-LABEL: random_uniform_no_fold
|
||||
// CHECK: %[[RANDOM:.*]] = "tf.RandomUniform"
|
||||
}
|
||||
|
||||
func @random_uniform_no_fold2(%arg0: tensor<2xi32>) -> tensor<*xf32> {
|
||||
%1 = "tf.RandomUniform"(%arg0) { seed = 1, seed2 = 2} : (tensor<2xi32>) -> tensor<*xf32>
|
||||
return %1 : tensor<*xf32>
|
||||
|
||||
// CHECK-LABEL: random_uniform_no_fold2
|
||||
// CHECK: %[[RANDOM:.*]] = "tf.RandomUniform"
|
||||
}
|
||||
|
||||
func @random_uniform_no_fold3() -> tensor<2x5xf64> {
|
||||
%0 = "tf.Const"() { value = dense<[2, 5]> : tensor<2xi32> } : () -> tensor<2xi32>
|
||||
%1 = "tf.RandomUniform"(%0) { seed = 1, seed2 = 0} : (tensor<2xi32>) -> tensor<2x5xf64>
|
||||
return %1 : tensor<2x5xf64>
|
||||
|
||||
// CHECK-LABEL: random_uniform_no_fold3
|
||||
// CHECK: %[[RANDOM:.*]] = "tf.RandomUniform"
|
||||
}
|
||||
|
||||
func @LstmWithoutProjection(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x16xf32>) {
|
||||
%1 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16x28xf32>} : () -> tensor<16x28xf32>
|
||||
%2 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16x16xf32>} : () -> tensor<16x16xf32>
|
||||
%3 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16xf32>} : () -> tensor<16xf32>
|
||||
%4 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<1x16xf32>} : () -> tensor<1x16xf32>
|
||||
%5 = "tf.Const"() {device = "", dtype = f32, value = dense<-1.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
|
||||
%6:3 = "tf.UnidirectionalSequenceLstm"(%arg, %1, %1, %1, %1, %2, %2, %2, %2, %3, %3, %3, %3, %3, %3, %3, %5, %5, %4, %4) {_tflite_input_indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 18, 19], device = ""} : (tensor<28x1x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1x16xf32>, tensor<1x16xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<28x1x16xf32>)
|
||||
return %6#2 : tensor<28x1x16xf32>
|
||||
}
|
||||
|
||||
// CHECK: func @LstmWithoutProjection([[VAL_0:%.*]]: tensor<28x1x28xf32>) -> tensor<28x1x16xf32> {
|
||||
// CHECK: [[VAL_1:%.*]] = constant dense<0.000000e+00> : tensor<16x28xf32>
|
||||
// CHECK: [[VAL_2:%.*]] = constant dense<0.000000e+00> : tensor<16x16xf32>
|
||||
// CHECK: [[VAL_3:%.*]] = constant dense<0.000000e+00> : tensor<16xf32>
|
||||
// CHECK: [[VAL_4:%.*]] = constant dense<0.000000e+00> : tensor<1x16xf32>
|
||||
// CHECK: [[VAL_5:%.*]] = constant unit
|
||||
// CHECK: [[VAL_6:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_1]], [[VAL_1]], [[VAL_1]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_2]], [[VAL_2]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_5]], [[VAL_5]], [[VAL_4]], [[VAL_4]], [[VAL_5]], [[VAL_5]], [[VAL_5]], [[VAL_5]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<28x1x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, none, none, tensor<1x16xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x16xf32>
|
||||
// CHECK: return [[VAL_6]] : tensor<28x1x16xf32>
|
||||
// CHECK: }
|
||||
|
||||
func @LstmWithProjection(%arg: tensor<28x1x16xf32>) -> (tensor<28x1x8xf32>) {
|
||||
%1 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16x16xf32>} : () -> tensor<16x16xf32>
|
||||
%2 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16x8xf32>} : () -> tensor<16x8xf32>
|
||||
%3 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16xf32>} : () -> tensor<16xf32>
|
||||
%4 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<1x16xf32>} : () -> tensor<1x16xf32>
|
||||
%5 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<8x16xf32>} : () -> tensor<8x16xf32>
|
||||
%6 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<1x8xf32>} : () -> tensor<1x8xf32>
|
||||
%7 = "tf.Const"() {device = "", dtype = f32, value = dense<-1.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
|
||||
%8:3 = "tf.UnidirectionalSequenceLstm"(%arg, %1, %1, %1, %1, %2, %2, %2, %2, %7, %7, %7, %3, %3, %3, %3, %5, %7, %6, %4) {_tflite_input_indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 18, 19], device = ""} : (tensor<28x1x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<8x16xf32>, tensor<1xf32>, tensor<1x8xf32>, tensor<1x16xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<28x1x8xf32>)
|
||||
return %8#2 : tensor<28x1x8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @LstmWithProjection(
|
||||
// CHECK-SAME: [[VAL_7:%.*]]: tensor<28x1x16xf32>) -> tensor<28x1x8xf32> {
|
||||
// CHECK: [[VAL_8:%.*]] = constant dense<0.000000e+00> : tensor<16x16xf32>
|
||||
// CHECK: [[VAL_9:%.*]] = constant dense<0.000000e+00> : tensor<16x8xf32>
|
||||
// CHECK: [[VAL_10:%.*]] = constant dense<0.000000e+00> : tensor<16xf32>
|
||||
// CHECK: [[VAL_11:%.*]] = constant dense<0.000000e+00> : tensor<1x16xf32>
|
||||
// CHECK: [[VAL_12:%.*]] = constant dense<0.000000e+00> : tensor<8x16xf32>
|
||||
// CHECK: [[VAL_13:%.*]] = constant dense<0.000000e+00> : tensor<1x8xf32>
|
||||
// CHECK: [[VAL_14:%.*]] = constant unit
|
||||
// CHECK: [[VAL_15:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_12]], [[VAL_14]], [[VAL_13]], [[VAL_11]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_14]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<28x1x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, none, none, none, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<8x16xf32>, none, tensor<1x8xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x8xf32>
|
||||
// CHECK: return [[VAL_15]] : tensor<28x1x8xf32>
|
||||
// CHECK: }
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt -split-input-file -verify-diagnostics %s | FileCheck %s --dump-input-on-failure
|
||||
// RUN: tf-opt -split-input-file -verify-diagnostics -tfl-runtime-verify %s | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// Unary math ops
|
||||
// -----
|
||||
@ -878,6 +878,14 @@ func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
|
||||
|
||||
// -----
|
||||
|
||||
func @packUnranked(%arg0: tensor<2xi32>, %arg1: tensor<*xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32}
|
||||
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<*xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @packInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<1x4x2xi32> {
|
||||
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32}
|
||||
%0 = "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<1x4x2xi32>
|
||||
|
@ -511,3 +511,34 @@ func @PadStridedSliceNewAxisMask2(%arg0: tensor<4x64x64x1xf32>) -> tensor<1x4x64
|
||||
%1 = "tf.StridedSlice"(%0, %cst, %cst, %cst_0) {Index = i32, T = f32, _output_shapes = ["tfshape$dim { size: 1 } dim { size: 4 } dim { size: 64 } dim { size: 64 }"], begin_mask = 6 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 1 : i64, shrink_axis_mask = 0 : i64} : (tensor<4x64x64xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x4x64x64xf32>
|
||||
return %1 : tensor<1x4x64x64xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @MatrixSetDiagV2Conversion
|
||||
func @MatrixSetDiagV2Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
|
||||
%cst = constant dense<0> : tensor<i32>
|
||||
%0 = "tf.MatrixSetDiagV2"(%arg0, %arg1, %cst) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
|
||||
return %0 : tensor<3x3xi32>
|
||||
|
||||
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @MatrixSetDiagV2NonZeroK
|
||||
func @MatrixSetDiagV2NonZeroK(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
|
||||
%cst = constant dense<1> : tensor<i32>
|
||||
%0 = "tf.MatrixSetDiagV2"(%arg0, %arg1, %cst) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
|
||||
return %0 : tensor<3x3xi32>
|
||||
|
||||
// CHECK: %[[CST:.*]] = constant dense<1> : tensor<i32>
|
||||
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiagV2"(%arg0, %arg1, %[[CST]]) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @MatrixSetDiagV3Conversion
|
||||
func @MatrixSetDiagV3Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
|
||||
%cst = constant dense<0> : tensor<i32>
|
||||
%0 = "tf.MatrixSetDiagV3"(%arg0, %arg1, %cst) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
|
||||
return %0 : tensor<3x3xi32>
|
||||
|
||||
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
@ -2,39 +2,44 @@
|
||||
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-quantize -tfl-numeric-verify | FileCheck --check-prefix=DEBUG %s
|
||||
|
||||
// CHECK-LABEL: QuantizeFloatConst
|
||||
func @QuantizeFloatConst() -> tensor<f32> {
|
||||
func @QuantizeFloatConst() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
|
||||
%0 = constant dense<-0.1> : tensor<2x2xf32>
|
||||
%1 = "tfl.quantize"(%0) {qtype = tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||
%2 = "tfl.dequantize"(%1) : (tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<f32>
|
||||
return %2 : tensor<f32>
|
||||
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||
return %1 : tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||
|
||||
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<0> : tensor<2x2xi8>}
|
||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[cst]])
|
||||
// CHECK: return %[[dq]] : tensor<f32>
|
||||
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<0> : tensor<2x2xi8>}
|
||||
// CHECK: return %[[cst]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeDenseFloatConst
|
||||
func @QuantizeDenseFloatConst() -> tensor<2x2xf32> {
|
||||
func @QuantizeDenseFloatConst() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
|
||||
%0 = constant dense<[[-0.1, 1.0], [1.0, 3.0]]> : tensor<2x2xf32>
|
||||
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||
%2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<2x2xf32>
|
||||
return %2 : tensor<2x2xf32>
|
||||
return %1 : tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||
|
||||
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<{{\[\[}}0, -1], {{\[}}-1, -1]]> : tensor<2x2xi8>}
|
||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[cst]])
|
||||
// CHECK: return %[[dq]] : tensor<2x2xf32>
|
||||
// CHECK: return %[[cst]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeSplatFloatConst
|
||||
func @QuantizeSplatFloatConst() -> tensor<2x2xf32> {
|
||||
func @QuantizeSplatFloatConst() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
|
||||
%0 = constant dense<3.0> : tensor<2x2xf32>
|
||||
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||
return %1 : tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||
|
||||
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<-1> : tensor<2x2xi8>}
|
||||
// CHECK: return %[[cst]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: NotQuantizeFloatConst
|
||||
func @NotQuantizeFloatConst() -> tensor<2x2xf32> {
|
||||
%0 = constant dense<-0.1> : tensor<2x2xf32>
|
||||
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||
%2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<2x2xf32>
|
||||
return %2 : tensor<2x2xf32>
|
||||
|
||||
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<-1> : tensor<2x2xi8>}
|
||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[cst]])
|
||||
// CHECK: return %[[dq]] : tensor<2x2xf32>
|
||||
// CHECK: %[[cst:.*]] = constant dense<-1.000000e-01> : tensor<2x2xf32>
|
||||
// CHECK: return %[[cst]] : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: DequantizeAndQuantize
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/init_mlir.h"
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
@ -32,8 +33,10 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
@ -130,12 +133,24 @@ int main(int argc, char **argv) {
|
||||
llvm::SourceMgr source_mgr;
|
||||
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context);
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> module =
|
||||
tensorflow::LoadFromGraphdefOrMlirSource(
|
||||
StatusOr<mlir::OwningModuleRef> module;
|
||||
|
||||
// TODO(b/147435528): We need to test the e2e behavior once the graph freezing
|
||||
// inside mlir is done.
|
||||
if (import_saved_model || import_saved_model_v1) {
|
||||
if (input_mlir)
|
||||
module = tensorflow::errors::InvalidArgument(
|
||||
"Importing saved model should not have input_mlir set");
|
||||
module = tensorflow::ImportSavedModel(
|
||||
import_saved_model, import_saved_model_v1, input_file_name,
|
||||
saved_model_tags, saved_model_exported_names, &context);
|
||||
} else {
|
||||
module = tensorflow::LoadFromGraphdefOrMlirSource(
|
||||
input_file_name, input_mlir, use_splatted_constant, custom_opdefs,
|
||||
debug_info_file, input_arrays, input_dtypes, input_shapes,
|
||||
output_arrays,
|
||||
/*prune_unused_nodes=*/true, &source_mgr, &context);
|
||||
}
|
||||
|
||||
// If errors occur, the library call in the above already logged the error
|
||||
// message. So we can just return here.
|
||||
@ -182,6 +197,7 @@ int main(int argc, char **argv) {
|
||||
pass_config.inline_functions = inline_functions;
|
||||
|
||||
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm);
|
||||
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
|
||||
|
||||
std::string result;
|
||||
auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
|
@ -22,6 +22,33 @@ using llvm::cl::opt;
|
||||
opt<std::string> input_file_name(llvm::cl::Positional,
|
||||
llvm::cl::desc("<input file>"),
|
||||
llvm::cl::init("-"));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
opt<bool> import_saved_model(
|
||||
"savedmodel-to-mlir",
|
||||
llvm::cl::desc("Import a saved model to its MLIR representation"),
|
||||
llvm::cl::value_desc("dir"));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
opt<bool> import_saved_model_v1(
|
||||
"savedmodel-v1-to-mlir",
|
||||
llvm::cl::desc("Import a saved model V1 to its MLIR representation"),
|
||||
llvm::cl::value_desc("dir"));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
opt<std::string> saved_model_tags(
|
||||
"tf-savedmodel-tags",
|
||||
llvm::cl::desc("Tags used to indicate which MetaGraphDef to import, "
|
||||
"separated by ','"),
|
||||
llvm::cl::init("serve"));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
opt<std::string> saved_model_exported_names(
|
||||
"tf-savedmodel-exported-names",
|
||||
llvm::cl::desc("Names to export from SavedModel, separated by ','. Empty "
|
||||
"(the default) means export all."),
|
||||
llvm::cl::init(""));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
opt<std::string> output_file_name("o", llvm::cl::desc("<output file>"),
|
||||
llvm::cl::value_desc("filename"),
|
||||
|
@ -39,4 +39,10 @@ extern llvm::cl::opt<bool> inline_functions;
|
||||
extern llvm::cl::list<std::string> custom_opdefs;
|
||||
extern llvm::cl::opt<bool> emit_quant_adaptor_ops;
|
||||
extern llvm::cl::opt<std::string> quant_stats_file_name;
|
||||
|
||||
// Import saved model.
|
||||
extern llvm::cl::opt<bool> import_saved_model;
|
||||
extern llvm::cl::opt<bool> import_saved_model_v1;
|
||||
extern llvm::cl::opt<std::string> saved_model_tags;
|
||||
extern llvm::cl::opt<std::string> saved_model_exported_names;
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_TRANSLATE_CL_H_
|
||||
|
@ -15,6 +15,10 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Parser.h" // TF:llvm-project
|
||||
@ -155,4 +159,37 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
||||
bool import_saved_model, bool import_saved_model_v1,
|
||||
const std::string& input_filename, const std::string& saved_model_tags,
|
||||
const std::string& saved_model_exported_names, mlir::MLIRContext* context) {
|
||||
if (import_saved_model) {
|
||||
std::unordered_set<std::string> tags =
|
||||
absl::StrSplit(saved_model_tags, ',');
|
||||
std::vector<std::string> exported_names =
|
||||
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
||||
|
||||
auto module = tensorflow::SavedModelToMlirImport(
|
||||
input_filename, tags, absl::Span<std::string>(exported_names), context);
|
||||
if (!module)
|
||||
return tensorflow::errors::InvalidArgument("fail to open input file");
|
||||
|
||||
return module;
|
||||
} else if (import_saved_model_v1) {
|
||||
std::unordered_set<std::string> tags =
|
||||
absl::StrSplit(saved_model_tags, ',');
|
||||
|
||||
auto module =
|
||||
tensorflow::SavedModelV1ToMlirImport(input_filename, tags, context);
|
||||
|
||||
if (!module)
|
||||
return tensorflow::errors::InvalidArgument("fail to open input file");
|
||||
|
||||
return module;
|
||||
} else {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Should be either saved model v1 or v2");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -40,6 +40,12 @@ LoadFromGraphdefOrMlirSource(
|
||||
absl::string_view output_arrays, bool prune_unused_nodes,
|
||||
llvm::SourceMgr* source_mgr, mlir::MLIRContext* context);
|
||||
|
||||
// Load Saved model (either v1 or v2) into MLIR.
|
||||
stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
||||
bool import_saved_model, bool import_saved_model_v1,
|
||||
const std::string& input_filename, const std::string& saved_model_tags,
|
||||
const std::string& saved_model_exported_names, mlir::MLIRContext* context);
|
||||
|
||||
// Taking a MLIR module in TF executor dialect and a set of parameters,
|
||||
// applies a set of passes to convert the module to TF Lite dialect and
|
||||
// serializes the result to a string. Depending on an attribute in the module
|
||||
|
@ -698,11 +698,10 @@ void ExtractOphintPass::runOnModule() {
|
||||
if (ophint_composite_ops.empty()) continue;
|
||||
|
||||
// Verify: Make sure all ophint_composite_ops are valid.
|
||||
// If not valid, we just don't do anything.
|
||||
for (const auto& kv : ophint_composite_ops) {
|
||||
if (failed(kv.getValue().VerifyOphint())) {
|
||||
module.emitError()
|
||||
<< "Found malformed ophint regions: missing inputs or outputs.";
|
||||
return signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -365,3 +365,7 @@ def : Pat<
|
||||
/*padding=*/ $padding,
|
||||
/*stride_h=*/ ExtractI32At<1>:$strides,
|
||||
/*stride_w=*/ ExtractI32At<2>:$strides)>;
|
||||
|
||||
def : Pat<
|
||||
(TF_MatrixSetDiagOp $input, $diagonal),
|
||||
(TFL_MatrixSetDiagOp $input, $diagonal)>;
|
||||
|
@ -49,6 +49,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/random/philox_random.h"
|
||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -61,6 +63,9 @@ namespace {
|
||||
using xla::Status;
|
||||
using xla::StatusOr;
|
||||
|
||||
constexpr char kUnidirectionalSequenceLstm[] = "tf.UnidirectionalSequenceLstm";
|
||||
constexpr char kTfLiteInputIndices[] = "_tflite_input_indices";
|
||||
|
||||
// Legalize operations in functions.
|
||||
struct LegalizeTF : public FunctionPass<LegalizeTF> {
|
||||
void runOnFunction() override;
|
||||
@ -114,9 +119,54 @@ DECL_CONVERT_OP(SplitV);
|
||||
DECL_CONVERT_OP(StridedSlice);
|
||||
DECL_CONVERT_OP(Unpack);
|
||||
DECL_CONVERT_OP(Reciprocal);
|
||||
DECL_CONVERT_OP(RandomUniform);
|
||||
|
||||
#undef DECL_CONVERT_OP
|
||||
|
||||
PatternMatchResult ConvertTFRandomUniformOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto random_uniform_op = cast<TF::RandomUniformOp>(op);
|
||||
if (random_uniform_op.seed() == 0 && random_uniform_op.seed2() == 0) {
|
||||
return matchFailure();
|
||||
}
|
||||
if (!random_uniform_op.dtype().isF32()) {
|
||||
return matchFailure();
|
||||
}
|
||||
typedef tensorflow::random::UniformDistribution<
|
||||
tensorflow::random::PhiloxRandom, float>
|
||||
Distribution;
|
||||
|
||||
tensorflow::random::PhiloxRandom generator(
|
||||
random_uniform_op.seed().getSExtValue(),
|
||||
random_uniform_op.seed2().getSExtValue());
|
||||
Distribution dist;
|
||||
int num_elements = 0;
|
||||
if (auto output_type =
|
||||
random_uniform_op.output().getType().dyn_cast_or_null<ShapedType>()) {
|
||||
if (auto ranked_output = output_type.dyn_cast_or_null<RankedTensorType>()) {
|
||||
if (!ranked_output.hasRank() || ranked_output.getNumDynamicDims() != 0) {
|
||||
return matchFailure();
|
||||
}
|
||||
num_elements = output_type.getNumElements();
|
||||
size_t offset = 0;
|
||||
size_t num_samples = Distribution::kResultElementCount;
|
||||
llvm::SmallVector<float, 32> data;
|
||||
data.resize(num_elements);
|
||||
while (offset < num_elements) {
|
||||
const typename Distribution::ResultType samples = dist(&generator);
|
||||
std::copy(&samples[0],
|
||||
&samples[0] + std::min(num_samples, data.size() - offset),
|
||||
&data[0] + offset);
|
||||
offset += num_samples;
|
||||
}
|
||||
auto output_data = DenseFPElementsAttr::get(output_type, data);
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(op, output_type, output_data);
|
||||
return matchSuccess();
|
||||
}
|
||||
}
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
PatternMatchResult ConvertTFConcatOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_concat_op = cast<TF::ConcatOp>(op);
|
||||
@ -514,6 +564,74 @@ PatternMatchResult ConvertTFReciprocalOp::matchAndRewrite(
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
// Legalize unidirectional sequence lstm.
|
||||
struct LegalizeUnidirectionalSequenceLstm : public RewritePattern {
|
||||
explicit LegalizeUnidirectionalSequenceLstm(MLIRContext* context)
|
||||
: RewritePattern(kUnidirectionalSequenceLstm, 1, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation* op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
auto tflite_indices_attr =
|
||||
op->getAttrOfType<ArrayAttr>(kTfLiteInputIndices);
|
||||
if (!tflite_indices_attr) return matchFailure();
|
||||
|
||||
SmallVector<int64_t, 20> tflite_indices;
|
||||
for (auto index_attr : tflite_indices_attr.getValue()) {
|
||||
IntegerAttr index = index_attr.cast<IntegerAttr>();
|
||||
tflite_indices.push_back(index.getInt());
|
||||
}
|
||||
|
||||
// Optional input placeholder.
|
||||
Value none = rewriter.create<mlir::ConstantOp>(
|
||||
op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
|
||||
|
||||
// Populate inputs.
|
||||
// UnidirectionalSequenceLstm is expected to have 24 inputs.
|
||||
SmallVector<Value, 24> inputs;
|
||||
int count = 0;
|
||||
int total_ophint_converted_inputs = tflite_indices.size();
|
||||
for (int i = 0; i < 24; ++i) {
|
||||
if (count < total_ophint_converted_inputs && tflite_indices[count] == i) {
|
||||
// specified input.
|
||||
inputs.push_back(op->getOperand(i));
|
||||
count++;
|
||||
} else {
|
||||
// Non specified input.
|
||||
inputs.push_back(none);
|
||||
}
|
||||
}
|
||||
|
||||
// Populate outputs.
|
||||
// UnidirectionalSequenceLstm should only have 1 output, and that is the
|
||||
// original ophint converted node's 3rd output.
|
||||
SmallVector<Type, 4> result_types;
|
||||
result_types.push_back(op->getOpResult(2).getType());
|
||||
|
||||
// Populate attributes.
|
||||
SmallVector<NamedAttribute, 4> attributes;
|
||||
// Activation will always be tanh.
|
||||
attributes.push_back(rewriter.getNamedAttr("fused_activation_function",
|
||||
rewriter.getStringAttr("TANH")));
|
||||
// cell_clip.
|
||||
attributes.push_back(
|
||||
rewriter.getNamedAttr("cell_clip", rewriter.getF32FloatAttr(10.0)));
|
||||
// proj_clip.
|
||||
attributes.push_back(
|
||||
rewriter.getNamedAttr("proj_clip", rewriter.getF32FloatAttr(0.0)));
|
||||
// will always be time_majored.
|
||||
attributes.push_back(
|
||||
rewriter.getNamedAttr("time_major", rewriter.getBoolAttr(true)));
|
||||
|
||||
auto lstm_op = rewriter.create<TFL::UnidirectionalSequenceLSTMOp>(
|
||||
op->getLoc(), result_types, inputs, attributes);
|
||||
|
||||
// Rewire the output.
|
||||
op->getResult(2).replaceAllUsesWith(lstm_op.getResult());
|
||||
op->erase();
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
void LegalizeTF::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto* ctx = &getContext();
|
||||
@ -521,11 +639,15 @@ void LegalizeTF::runOnFunction() {
|
||||
|
||||
// Add the generated patterns to the list.
|
||||
populateWithGenerated(ctx, &patterns);
|
||||
patterns.insert<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp,
|
||||
ConvertTFMatrixDiagV2Op, ConvertTFMatrixDiagV3Op,
|
||||
ConvertTFPackOp, ConvertTFReshapeOp, ConvertTFSplitOp,
|
||||
ConvertTFSplitVOp, ConvertTFStridedSliceOp, ConvertTFUnpackOp,
|
||||
ConvertTFAssertOp, ConvertTFReciprocalOp>(ctx);
|
||||
patterns
|
||||
.insert<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp,
|
||||
ConvertTFMatrixDiagV2Op, ConvertTFMatrixDiagV3Op, ConvertTFPackOp,
|
||||
ConvertTFReshapeOp, ConvertTFSplitOp, ConvertTFSplitVOp,
|
||||
ConvertTFStridedSliceOp, ConvertTFUnpackOp, ConvertTFAssertOp,
|
||||
ConvertTFReciprocalOp, ConvertTFRandomUniformOp>(ctx);
|
||||
|
||||
// Ophint python converter converted tf node pattern.
|
||||
patterns.insert<LegalizeUnidirectionalSequenceLstm>(ctx);
|
||||
applyPatternsGreedily(func, patterns);
|
||||
}
|
||||
|
||||
|
@ -199,6 +199,22 @@ def : Pat<
|
||||
(TFL_HardSwishOp $x),
|
||||
[(EqualOperands $x, $y)]>;
|
||||
|
||||
// Matching HardSwish with extra FakeQuant. These FakeQuant ops were due to
|
||||
// incorrect placement in the quantization aware training.
|
||||
// TODO(b/149735743): We should make the placement automatically.
|
||||
def : Pat<
|
||||
(TFL_MulOp (TFL_DequantizeOp (TFL_QuantizeOp
|
||||
(TFL_MulOp
|
||||
$x, (TFL_DequantizeOp (TFL_QuantizeOp (TFL_AddOp
|
||||
$y,
|
||||
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
|
||||
TFL_AF_Relu6), $qattr2)),
|
||||
TFL_AF_None), $qattr1)),
|
||||
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
|
||||
TFL_AF_None),
|
||||
(TFL_HardSwishOp $x),
|
||||
[(EqualOperands $x, $y)]>;
|
||||
|
||||
// Constraint that the attribute value is less than 'n'
|
||||
class ConstDoubleValueLessThan<string n> : Constraint<
|
||||
CPred<"$0.isa<DenseElementsAttr>() && "
|
||||
|
@ -91,6 +91,9 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFWhilePass();
|
||||
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateWhileOutlinePass();
|
||||
|
||||
// Verifies runtime supports types used.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateRuntimeTypeVerifyPass();
|
||||
|
||||
} // namespace TFL
|
||||
|
||||
} // namespace mlir
|
||||
|
@ -190,3 +190,16 @@ def : Pat<(TF_ReshapeOp:$old_value
|
||||
// parameters of the input, so we can remove the quantization ops.
|
||||
def : Pat<(TF_RankOp (TFL_DequantizeOp (TFL_QuantizeOp $input, $qtype))),
|
||||
(TF_RankOp $input)>;
|
||||
|
||||
// `k` is expected to be 0, other values are not supported currently.
|
||||
def : Pat<(TF_MatrixSetDiagV2Op $input, $diagonal,
|
||||
(ConstantOp ConstantAttr<I32ElementsAttr, "{0}">)),
|
||||
(TF_MatrixSetDiagOp $input, $diagonal)>;
|
||||
|
||||
// `align` attribute can be ignored because we only support converting
|
||||
// `MatrixSetDiagV3` to `MatrixSetDiag` with default `k` inputs.
|
||||
def : Pat<(TF_MatrixSetDiagV3Op $input, $diagonal,
|
||||
(ConstantOp ConstantAttr<I32ElementsAttr, "{0}">),
|
||||
$align),
|
||||
(TF_MatrixSetDiagOp $input, $diagonal)>;
|
||||
|
||||
|
@ -21,12 +21,20 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||
|
||||
// Quantize attribute $0 by using quantization parameter from %1.
|
||||
def QuantizeByQuantizedType : NativeCodeCall<"quant::Quantize($0, $1.getValue())">;
|
||||
def F32ElementsAttr : ElementsAttrBase<
|
||||
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
|
||||
|
||||
// Squash tfl.dequantize and tfl.quantize pairs.
|
||||
// TODO(fengliuai): Compare the scale of input and output. This can also be
|
||||
// squashed to a requantize op if the scales are different.
|
||||
def : Pat<(TFL_QuantizeOp (TFL_DequantizeOp $in), $qt), (replaceWithValue $in)>;
|
||||
|
||||
// If the tfl.dequantize op wasn't fused, we shouldn't quantize the floating
|
||||
// point constant.
|
||||
def : Pat<(TFL_DequantizeOp
|
||||
(TFL_QuantizeOp (ConstantOp F32ElementsAttr:$cst), $qt)),
|
||||
(ConstantOp $cst)>;
|
||||
|
||||
// Quantize the value of a constant op if the quantization parameters have been
|
||||
// propagated to the output.
|
||||
def : Pat<(TFL_QuantizeOp
|
||||
|
@ -0,0 +1,52 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.h.inc"
|
||||
namespace TFL {
|
||||
namespace {
|
||||
|
||||
// This pass verifies that the operands and results types are supported by
|
||||
// TFLite runtime.
|
||||
class RuntimeTypeVerifyPass : public mlir::FunctionPass<RuntimeTypeVerifyPass> {
|
||||
public:
|
||||
explicit RuntimeTypeVerifyPass() {}
|
||||
|
||||
private:
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
void RuntimeTypeVerifyPass::runOnFunction() {
|
||||
getFunction().walk([&](TflRuntimeVerifyOpInterface op) {
|
||||
if (failed(op.VerifyTflRuntimeTypes(op.getOperation())))
|
||||
signalPassFailure();
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Verifies runtime supports types used.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateRuntimeTypeVerifyPass() {
|
||||
return std::make_unique<RuntimeTypeVerifyPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<RuntimeTypeVerifyPass> pass(
|
||||
"tfl-runtime-verify", "TFLite runtime verification");
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
@ -168,6 +168,10 @@ std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val) {
|
||||
result.getResultNumber());
|
||||
return std::string(result.getOwner()->getName().getStringRef());
|
||||
}
|
||||
// Use the ASM syntax for BloackArgument
|
||||
if (auto arg = val.dyn_cast<mlir::BlockArgument>()) {
|
||||
return "arg" + std::to_string(arg.getArgNumber());
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
|
@ -287,6 +287,7 @@ cc_library(
|
||||
"transforms/materialize_mlir_passthrough_op.cc",
|
||||
"transforms/optimize.cc",
|
||||
"transforms/optimize_global_tensors.cc",
|
||||
"transforms/parallel_execute_to_islands.cc",
|
||||
"transforms/promote_resources_to_args.cc",
|
||||
"transforms/raise_control_flow.cc",
|
||||
"transforms/replicate_invariant_op_hoisting.cc",
|
||||
@ -708,7 +709,6 @@ cc_library(
|
||||
deps = [
|
||||
":tensorflow_dialect_registration",
|
||||
":tf_dialect_passes",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
],
|
||||
)
|
||||
|
||||
@ -913,7 +913,6 @@ cc_library(
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Pass",
|
||||
|
@ -41,11 +41,52 @@ limitations under the License.
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "mlir/Support/STLExtras.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_device {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TF Device Dialect Interfaces
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
struct TFInlinerInterface : public DialectInlinerInterface {
|
||||
using DialectInlinerInterface::DialectInlinerInterface;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Analysis Hooks
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// Defines the legality of inlining TF Device operations.
|
||||
bool isLegalToInline(Operation*, Region*, BlockAndValueMapping&) const final {
|
||||
// For now, enable inlining all operations.
|
||||
return true;
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Transformation Hooks
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// Attempts to materialize a conversion for a type mismatch between a call
|
||||
// from this dialect, and a callable region. This method should generate an
|
||||
// operation that takes 'input' as the only operand, and produces a single
|
||||
// result of 'resultType'. If a conversion can not be generated, nullptr
|
||||
// should be returned.
|
||||
// This is just re-using the same logic as the TensorFlow dialect right now.
|
||||
Operation* materializeCallConversion(OpBuilder& builder, Value input,
|
||||
Type result_type,
|
||||
Location conversion_loc) const final {
|
||||
if (!result_type.isa<TensorType>() || !input.getType().isa<TensorType>())
|
||||
return nullptr;
|
||||
return builder.create<TF::CastOp>(conversion_loc, result_type, input,
|
||||
/*truncate=*/builder.getBoolAttr(false));
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context)
|
||||
: Dialect(/*name=*/"tf_device", context) {
|
||||
addOperations<
|
||||
@ -54,6 +95,8 @@ TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context)
|
||||
>();
|
||||
|
||||
addOperations<ParallelExecuteOp>();
|
||||
|
||||
addInterfaces<TFInlinerInterface>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -573,9 +573,9 @@ void Print(SwitchNOp switchn, OpAsmPrinter &p) {
|
||||
|
||||
ParseResult ParseSwitchNOp(OpAsmParser &parser, OperationState &result) {
|
||||
// Parsing:
|
||||
// %2:6 = tf_executor.SwitchN %0, %1 by 5 : tensor<??xf32>
|
||||
// %2:6 = tf_executor.SwitchN %0, %1 of 5 : tensor<??xf32>
|
||||
// Where the first operand is the data to replicate, the second is an i32
|
||||
// indicating which output to populate, followed by the keyword `by` and the
|
||||
// indicating which output to populate, followed by the keyword `of` and the
|
||||
// number of outputs (+1 for the control token).
|
||||
SmallVector<OpAsmParser::OperandType, 2> op_infos;
|
||||
SmallVector<Type, 1> types;
|
||||
|
@ -165,7 +165,7 @@ def TfExecutor_IslandOp : TfExecutor_Op<"island",
|
||||
The `tf_executor.island` operation has a single region with a single block
|
||||
attached (only functional control flow is allowed). The block is terminated
|
||||
by a `tf_executor.yield` operation. The operands of the terminator
|
||||
correspond to the result values of the `tf_executor.graph` operation. An
|
||||
correspond to the result values of the `tf_executor.island` operation. An
|
||||
extra result of type `!tf_executor.control` is always produced by every
|
||||
`tf_executor.island`.
|
||||
Within an island, execution semantics follow standard sequential behavior as
|
||||
@ -299,7 +299,7 @@ def TfExecutor_SwitchNOp : TfExecutor_Op<"SwitchN",
|
||||
.SetShapeFn(SwitchNShape);
|
||||
|
||||
For example:
|
||||
%2:6 = tf_executor.SwitchN %0, %1 by 5 : tensor<??xf32>
|
||||
%2:6 = tf_executor.SwitchN %0, %1 of 5 : tensor<??xf32>
|
||||
|
||||
Note: One additional result corresponds to the control output.
|
||||
}];
|
||||
|
@ -49,7 +49,7 @@ an output element, this operation computes \\(y = |x|\\).
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape]>,
|
||||
def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns x + y element-wise.";
|
||||
|
||||
@ -98,7 +98,7 @@ Inputs must be of same size and shape.
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape]>,
|
||||
def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns x + y element-wise.";
|
||||
|
||||
@ -508,8 +508,9 @@ Broadcasting is supported, so `value` may have any number of dimensions.
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// TF_LayoutSensitiveInterface:
|
||||
SmallVector<int64_t, 4> GetLayoutDependentArgs() { return {0}; }
|
||||
SmallVector<int64_t, 4> GetLayoutDependentResults() { return {0}; }
|
||||
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
||||
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
|
||||
LogicalResult UpdateDataFormat(StringRef data_format);
|
||||
}];
|
||||
}
|
||||
|
||||
@ -980,7 +981,7 @@ tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j]
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TF_Conv2DOp : TF_Op<"Conv2D", [NoSideEffect]> {
|
||||
def TF_Conv2DOp : TF_Op<"Conv2D", [NoSideEffect, TF_LayoutSensitiveInterface]> {
|
||||
let summary = [{
|
||||
Computes a 2-D convolution given 4-D `input` and `filter` tensors.
|
||||
}];
|
||||
@ -1030,6 +1031,13 @@ horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// TF_LayoutSensitiveInterface:
|
||||
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
||||
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
|
||||
LogicalResult UpdateDataFormat(StringRef data_format);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_Conv2DBackpropFilterOp : TF_Op<"Conv2DBackpropFilter", [NoSideEffect]> {
|
||||
@ -2091,7 +2099,7 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
|
||||
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<3>;
|
||||
}
|
||||
|
||||
def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect]> {
|
||||
def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
|
||||
let summary = "Batch normalization.";
|
||||
|
||||
let description = [{
|
||||
@ -2122,6 +2130,13 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// TF_FoldOperandsTransposeInterface:
|
||||
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
||||
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
|
||||
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_GatherOp : TF_Op<"Gather", [NoSideEffect]> {
|
||||
@ -3392,6 +3407,130 @@ tf.matrix_diag(diagonal, k = -1, num_rows = 3, padding_value = 9)
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_MatrixSetDiagOp : TF_Op<"MatrixSetDiag", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Returns a batched matrix tensor with new batched diagonal values.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Given `input` and `diagonal`, this operation returns a tensor with the
|
||||
same shape and values as `input`, except for the main diagonal of the
|
||||
innermost matrices. These will be overwritten by the values in `diagonal`.
|
||||
|
||||
The output is computed as follows:
|
||||
|
||||
Assume `input` has `k+1` dimensions `[I, J, K, ..., M, N]` and `diagonal` has
|
||||
`k` dimensions `[I, J, K, ..., min(M, N)]`. Then the output is a
|
||||
tensor of rank `k+1` with dimensions `[I, J, K, ..., M, N]` where:
|
||||
|
||||
* `output[i, j, k, ..., m, n] = diagonal[i, j, k, ..., n]` for `m == n`.
|
||||
* `output[i, j, k, ..., m, n] = input[i, j, k, ..., m, n]` for `m != n`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$input,
|
||||
TF_Tensor:$diagonal
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_MatrixSetDiagV2Op : TF_Op<"MatrixSetDiagV2", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Returns a batched matrix tensor with new batched diagonal values.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Given `input` and `diagonal`, this operation returns a tensor with the
|
||||
same shape and values as `input`, except for the specified diagonals of the
|
||||
innermost matrices. These will be overwritten by the values in `diagonal`.
|
||||
|
||||
`input` has `r+1` dimensions `[I, J, ..., L, M, N]`. When `k` is scalar or
|
||||
`k[0] == k[1]`, `diagonal` has `r` dimensions `[I, J, ..., L, max_diag_len]`.
|
||||
Otherwise, it has `r+1` dimensions `[I, J, ..., L, num_diags, max_diag_len]`.
|
||||
`num_diags` is the number of diagonals, `num_diags = k[1] - k[0] + 1`.
|
||||
`max_diag_len` is the longest diagonal in the range `[k[0], k[1]]`,
|
||||
`max_diag_len = min(M + min(k[1], 0), N + min(-k[0], 0))`
|
||||
|
||||
The output is a tensor of rank `k+1` with dimensions `[I, J, ..., L, M, N]`.
|
||||
If `k` is scalar or `k[0] == k[1]`:
|
||||
|
||||
```
|
||||
output[i, j, ..., l, m, n]
|
||||
= diagonal[i, j, ..., l, n-max(k[1], 0)] ; if n - m == k[1]
|
||||
input[i, j, ..., l, m, n] ; otherwise
|
||||
```
|
||||
|
||||
Otherwise,
|
||||
|
||||
```
|
||||
output[i, j, ..., l, m, n]
|
||||
= diagonal[i, j, ..., l, diag_index, index_in_diag] ; if k[0] <= d <= k[1]
|
||||
input[i, j, ..., l, m, n] ; otherwise
|
||||
```
|
||||
where `d = n - m`, `diag_index = k[1] - d`, and `index_in_diag = n - max(d, 0)`.
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
# The main diagonal.
|
||||
input = np.array([[[7, 7, 7, 7], # Input shape: (2, 3, 4)
|
||||
[7, 7, 7, 7],
|
||||
[7, 7, 7, 7]],
|
||||
[[7, 7, 7, 7],
|
||||
[7, 7, 7, 7],
|
||||
[7, 7, 7, 7]]])
|
||||
diagonal = np.array([[1, 2, 3], # Diagonal shape: (2, 3)
|
||||
[4, 5, 6]])
|
||||
tf.matrix_set_diag(diagonal) ==> [[[1, 7, 7, 7], # Output shape: (2, 3, 4)
|
||||
[7, 2, 7, 7],
|
||||
[7, 7, 3, 7]],
|
||||
[[4, 7, 7, 7],
|
||||
[7, 5, 7, 7],
|
||||
[7, 7, 6, 7]]]
|
||||
|
||||
# A superdiagonal (per batch).
|
||||
tf.matrix_set_diag(diagonal, k = 1)
|
||||
==> [[[7, 1, 7, 7], # Output shape: (2, 3, 4)
|
||||
[7, 7, 2, 7],
|
||||
[7, 7, 7, 3]],
|
||||
[[7, 4, 7, 7],
|
||||
[7, 7, 5, 7],
|
||||
[7, 7, 7, 6]]]
|
||||
|
||||
# A band of diagonals.
|
||||
diagonals = np.array([[[1, 2, 3], # Diagonal shape: (2, 2, 3)
|
||||
[4, 5, 0]],
|
||||
[[6, 1, 2],
|
||||
[3, 4, 0]]])
|
||||
tf.matrix_set_diag(diagonals, k = (-1, 0))
|
||||
==> [[[1, 7, 7, 7], # Output shape: (2, 3, 4)
|
||||
[4, 2, 7, 7],
|
||||
[0, 5, 3, 7]],
|
||||
[[6, 7, 7, 7],
|
||||
[3, 1, 7, 7],
|
||||
[7, 4, 2, 7]]]
|
||||
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$input,
|
||||
TF_Tensor:$diagonal,
|
||||
I32Tensor:$k
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_MatrixSetDiagV3Op : TF_Op<"MatrixSetDiagV3", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Returns a batched matrix tensor with new batched diagonal values.
|
||||
@ -3551,7 +3690,7 @@ retained with length 1.
|
||||
>];
|
||||
}
|
||||
|
||||
def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect]> {
|
||||
def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
|
||||
let summary = "Performs max pooling on the input.";
|
||||
|
||||
let description = [{
|
||||
@ -3571,6 +3710,13 @@ def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect]> {
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// TF_FoldOperandsTransposeInterface:
|
||||
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
||||
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
|
||||
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_MaxPoolGradOp : TF_Op<"MaxPoolGrad", [NoSideEffect]> {
|
||||
@ -4714,7 +4860,7 @@ I.e., \\(y = 1 / x\\).
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType, TF_LayoutAgnostic]> {
|
||||
let summary = "Computes rectified linear: `max(features, 0)`.";
|
||||
|
||||
let description = [{
|
||||
@ -6657,7 +6803,7 @@ variables.
|
||||
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_TanhOp : TF_Op<"Tanh", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_TanhOp : TF_Op<"Tanh", [NoSideEffect, SameOperandsAndResultType, TF_LayoutAgnostic]> {
|
||||
let summary = "Computes hyperbolic tangent of `x` element-wise.";
|
||||
|
||||
let description = [{
|
||||
|
@ -58,6 +58,10 @@ TODO: Make invariants more structured so that we can reference them in ops.
|
||||
def TF_OperandsSameAsResultsTypeOrRef : NativeOpTrait<
|
||||
"TF::OperandsSameAsResultsTypeOrRef">;
|
||||
|
||||
// Layout agnostic operations do not depend on the operands data layout (data
|
||||
// format), as an example all element wise operations are layout agnostic.
|
||||
def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow op definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -44,11 +44,17 @@ def TF_LayoutSensitiveInterface : OpInterface<"LayoutSensitiveInterface"> {
|
||||
>,
|
||||
InterfaceMethod<
|
||||
[{Returns indices of layout dependent arguments.}],
|
||||
"SmallVector<int64_t, 4>", "GetLayoutDependentArgs", (ins)
|
||||
"SmallVector<unsigned, 4>", "GetLayoutDependentArgs", (ins)
|
||||
>,
|
||||
InterfaceMethod<
|
||||
[{Returns indices of layout dependent results.}],
|
||||
"SmallVector<int64_t, 4>", "GetLayoutDependentResults", (ins)
|
||||
"SmallVector<unsigned, 4>", "GetLayoutDependentResults", (ins)
|
||||
>,
|
||||
InterfaceMethod<
|
||||
[{Updates operation attributes and operands to account for the updated
|
||||
data format. If data format is not supported, must return failure.}],
|
||||
"LogicalResult", "UpdateDataFormat",
|
||||
(ins "StringRef":$data_format)
|
||||
>,
|
||||
];
|
||||
|
||||
@ -57,4 +63,42 @@ def TF_LayoutSensitiveInterface : OpInterface<"LayoutSensitiveInterface"> {
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_FoldOperandsTransposeInterface : OpInterface<"FoldOperandsTransposeInterface"> {
|
||||
let description = [{
|
||||
Operation supports folding operand(s) transposes into the operation itself.
|
||||
|
||||
(1) Operation might have layout dependent operands and results...
|
||||
|
||||
Example: MaxPool(Transpose($arg, $perm))
|
||||
-> Transpose(MaxPool($arg, $perm))
|
||||
|
||||
(2) ... or it might have only layout dependent operands:
|
||||
|
||||
Example: Mean(Transpose($arg, $reduction_dims))
|
||||
-> Mean($arg, Transpose($reduction_dims))
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
[{Returns indices of layout dependent arguments.}],
|
||||
"SmallVector<unsigned, 4>", "GetLayoutDependentArgs", (ins)
|
||||
>,
|
||||
InterfaceMethod<
|
||||
[{Returns indices of layout dependent results.}],
|
||||
"SmallVector<unsigned, 4>", "GetLayoutDependentResults", (ins)
|
||||
>,
|
||||
InterfaceMethod<
|
||||
[{Updates operation attributes and operands to account for the folded
|
||||
permutation. If folding of permutation is not possible, must return
|
||||
failure.}],
|
||||
"LogicalResult", "FoldOperandsPermutation",
|
||||
(ins "ArrayRef<int64_t>":$permutation)
|
||||
>,
|
||||
];
|
||||
|
||||
let verify = [{
|
||||
return VerifyFoldOperandsTransposeInterface($_op);
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TF_OP_INTERFACES
|
||||
|
@ -292,6 +292,156 @@ static LogicalResult VerifyTypesCompatibility(
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TF op helper functions to work with layout transformation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
SmallVector<int64_t, 4> GetDataFormatPermutation(StringRef from, StringRef to) {
|
||||
if (from == "NHWC" && to == "NCHW") {
|
||||
return {0, 3, 1, 2};
|
||||
} else if (from == "NCHW" && to == "NHWC") {
|
||||
return {0, 2, 3, 1};
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
// Shuffle elements in the `attr` according to the permutation. Optional
|
||||
// `inner_size` allows to shuffle array attributes created from rank 2 tensors
|
||||
// on outer dimension only.
|
||||
ArrayAttr ShuffleArrayAttr(ArrayAttr attr, ArrayRef<int64_t> permutation,
|
||||
int inner_size = 1) {
|
||||
if (attr.size() == 0) return attr;
|
||||
|
||||
assert(attr.size() % inner_size == 0);
|
||||
assert(attr.size() / inner_size == permutation.size());
|
||||
|
||||
SmallVector<Attribute, 8> values{attr.begin(), attr.end()};
|
||||
SmallVector<Attribute, 8> shuffled(values.size());
|
||||
|
||||
for (size_t i = 0; i < permutation.size(); ++i) {
|
||||
for (size_t j = 0; j < inner_size; ++j) {
|
||||
shuffled[i * inner_size + j] = values[permutation[i] * inner_size + j];
|
||||
}
|
||||
}
|
||||
|
||||
return ArrayAttr::get(shuffled, attr.getContext());
|
||||
}
|
||||
|
||||
// Shuffle ranked tensor dimensions according to the permutation.
|
||||
Type ShuffleRankedTensorType(Type type, ArrayRef<int64_t> permutation) {
|
||||
if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
|
||||
ArrayRef<int64_t> shape = ranked_type.getShape();
|
||||
assert(permutation.size() == shape.size());
|
||||
|
||||
SmallVector<int64_t, 4> new_shape(permutation.size());
|
||||
for (size_t i = 0; i < permutation.size(); ++i)
|
||||
new_shape[i] = shape[permutation[i]];
|
||||
|
||||
return RankedTensorType::get(new_shape, ranked_type.getElementType());
|
||||
}
|
||||
|
||||
return type;
|
||||
}
|
||||
|
||||
static bool AreCancellablePermutations(DenseIntElementsAttr perm0,
|
||||
DenseIntElementsAttr perm1) {
|
||||
if (perm0.getNumElements() == 0 || perm1.getNumElements() == 0) return false;
|
||||
if (perm0.getNumElements() != perm1.getNumElements()) return false;
|
||||
|
||||
SmallVector<int64_t, 8> perm0_values;
|
||||
for (auto value : perm0.getIntValues())
|
||||
perm0_values.push_back(value.getSExtValue());
|
||||
|
||||
SmallVector<int64_t, 8> perm1_values;
|
||||
for (auto value : perm1.getIntValues())
|
||||
perm1_values.push_back(value.getSExtValue());
|
||||
|
||||
for (int i = 0; i < perm0_values.size(); ++i) {
|
||||
if (perm0_values[perm1_values[i]] != i) return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Default implementation of `LayoutSensitiveInterface::UpdateDataFormat` for
|
||||
// layout sensitive operations that do not have any additional layout dependent
|
||||
// attributes besides `data_format` string.
|
||||
template <typename Op>
|
||||
LogicalResult UpdateDataFormat(StringRef data_format, Op *op) {
|
||||
auto perm = GetDataFormatPermutation(op->data_format(), data_format);
|
||||
if (perm.empty()) return failure();
|
||||
|
||||
// Update data format attribute.
|
||||
op->setAttr("data_format", StringAttr::get(data_format, op->getContext()));
|
||||
|
||||
// Update types for all layout sensitive results.
|
||||
auto layout_sensitive = cast<LayoutSensitiveInterface>(op->getOperation());
|
||||
for (unsigned idx : layout_sensitive.GetLayoutDependentResults()) {
|
||||
OpResult result = op->getOperation()->getResult(idx);
|
||||
result.setType(ShuffleRankedTensorType(result.getType(), perm));
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// Default implementation for folding operand transpose into the operation.
|
||||
// See `FoldOperandsTransposeInterface::FoldOperandsPermutation`.
|
||||
template <typename Op>
|
||||
LogicalResult FoldOperandsPermutation(
|
||||
ArrayRef<int64_t> permutation, Op *op,
|
||||
ArrayRef<std::pair<StringRef, ArrayAttr>> shuffle_attrs = {}) {
|
||||
MLIRContext *context = op->template getParentOfType<ModuleOp>().getContext();
|
||||
|
||||
// We only support NHWC <-> NCHW permutations.
|
||||
static constexpr std::array<int64_t, 4> kNchwToNhwc = {0, 2, 3, 1};
|
||||
static constexpr std::array<int64_t, 4> kNhwcToNchw = {0, 3, 1, 2};
|
||||
|
||||
// Operation data format after folding `permutation`.
|
||||
StringRef target_data_format = [&]() -> StringRef {
|
||||
if (op->data_format() == "NHWC" && permutation.equals(kNchwToNhwc)) {
|
||||
return "NCHW"; // cancel NCHW->NHWC operand permutation
|
||||
} else if (op->data_format() == "NCHW" && permutation.equals(kNhwcToNchw)) {
|
||||
return "NHWC"; // cancel NHWC->NCHW operand permutation
|
||||
} else {
|
||||
return "";
|
||||
}
|
||||
}();
|
||||
if (target_data_format.empty()) return failure();
|
||||
|
||||
// To fold operand `permutation` into the `op` we need shuffle all layout
|
||||
// dependent attributes and types with a reverse permutation, and change
|
||||
// operation data format to `target_data_format`.
|
||||
//
|
||||
// Example:
|
||||
// %1 = SomeOp(...) {data_format = NHWC}
|
||||
// %2 = Transpose(%1) {permutation = NHWC->NCHW}
|
||||
// %3 = Op(%2) {data_format = NCHW}
|
||||
//
|
||||
// To bypass %2 we have to change data format to shuffle data format from NCHW
|
||||
// to NHWC, which is the reverse of operand permutation (function argument).
|
||||
auto reverse_permutation =
|
||||
GetDataFormatPermutation(op->data_format(), target_data_format);
|
||||
if (reverse_permutation.empty()) return failure();
|
||||
|
||||
op->setAttr("data_format", StringAttr::get(target_data_format, context));
|
||||
|
||||
for (auto pair : shuffle_attrs) {
|
||||
StringRef attr_name = pair.first;
|
||||
ArrayAttr attr_value = pair.second;
|
||||
op->setAttr(attr_name, ShuffleArrayAttr(attr_value, reverse_permutation));
|
||||
}
|
||||
|
||||
auto fold = cast<FoldOperandsTransposeInterface>(op->getOperation());
|
||||
for (unsigned idx : fold.GetLayoutDependentResults()) {
|
||||
OpResult result = op->getOperation()->getResult(idx);
|
||||
result.setType(
|
||||
ShuffleRankedTensorType(result.getType(), reverse_permutation));
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
|
||||
} // namespace
|
||||
@ -459,6 +609,15 @@ static LogicalResult Verify(BiasAddOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
// TODO(ezhulenev): BiasAddOp is not really layout sensitive, it must only
|
||||
// support folding operand transposes.
|
||||
LogicalResult BiasAddOp::UpdateDataFormat(StringRef data_format) {
|
||||
auto ranked = value().getType().dyn_cast<RankedTensorType>();
|
||||
if (!ranked || ranked.getRank() != 4) return failure();
|
||||
|
||||
return ::mlir::TF::UpdateDataFormat(data_format, this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BiasAddGradOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -817,6 +976,21 @@ static LogicalResult Verify(OpT op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult Conv2DOp::UpdateDataFormat(StringRef data_format) {
|
||||
auto perm = GetDataFormatPermutation(this->data_format(), data_format);
|
||||
if (perm.empty()) return failure();
|
||||
|
||||
// Update data_format attribute and result types.
|
||||
if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure();
|
||||
|
||||
// Update convolution attributes.
|
||||
setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
|
||||
setAttr("strides", ShuffleArrayAttr(strides(), perm));
|
||||
setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Conv2dBackpropInputOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1138,6 +1312,11 @@ static LogicalResult Verify(FusedBatchNormOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult FusedBatchNormV3Op::FoldOperandsPermutation(
|
||||
ArrayRef<int64_t> permutation) {
|
||||
return ::mlir::TF::FoldOperandsPermutation(permutation, this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GatherV2Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1330,6 +1509,16 @@ void MaxOp::build(Builder *builder, OperationState &result, Value input,
|
||||
build(builder, result, out_ty, input, reduction_indices, keep_dims);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MaxPoolOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult MaxPoolOp::FoldOperandsPermutation(
|
||||
ArrayRef<int64_t> permutation) {
|
||||
return ::mlir::TF::FoldOperandsPermutation(
|
||||
permutation, this, {{"strides", strides()}, {"ksize", ksize()}});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MaxPoolGradOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1347,6 +1536,38 @@ static LogicalResult Verify(MaxPoolGradOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MeanOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) {
|
||||
// Reduction indices must be defined by a constant operation.
|
||||
auto reduction_op =
|
||||
dyn_cast_or_null<TF::ConstOp>(reduction_indices().getDefiningOp());
|
||||
if (!reduction_op) return failure();
|
||||
|
||||
auto reductions_value = reduction_op.value().dyn_cast<DenseElementsAttr>();
|
||||
if (!reductions_value) return failure();
|
||||
|
||||
// Prepare new reduction indices according to operand permutation.
|
||||
SmallVector<int64_t, 4> shuffled_reduction;
|
||||
llvm::transform(reductions_value.getIntValues(),
|
||||
std::back_inserter(shuffled_reduction),
|
||||
[&](APInt idx) { return permutation[idx.getSExtValue()]; });
|
||||
|
||||
// Add constant operation with a new reduction indices.
|
||||
OpBuilder builder(getOperation());
|
||||
auto type = mlir::RankedTensorType::get(shuffled_reduction.size(),
|
||||
builder.getIntegerType(64));
|
||||
auto values = mlir::DenseIntElementsAttr::get(type, shuffled_reduction);
|
||||
auto shuffled_reduction_op = builder.create<TF::ConstOp>(getLoc(), values);
|
||||
|
||||
// Use new reduction indices.
|
||||
setOperand(1, shuffled_reduction_op);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NegOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -2723,23 +2944,46 @@ void TransposeOp::build(Builder *builder, OperationState &result, Value x,
|
||||
perm);
|
||||
}
|
||||
|
||||
OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto const_perm = dyn_cast_or_null<TF::ConstOp>(perm().getDefiningOp());
|
||||
namespace {
|
||||
|
||||
if (!const_perm) {
|
||||
return {};
|
||||
}
|
||||
OpFoldResult FoldIdentityTranspose(TransposeOp op) {
|
||||
auto const_perm = dyn_cast_or_null<TF::ConstOp>(op.perm().getDefiningOp());
|
||||
if (!const_perm) return {};
|
||||
|
||||
auto const_value = const_perm.value();
|
||||
|
||||
const auto &elements = const_value.getValues<APInt>();
|
||||
|
||||
for (auto it : llvm::enumerate(elements)) {
|
||||
if (it.index() != it.value()) {
|
||||
return {};
|
||||
}
|
||||
if (it.index() != it.value()) return {};
|
||||
}
|
||||
|
||||
return x();
|
||||
return op.x();
|
||||
}
|
||||
|
||||
OpFoldResult FoldCancellableTranspose(TransposeOp op) {
|
||||
// Operand is a TransposeOp.
|
||||
auto transpose = dyn_cast_or_null<TF::TransposeOp>(op.x().getDefiningOp());
|
||||
if (!transpose) return {};
|
||||
|
||||
// Permutations defined by constant operations.
|
||||
auto perm0 = dyn_cast_or_null<TF::ConstOp>(op.perm().getDefiningOp());
|
||||
auto perm1 = dyn_cast_or_null<TF::ConstOp>(transpose.perm().getDefiningOp());
|
||||
if (!perm0 || !perm1) return {};
|
||||
|
||||
// With permutation indices that cancel each other
|
||||
auto perm0_value = perm0.value().cast<DenseIntElementsAttr>();
|
||||
auto perm1_value = perm1.value().cast<DenseIntElementsAttr>();
|
||||
if (!AreCancellablePermutations(perm0_value, perm1_value)) return {};
|
||||
|
||||
return transpose.x();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (auto folded = FoldIdentityTranspose(*this)) return folded;
|
||||
if (auto folded = FoldCancellableTranspose(*this)) return folded;
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -172,7 +172,7 @@ else_branch: A function that takes 'inputs' and returns a list of
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_MeanOp : TF_Op<"Mean", [NoSideEffect]> {
|
||||
def TF_MeanOp : TF_Op<"Mean", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
|
||||
let summary = "Computes the mean of elements across dimensions of a tensor.";
|
||||
|
||||
let description = [{
|
||||
@ -195,6 +195,13 @@ retained with length 1.
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// TF_FoldOperandsTransposeInterface:
|
||||
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
||||
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {}; }
|
||||
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_LegacyCallOp : TF_Op<"LegacyCall",
|
||||
|
@ -112,12 +112,26 @@ static LogicalResult VerifyIndexPath(Operation *op, NamedAttribute named_attr) {
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Return true if `type` is a tensor of `!tf.resource`. This is the type that is
|
||||
// used to represent mutable variables on exported functions' bound inputs.
|
||||
static bool IsResourceVarType(Type type) {
|
||||
TensorType tensor_type = type.dyn_cast<TensorType>();
|
||||
if (!tensor_type) return false;
|
||||
return tensor_type.getElementType().isa<TF::ResourceType>();
|
||||
static LogicalResult VerifyBoundInputArgType(Operation *op_for_diagnostics,
|
||||
Type arg_type,
|
||||
GlobalTensorOp global_tensor) {
|
||||
if (global_tensor.is_mutable()) {
|
||||
auto expected_type = RankedTensorType::get(
|
||||
{}, TF::ResourceType::get({global_tensor.type().cast<TensorType>()},
|
||||
arg_type.getContext()));
|
||||
if (arg_type != expected_type) {
|
||||
return op_for_diagnostics->emitError()
|
||||
<< "mutable bound input with type " << arg_type
|
||||
<< " expected to have type " << expected_type;
|
||||
}
|
||||
} else {
|
||||
if (arg_type != global_tensor.type()) {
|
||||
return op_for_diagnostics->emitError()
|
||||
<< "bound input for immutable 'tf_saved_model.global_tensor' must "
|
||||
"match the global tensor's type";
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult TensorFlowSavedModelDialect::verifyRegionArgAttribute(
|
||||
@ -137,20 +151,7 @@ LogicalResult TensorFlowSavedModelDialect::verifyRegionArgAttribute(
|
||||
<< symbol_name << "'";
|
||||
}
|
||||
auto arg_type = cast<FuncOp>(op).getArgument(arg_index).getType();
|
||||
if (global_tensor.is_mutable()) {
|
||||
if (!IsResourceVarType(arg_type)) {
|
||||
return op->emitError()
|
||||
<< "bound inputs for mutable 'tf_saved_model.global_tensor's "
|
||||
"must be tensors of '!tf.resource'";
|
||||
}
|
||||
} else {
|
||||
if (arg_type != global_tensor.type()) {
|
||||
return op->emitError() << "bound input for immutable "
|
||||
"'tf_saved_model.global_tensor' must "
|
||||
"match the global tensor's type";
|
||||
}
|
||||
}
|
||||
return success();
|
||||
return VerifyBoundInputArgType(op, arg_type, global_tensor);
|
||||
}
|
||||
if (named_attr.first == "tf_saved_model.index_path") {
|
||||
return VerifyIndexPath(op, named_attr);
|
||||
|
@ -68,6 +68,11 @@ class OperandsSameAsResultsTypeOrRef
|
||||
}
|
||||
};
|
||||
|
||||
// Layout agnostic operations do not depend on the operands data layout (data
|
||||
// format), as and example all element wise operations are layout agnostic.
|
||||
template <typename ConcreteType>
|
||||
class LayoutAgnostic : public TraitBase<ConcreteType, LayoutAgnostic> {};
|
||||
|
||||
} // namespace TF
|
||||
} // namespace OpTrait
|
||||
} // namespace mlir
|
||||
|
@ -21,23 +21,35 @@ limitations under the License.
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
LogicalResult VerifyLayoutSensitiveInterface(Operation* op) {
|
||||
auto layout_sensitive_interface = cast<LayoutSensitiveInterface>(op);
|
||||
namespace {
|
||||
|
||||
if (!llvm::all_of(
|
||||
layout_sensitive_interface.GetLayoutDependentArgs(),
|
||||
[&](int64_t index) { return index < op->getNumOperands(); })) {
|
||||
template <typename Interface>
|
||||
LogicalResult VerifyLayoutDependentArgsAndResults(Operation* op,
|
||||
Interface interface) {
|
||||
auto valid_operand = [&](int64_t idx) { return idx < op->getNumOperands(); };
|
||||
if (!llvm::all_of(interface.GetLayoutDependentArgs(), valid_operand)) {
|
||||
return op->emitOpError("layout dependent argument index is out of bound");
|
||||
}
|
||||
|
||||
if (!llvm::all_of(
|
||||
layout_sensitive_interface.GetLayoutDependentResults(),
|
||||
[&](int64_t index) { return index < op->getNumResults(); })) {
|
||||
auto valid_result = [&](int64_t idx) { return idx < op->getNumResults(); };
|
||||
if (!llvm::all_of(interface.GetLayoutDependentResults(), valid_result)) {
|
||||
return op->emitOpError("layout dependent result index is out of bound");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult VerifyLayoutSensitiveInterface(Operation* op) {
|
||||
auto layout_sensitive_interface = cast<LayoutSensitiveInterface>(op);
|
||||
return VerifyLayoutDependentArgsAndResults(op, layout_sensitive_interface);
|
||||
}
|
||||
|
||||
LogicalResult VerifyFoldOperandsTransposeInterface(Operation* op) {
|
||||
auto fold_operands_transpose = cast<FoldOperandsTransposeInterface>(op);
|
||||
return VerifyLayoutDependentArgsAndResults(op, fold_operands_transpose);
|
||||
}
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
@ -29,6 +29,12 @@ namespace TF {
|
||||
// [0, getNumOperands/getNumResults) range.
|
||||
LogicalResult VerifyLayoutSensitiveInterface(Operation* op);
|
||||
|
||||
// Verifies correctness of ops implementing FoldOperandsTransposeInterface (see
|
||||
// definition in tf_op_base.td):
|
||||
// (1) Layout dependent arguments and results indices must be in
|
||||
// [0, getNumOperands/getNumResults) range.
|
||||
LogicalResult VerifyFoldOperandsTransposeInterface(Operation* op);
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -3,6 +3,10 @@
|
||||
|
||||
// All tests also test for idempotence.
|
||||
|
||||
// Test that external functions aren't processed (used to crash).
|
||||
// CHECK-LABEL: func @unused_external_func
|
||||
func @unused_external_func()
|
||||
|
||||
func @multiple_return(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> (tensor<*xi32>, tensor<*xi32>) {
|
||||
%graph:2 = tf_executor.graph {
|
||||
%island:3 = tf_executor.island {
|
||||
@ -276,3 +280,67 @@ func @empty_island_multiple_data_results(%arg0: tensor<*xf32>, %arg1: tensor<*xi
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// The following tests check that certain control dependencies between islands
|
||||
// and certain tf_executor ops are added correctly.
|
||||
|
||||
// CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print"
|
||||
// CHECK: tf_executor.NextIteration.Sink [{{.*}}] {{.*}}, %[[CONTROL]]
|
||||
func @next_iteration_sink_control_input() {
|
||||
tf_executor.graph {
|
||||
%source:3 = tf_executor.NextIteration.Source : tensor<*xi32>
|
||||
%island:2 = tf_executor.island {
|
||||
%const = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32>
|
||||
%print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||
tf_executor.yield %const : tensor<*xi32>
|
||||
}
|
||||
tf_executor.NextIteration.Sink[%source#1] %island#0 : tensor<*xi32>
|
||||
tf_executor.fetch %island#0 : tensor<*xi32>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print"
|
||||
// CHECK: tf_executor.LoopCond {{.*}}, %[[CONTROL]]
|
||||
func @loop_cond_control_input() {
|
||||
tf_executor.graph {
|
||||
%island:2 = tf_executor.island {
|
||||
%const = "tf.Const"() {value = dense<1> : tensor<i1>} : () -> tensor<*xi1>
|
||||
%print = "tf.Print"(%const) : (tensor<*xi1>) -> (tensor<*xi1>)
|
||||
tf_executor.yield %const : tensor<*xi1>
|
||||
}
|
||||
%loop_cond:2 = tf_executor.LoopCond %island#0 : tensor<*xi1>
|
||||
tf_executor.fetch %loop_cond#0 : tensor<*xi1>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print"
|
||||
// CHECK: tf_executor.Enter {{.*}}, %[[CONTROL]]
|
||||
func @enter_control_input() {
|
||||
tf_executor.graph {
|
||||
%island:2 = tf_executor.island {
|
||||
%const = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32>
|
||||
%print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||
tf_executor.yield %const : tensor<*xi32>
|
||||
}
|
||||
%enter:2 = tf_executor.Enter %island#0 frame "some/frame" : tensor<*xi32>
|
||||
tf_executor.fetch %enter#0 : tensor<*xi32>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print"
|
||||
// CHECK: tf_executor.SwitchN {{.*}}, {{.*}} of {{[0-9]*}} (%[[CONTROL]])
|
||||
func @switchn_control_input(%arg1: tensor<i32>) {
|
||||
tf_executor.graph {
|
||||
%island:2 = tf_executor.island {
|
||||
%const = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32>
|
||||
%print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||
tf_executor.yield %const : tensor<*xi32>
|
||||
}
|
||||
%switchn:4 = tf_executor.SwitchN %island#0, %arg1 of 3: tensor<*xi32>
|
||||
tf_executor.fetch %switchn#0 : tensor<*xi32>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -383,6 +383,28 @@ func @nonIdentityTranspose(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x6x5xf32
|
||||
// CHECK: return %1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @cancellableTranspose
|
||||
func @cancellableTranspose(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> {
|
||||
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
%1 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
%2 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
|
||||
%3 = "tf.Transpose"(%2, %1) : (tensor<1x8x4x4xf32>, tensor<4xi32>) -> tensor<1x4x4x8xf32>
|
||||
|
||||
return %3 : tensor<1x4x4x8xf32>
|
||||
// CHECK: return %arg0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @nonCancellableTranspose
|
||||
func @nonCancellableTranspose(%arg0: tensor<1x4x4x8xf32>) -> tensor<4x1x4x8xf32> {
|
||||
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
%1 = "tf.Const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
%2 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
|
||||
%3 = "tf.Transpose"(%2, %1) : (tensor<1x8x4x4xf32>, tensor<4xi32>) -> tensor<4x1x4x8xf32>
|
||||
|
||||
return %3 : tensor<4x1x4x8xf32>
|
||||
// CHECK: return %3
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @addN
|
||||
func @addN(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: return %arg0
|
||||
|
@ -9,5 +9,7 @@ func @device_test(%arg0: tensor<3x1xf32>) -> (tensor<3x3xf32>) {
|
||||
%1 = "tf.MatMul"(%arg0, %0) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "", transpose_a = false, transpose_b = false} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
|
||||
// CHECK: device = "cpu"
|
||||
%2 = "tf.Relu"(%1) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "cpu"} : (tensor<3x3xf32>) -> tensor<3x3xf32>
|
||||
return %2 : tensor<3x3xf32>
|
||||
// CHECK: device = "gpu"
|
||||
%3 = "tf.Relu"(%2) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"]} : (tensor<3x3xf32>) -> tensor<3x3xf32>
|
||||
return %3 : tensor<3x3xf32>
|
||||
}
|
||||
|
@ -0,0 +1,57 @@
|
||||
// RUN: tf-opt %s -tf-executor-tpu-v1-island-coarsening | FileCheck %s --dump-input=fail
|
||||
|
||||
|
||||
// Test that islands with a function call are merged if the call is to a function
|
||||
// that contains ops with the same attribute.
|
||||
// CHECK-LABEL: func @control_input
|
||||
func @control_input(%arg0 : tensor<i1>) -> tensor<i32> {
|
||||
%0:6 = tf_executor.graph {
|
||||
%1:2 = tf_executor.island wraps "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i1>) -> tensor<i32>
|
||||
%2:2 = tf_executor.island wraps "tf.While"(%1#0) {name = "A", body = @while_body_with_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||
%3:2 = tf_executor.island wraps "tf.While"(%1#0) {name = "B", body = @while_body_with_wrong_cluster_attr, cond = @while_cond_with_wrong_cluster_attr, is_stateless = false, parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||
%4:2 = tf_executor.island wraps "tf.While"(%1#0) {name = "C", body = @while_body_without_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||
%6:2 = tf_executor.island wraps "tf.While"(%1#0) {name = "D", body = @while_body_without_cluster_attr, cond = @while_cond_without_cluster_attr, is_stateless = false, parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||
%5:2 = tf_executor.island wraps "tf.While"(%1#0) {name = "E", body = @while_body_with_cluster_attr, cond = @while_cond_without_cluster_attr, is_stateless = false, parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||
|
||||
// CHECK: "tf.opA"
|
||||
// CHECK-NOT: island
|
||||
// CHECK: name = "A"
|
||||
// CHECK-NOT: island
|
||||
// CHECK: name = "C"
|
||||
// CHECK-NOT: island
|
||||
// CHECK: name = "E"
|
||||
// CHECK: island {{.*}}name = "B"
|
||||
// CHECK: island {{.*}}name = "D"
|
||||
|
||||
tf_executor.fetch %1#0, %2#0, %3#0, %4#0, %5#0, %6#0 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||
}
|
||||
return %0#0 : tensor<i32>
|
||||
}
|
||||
|
||||
func @while_body_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
func @while_cond_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
|
||||
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
func @while_body_with_wrong_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%0 = "some.op"(%arg0) {_tpu_replicate = "wrong_cluster"} : (tensor<i32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
func @while_cond_with_wrong_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
|
||||
%0 = "some.op"(%arg0) {_tpu_replicate = "wrong_cluster"} : (tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
func @while_body_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
func @while_cond_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
|
||||
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
@ -0,0 +1,44 @@
|
||||
// RUN: tf-opt %s -tf-executor-tpu-v1-island-inlining | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK-NOT: tf.PartitionedCall
|
||||
// CHECK-NOT: module @_tpu_v1_compat_outlined
|
||||
|
||||
module {
|
||||
func @control_input(%arg0: tensor<i1>) -> tensor<i32> {
|
||||
%0:4 = tf_executor.graph {
|
||||
%outputs:4, %control = tf_executor.island wraps "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @_tpu_v1_compat_outlined::@_tpu_v1_compat_outlined_func0} : (tensor<i1>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>)
|
||||
tf_executor.fetch %outputs#0, %outputs#1, %outputs#2, %outputs#3 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||
}
|
||||
return %0#0 : tensor<i32>
|
||||
}
|
||||
module @_tpu_v1_compat_outlined {
|
||||
func @_tpu_v1_compat_outlined_func0(%arg0: tensor<i1>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
|
||||
"tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "device", num_replicas = 1 : i64, topology = "topology"} : () -> ()
|
||||
%0 = "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i1>) -> tensor<i32>
|
||||
%1 = "tf.While"(%0) {body = @while_body_with_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "A", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||
%2 = "tf.While"(%0) {body = @while_body_without_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "C", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||
%3 = "tf.While"(%0) {body = @while_body_with_cluster_attr, cond = @while_cond_without_cluster_attr, is_stateless = false, name = "E", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||
return %0, %1, %2, %3 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||
}
|
||||
func @while_body_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
func @while_cond_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
|
||||
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
func @while_body_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
func @while_cond_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
|
||||
%0 = "tf.PartionedCalledOp"(%arg0) {f = @callee_func} : (tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
func @callee_func(%arg0: tensor<i32>) -> tensor<i1> {
|
||||
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,48 @@
|
||||
// RUN: tf-opt %s -tf-executor-tpu-v1-island-outlining | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK: func @control_input
|
||||
// CHECK-NOT: func @
|
||||
// CHECK-LABEL: module @_tpu_v1_compat_outlined
|
||||
// CHECK: @_tpu_v1_compat_outlined_func0
|
||||
// CHECK: func @while_body_with_cluster_attr
|
||||
// CHECK: func @while_cond_with_cluster_attr
|
||||
// CHECK: func @while_body_without_cluster_attr
|
||||
// CHECK: func @while_cond_without_cluster_attr
|
||||
// CHECK: func @callee_func
|
||||
module {
|
||||
func @control_input(%arg0: tensor<i1>) -> tensor<i32> {
|
||||
%0:4 = tf_executor.graph {
|
||||
%outputs:4, %control = tf_executor.island {
|
||||
"tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
|
||||
%1 = "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i1>) -> tensor<i32>
|
||||
%2 = "tf.While"(%1) {body = @while_body_with_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "A", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||
%3 = "tf.While"(%1) {body = @while_body_without_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "C", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||
%4 = "tf.While"(%1) {body = @while_body_with_cluster_attr, cond = @while_cond_without_cluster_attr, is_stateless = false, name = "E", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||
tf_executor.yield %1, %2, %3, %4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||
}
|
||||
tf_executor.fetch %outputs#0, %outputs#1, %outputs#2, %outputs#3 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||
|
||||
}
|
||||
return %0#0 : tensor<i32>
|
||||
}
|
||||
func @while_body_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
func @while_cond_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
|
||||
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
func @while_body_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
func @while_cond_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
|
||||
%0 = "tf.PartionedCalledOp"(%arg0) { f = @callee_func} : (tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
func @callee_func(%arg0: tensor<i32>) -> tensor<i1> {
|
||||
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
}
|
@ -1,41 +1,24 @@
|
||||
// RUN: tf-opt %s -tf-layout-assignment=force-data-format=NCHW -verify-diagnostics | FileCheck %s
|
||||
// RUN: tf-opt %s -tf-layout-optimization=force-data-format=NCHW -verify-diagnostics | FileCheck %s --dump-input=always
|
||||
|
||||
// CHECK-LABEL: func @transposeBiasAdd
|
||||
func @transposeBiasAdd(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
|
||||
func @transposeBiasAdd(%arg0: tensor<1x8x4x4xf32>, %arg1: tensor<8xf32>) -> tensor<1x8x4x4xf32> {
|
||||
|
||||
// Check that BiasAdd was converted to forced data format, and layout
|
||||
// dependent arguments and results passed through transpose nodes.
|
||||
// Convert input: NCHW -> NHWC
|
||||
%0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x8x4x4xf32>, tensor<4xi64>) -> tensor<1x4x4x8xf32>
|
||||
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||
// CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%[[ARG_TRANSPOSE]], %arg1) {data_format = "NCHW"} {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
|
||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[BIAS_ADD]], %[[RES_PERM]])
|
||||
// CHECK: return %[[RES_TRANSPOSE]]
|
||||
%0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NHWC"} : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
|
||||
// Compute in NHWC
|
||||
%2 = "tf.BiasAdd"(%1, %arg1) {data_format = "NHWC"} : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
|
||||
|
||||
return %0 : tensor<1x4x4x8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transposeBiasAddWithDefaultAttr
|
||||
func @transposeBiasAddWithDefaultAttr(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
|
||||
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||
// CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%[[ARG_TRANSPOSE]], %arg1) {data_format = "NCHW"} {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
|
||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[BIAS_ADD]], %[[RES_PERM]])
|
||||
// CHECK: return %[[RES_TRANSPOSE]]
|
||||
%0 = "tf.BiasAdd"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
|
||||
|
||||
return %0 : tensor<1x4x4x8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transposeBiasWithUnknownShape
|
||||
func @transposeBiasWithUnknownShape(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tensor<*xf32> {
|
||||
|
||||
// CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%[[ARG_TRANSPOSE]], %arg1) {data_format = "NCHW"} {{.*}} tensor<*xf32>
|
||||
%0 = "tf.BiasAdd"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<*xf32>
|
||||
|
||||
return %0 : tensor<*xf32>
|
||||
// Convert result back: NHWC -> NCHW
|
||||
%3 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%4 = "tf.Transpose"(%2, %3) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
||||
|
||||
// Check that BiasAdd computed in NCHW format, and all redundant transpose
|
||||
// operations removed from the function.
|
||||
|
||||
// CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW"} {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: return %[[BIAS_ADD]]
|
||||
|
||||
return %4 : tensor<1x8x4x4xf32>
|
||||
}
|
@ -0,0 +1,75 @@
|
||||
// RUN: tf-opt %s -tf-layout-assignment=force-data-format=NCHW -verify-diagnostics | FileCheck %s --dump-input=always
|
||||
|
||||
// CHECK-LABEL: func @transposeBiasAdd
|
||||
func @transposeBiasAdd(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
|
||||
|
||||
// Check that BiasAdd was converted to forced data format, and layout
|
||||
// dependent arguments and results passed through transpose nodes.
|
||||
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||
// CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%[[ARG_TRANSPOSE]], %arg1) {data_format = "NCHW"} {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
|
||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[BIAS_ADD]], %[[RES_PERM]])
|
||||
// CHECK: return %[[RES_TRANSPOSE]]
|
||||
%0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NHWC"} : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
|
||||
|
||||
return %0 : tensor<1x4x4x8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transposeBiasAddWithDefaultAttr
|
||||
func @transposeBiasAddWithDefaultAttr(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
|
||||
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||
// CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%[[ARG_TRANSPOSE]], %arg1) {data_format = "NCHW"} {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
|
||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[BIAS_ADD]], %[[RES_PERM]])
|
||||
// CHECK: return %[[RES_TRANSPOSE]]
|
||||
%0 = "tf.BiasAdd"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
|
||||
|
||||
return %0 : tensor<1x4x4x8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transposeBiasWithUnknownShape
|
||||
func @transposeBiasWithUnknownShape(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tensor<*xf32> {
|
||||
|
||||
// CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%[[ARG_TRANSPOSE]], %arg1) {data_format = "NCHW"} {{.*}} tensor<*xf32>
|
||||
%0 = "tf.BiasAdd"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<*xf32>
|
||||
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transposeConv2D
|
||||
func @transposeConv2D(%input: tensor<1x32x32x3xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<1x32x32x8xf32> {
|
||||
|
||||
// IMPORTANT: Tensor shapes do not match convolution parameters (stride,
|
||||
// dilations, etc...). This test only verifies that changing convolution data
|
||||
// layout will update all the attributes.
|
||||
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||
|
||||
// CHECK: %[[CONV2D:[0-9]*]] = "tf.Conv2D"(%[[ARG_TRANSPOSE]], %arg1)
|
||||
// CHECK-SAME: data_format = "NCHW"
|
||||
// CHECK-SAME: dilations = [1, 4, 2, 3]
|
||||
// CHECK-SAME: explicit_paddings = [1, 2, 7, 8, 3, 4, 5, 6]
|
||||
// CHECK-SAME: padding = "EXPLICIT"
|
||||
// CHECK-SAME: strides = [5, 8, 6, 7]
|
||||
// CHECK-SAME: (tensor<1x3x32x32xf32>, tensor<1x1x3x8xf32>) -> tensor<1x8x32x32xf32>
|
||||
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
|
||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[CONV2D]], %[[RES_PERM]])
|
||||
// CHECK: return %[[RES_TRANSPOSE]]
|
||||
|
||||
%0 = "tf.Conv2D"(%input, %filter)
|
||||
{
|
||||
data_format = "NHWC",
|
||||
dilations = [1, 2, 3, 4],
|
||||
explicit_paddings = [1, 2, 3, 4, 5, 6, 7, 8],
|
||||
padding = "EXPLICIT",
|
||||
strides = [5, 6, 7, 8]
|
||||
} : (tensor<1x32x32x3xf32>, tensor<1x1x3x8xf32>) -> tensor<1x32x32x8xf32>
|
||||
|
||||
return %0 : tensor<1x32x32x8xf32>
|
||||
}
|
@ -0,0 +1,35 @@
|
||||
// RUN: tf-opt %s -tf-layout-assignment=force-data-format=NHWC -verify-diagnostics | FileCheck %s --dump-input=always
|
||||
|
||||
// CHECK-LABEL: func @transposeConv2D
|
||||
func @transposeConv2D(%input: tensor<1x3x32x32xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<1x8x32x32xf32> {
|
||||
|
||||
// IMPORTANT: Tensor shapes do not match convolution parameters (stride,
|
||||
// dilations, etc...). This test only verifies that changing convolution data
|
||||
// layout will update all the attributes.
|
||||
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
|
||||
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||
|
||||
// CHECK: %[[CONV2D:[0-9]*]] = "tf.Conv2D"(%[[ARG_TRANSPOSE]], %arg1)
|
||||
// CHECK-SAME: data_format = "NHWC"
|
||||
// CHECK-SAME: dilations = [1, 3, 4, 2]
|
||||
// CHECK-SAME: explicit_paddings = [1, 2, 5, 6, 7, 8, 3, 4]
|
||||
// CHECK-SAME: padding = "EXPLICIT"
|
||||
// CHECK-SAME: strides = [5, 7, 8, 6]
|
||||
// CHECK-SAME: (tensor<1x32x32x3xf32>, tensor<1x1x3x8xf32>) -> tensor<1x32x32x8xf32>
|
||||
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[CONV2D]], %[[RES_PERM]])
|
||||
// CHECK: return %[[RES_TRANSPOSE]]
|
||||
|
||||
%0 = "tf.Conv2D"(%input, %filter)
|
||||
{
|
||||
data_format = "NCHW",
|
||||
dilations = [1, 2, 3, 4],
|
||||
explicit_paddings = [1, 2, 3, 4, 5, 6, 7, 8],
|
||||
padding = "EXPLICIT",
|
||||
strides = [5, 6, 7, 8]
|
||||
} : (tensor<1x3x32x32xf32>, tensor<1x1x3x8xf32>) -> tensor<1x8x32x32xf32>
|
||||
|
||||
return %0 : tensor<1x8x32x32xf32>
|
||||
}
|
@ -0,0 +1,67 @@
|
||||
// RUN: tf-opt %s -tf-move-transposes=direction=begin -verify-diagnostics | FileCheck %s --dump-input=always
|
||||
|
||||
// CHECK-LABEL: func @move_across_single_op
|
||||
func @move_across_single_op(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
||||
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||
// CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: return %[[TANH]]
|
||||
|
||||
%0 = "tf.Tanh"(%arg0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
||||
%1 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%2 = "tf.Transpose"(%0, %1) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
||||
|
||||
return %2 : tensor<1x8x4x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @move_across_multiple_ops
|
||||
func @move_across_multiple_ops(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
||||
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||
// CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: %[[RELU:[0-9]*]] = "tf.Relu"(%[[TANH]]) {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: return %[[RELU]]
|
||||
|
||||
%0 = "tf.Tanh"(%arg0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
||||
%1 = "tf.Relu"(%0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
||||
|
||||
%2 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%3 = "tf.Transpose"(%1, %2) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
||||
|
||||
return %3 : tensor<1x8x4x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @move_across_multi_operand_op
|
||||
func @move_across_multi_operand_op(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
||||
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||
// CHECK: %[[ARG0_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||
// CHECK: %[[ARG1_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg1, %[[ARG_PERM]])
|
||||
// CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%[[ARG0_TRANSPOSE]], %[[ARG1_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: return %[[ADD]]
|
||||
|
||||
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
||||
%1 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%2 = "tf.Transpose"(%0, %1) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
||||
|
||||
return %2 : tensor<1x8x4x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @move_with_multiple_uses
|
||||
func @move_with_multiple_uses(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
||||
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||
// CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%[[TANH]], %[[TANH]]) {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: return %[[ADD]]
|
||||
|
||||
%0 = "tf.Tanh"(%arg0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
||||
%1 = "tf.AddV2"(%0, %0) : (tensor<1x4x4x8xf32>, tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
||||
%2 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%3 = "tf.Transpose"(%1, %2) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
||||
|
||||
return %3 : tensor<1x8x4x4xf32>
|
||||
}
|
@ -0,0 +1,120 @@
|
||||
// RUN: tf-opt %s -tf-move-transposes=direction=end -verify-diagnostics | FileCheck %s --dump-input=always
|
||||
|
||||
// CHECK-LABEL: func @move_across_single_op
|
||||
func @move_across_single_op(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
||||
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||
// CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%arg0) {{.*}} tensor<1x4x4x8xf32>
|
||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[TANH]], %[[RES_PERM]]) {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: return %[[RES_TRANSPOSE]]
|
||||
|
||||
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
||||
%2 = "tf.Tanh"(%1) : (tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32>
|
||||
|
||||
return %2 : tensor<1x8x4x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @move_across_multiple_ops
|
||||
func @move_across_multiple_ops(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
||||
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||
// CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%arg0) {{.*}} tensor<1x4x4x8xf32>
|
||||
// CHECK: %[[RELU:[0-9]*]] = "tf.Relu"(%[[TANH]]) {{.*}} tensor<1x4x4x8xf32>
|
||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[RELU]], %[[RES_PERM]])
|
||||
// CHECK: return %[[RES_TRANSPOSE]]
|
||||
|
||||
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
||||
%2 = "tf.Tanh"(%1) : (tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32>
|
||||
%3 = "tf.Relu"(%2) : (tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32>
|
||||
|
||||
return %3 : tensor<1x8x4x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @move_across_multi_operand_op
|
||||
func @move_across_multi_operand_op(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
||||
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||
// CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%arg0, %arg1) {{.*}} tensor<1x4x4x8xf32>
|
||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[ADD]], %[[RES_PERM]])
|
||||
// CHECK: return %[[RES_TRANSPOSE]]
|
||||
|
||||
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
||||
%2 = "tf.Transpose"(%arg1, %0) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
||||
%3 = "tf.AddV2"(%1, %2) : (tensor<1x8x4x4xf32>, tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32>
|
||||
|
||||
return %3 : tensor<1x8x4x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fold_into_max_pool
|
||||
func @fold_into_max_pool(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x56x56x64xf32> {
|
||||
|
||||
// MaxPool operand transpose must be folded into the op and MaxPool
|
||||
// must use NCHW data format with updated kernel size and strides.
|
||||
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
|
||||
// CHECK: %[[MAX_POOL:[0-9]*]] = "tf.MaxPool"(%arg0) {data_format = "NCHW", ksize = [1, 1, 3, 3], padding = "SAME", strides = [1, 1, 2, 2]} : (tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xf32>
|
||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[MAX_POOL]], %[[RES_PERM]])
|
||||
// CHECK: return %[[RES_TRANSPOSE]]
|
||||
|
||||
// Transpose NCHW -> NHWC
|
||||
%0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x64x112x112xf32>, tensor<4xi64>) -> tensor<1x112x112x64xf32>
|
||||
|
||||
// Compute MaxPool in NHWC format
|
||||
%2 = "tf.MaxPool"(%1)
|
||||
{
|
||||
data_format = "NHWC", ksize = [1, 3, 3, 1],
|
||||
padding = "SAME", strides = [1, 2, 2, 1]
|
||||
} : (tensor<1x112x112x64xf32>) -> tensor<1x56x56x64xf32>
|
||||
|
||||
return %2 : tensor<1x56x56x64xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fold_into_mean
|
||||
func @fold_into_mean(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x64xf32> {
|
||||
|
||||
// CHECK: %[[RED_IDX:[0-9]*]] = "tf.Const"() {value = dense<[2, 3]> : tensor<2xi64>}
|
||||
// CHECK: %[[MEAN:[0-9]*]] = "tf.Mean"(%arg0, %[[RED_IDX]])
|
||||
// CHECK-SAME: (tensor<1x64x112x112xf32>, tensor<2xi64>) -> tensor<1x64xf32>
|
||||
// CHECK: return %[[MEAN]]
|
||||
|
||||
// Transpose NCHW -> NHWC
|
||||
%0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x64x112x112xf32>, tensor<4xi64>) -> tensor<1x112x112x64xf32>
|
||||
|
||||
// Compute Mean over spatial dimensions in NHWC format.
|
||||
%2 = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
%3 = "tf.Mean"(%1, %2) : (tensor<1x112x112x64xf32>, tensor<2xi64>) -> tensor<1x64xf32>
|
||||
|
||||
return %3 : tensor<1x64xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fold_into_fused_batch_norm
|
||||
func @fold_into_fused_batch_norm(%arg0: tensor<1x64x112x112xf32>, %arg1: tensor<64xf32>) -> tensor<1x112x112x64xf32> {
|
||||
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
|
||||
// CHECK: "tf.FusedBatchNormV3"(%arg0, {{.*}} {data_format = "NCHW"
|
||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%y, %[[RES_PERM]])
|
||||
// CHECK: return %[[RES_TRANSPOSE]]
|
||||
|
||||
// Transpose NCHW -> NHWC
|
||||
%0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x64x112x112xf32>, tensor<4xi64>) -> tensor<1x112x112x64xf32>
|
||||
|
||||
// Compute FusedBatchNormV3 in NHWC format
|
||||
%2, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3
|
||||
= "tf.FusedBatchNormV3"(%1, %arg1, %arg1, %arg1, %arg1)
|
||||
{
|
||||
data_format = "NHWC",
|
||||
epsilon = 1.001 : f32,
|
||||
exponential_avg_factor = 1.0 : f32,
|
||||
is_training = false
|
||||
}
|
||||
: (tensor<1x112x112x64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
|
||||
-> (tensor<1x112x112x64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
|
||||
|
||||
return %2#0 : tensor<1x112x112x64xf32>
|
||||
}
|
@ -0,0 +1,194 @@
|
||||
// RUN: tf-opt %s -tf-parallel-execute-to-islands | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK-LABEL: func @check_regions_to_islands
|
||||
func @check_regions_to_islands() {
|
||||
tf_executor.graph {
|
||||
tf_executor.island() {
|
||||
"tf_device.parallel_execute"() ({
|
||||
tf_device.return
|
||||
},
|
||||
{
|
||||
tf_device.return
|
||||
}) {} : () -> ()
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: %[[ISLAND_INPUT_CTL:[a-z_0-9]*]] = tf_executor.island {
|
||||
// CHECK-NEXT: tf_executor.yield
|
||||
// CHECK: %[[ISLAND_1_CTL:[a-z_0-9]*]] = tf_executor.island(%[[ISLAND_INPUT_CTL]]) {
|
||||
// CHECK: tf_executor.yield
|
||||
// CHECK: %[[ISLAND_2_CTL:[a-z_0-9]*]] = tf_executor.island(%[[ISLAND_INPUT_CTL]]) {
|
||||
// CHECK: tf_executor.yield
|
||||
// CHECK: %{{.*}} = tf_executor.island(%[[ISLAND_1_CTL]], %[[ISLAND_2_CTL]]) {
|
||||
// CHECK-NEXT: tf_executor.yield
|
||||
|
||||
|
||||
// CHECK-LABEL: func @check_regions_to_islands_with_inputs
|
||||
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
|
||||
func @check_regions_to_islands_with_inputs(%arg0 : tensor<i1>) {
|
||||
tf_executor.graph {
|
||||
%1:2 = tf_executor.island {
|
||||
%2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
|
||||
tf_executor.yield %2 : tensor<i1>
|
||||
}
|
||||
tf_executor.island() {
|
||||
"tf_device.parallel_execute"() ({
|
||||
%3 = "tf.opB"(%1#0) : (tensor<i1>) -> tensor<i1>
|
||||
tf_device.return %3 : tensor<i1>
|
||||
},
|
||||
{
|
||||
%5 = "tf.opC"(%1#0) : (tensor<i1>) -> tensor<i32>
|
||||
tf_device.return %5 : tensor<i32>
|
||||
}) {} : () -> (tensor<i1>, tensor<i32>)
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island {
|
||||
// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor<i1>) -> tensor<i1>
|
||||
// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor<i1>
|
||||
// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CONTROL:[a-z_0-9]*]] = tf_executor.island {
|
||||
// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor<i1>
|
||||
// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %[[ISLAND_1_CTL:[a-z_0-9]*]] = tf_executor.island {
|
||||
// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"(%[[INPUT_0]]) : (tensor<i1>) -> tensor<i1>
|
||||
// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor<i1>
|
||||
// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %[[ISLAND_2_CTL:[a-z_0-9]*]] = tf_executor.island {
|
||||
// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"(%outputs_0) : (tensor<i1>) -> tensor<i32>
|
||||
// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor<i32>
|
||||
// CHECK: %{{.*}} = tf_executor.island(%[[ISLAND_1_CTL]], %[[ISLAND_2_CTL]]) {
|
||||
// CHECK-NEXT: tf_executor.yield
|
||||
|
||||
|
||||
// CHECK-LABEL: func @check_input_sink_island_forwards_control_inputs
|
||||
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
|
||||
func @check_input_sink_island_forwards_control_inputs(%arg0 : tensor<i1>) {
|
||||
tf_executor.graph {
|
||||
%1:2 = tf_executor.island {
|
||||
%2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
|
||||
tf_executor.yield %2 : tensor<i1>
|
||||
}
|
||||
%7 = tf_executor.ControlTrigger {}
|
||||
%8 = tf_executor.ControlTrigger {}
|
||||
tf_executor.island(%7, %8) {
|
||||
"tf_device.parallel_execute"() ({
|
||||
%3 = "tf.opB"(%1#0) : (tensor<i1>) -> tensor<i1>
|
||||
tf_device.return %3 : tensor<i1>
|
||||
},
|
||||
{
|
||||
%5 = "tf.opC"() : () -> tensor<i32>
|
||||
tf_device.return %5 : tensor<i32>
|
||||
}) {} : () -> (tensor<i1>, tensor<i32>)
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island {
|
||||
// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor<i1>) -> tensor<i1>
|
||||
// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor<i1>
|
||||
// CHECK: %[[CT_0:[0-9]*]] = tf_executor.ControlTrigger
|
||||
// CHECK: %[[CT_1:[0-9]*]] = tf_executor.ControlTrigger
|
||||
// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CONTROL:[a-z_0-9]*]] = tf_executor.island(%[[CT_0]], %[[CT_1]]) {
|
||||
// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor<i1>
|
||||
// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %[[ISLAND_1_CTL:[a-z_0-9]*]] = tf_executor.island {
|
||||
// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"(%[[INPUT_0]]) : (tensor<i1>) -> tensor<i1>
|
||||
// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor<i1>
|
||||
// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %[[ISLAND_2_CTL:[a-z_0-9]*]] = tf_executor.island(%[[INPUT_CONTROL]]) {
|
||||
// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"() : () -> tensor<i32>
|
||||
// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor<i32>
|
||||
// CHECK: %{{.*}} = tf_executor.island(%[[ISLAND_1_CTL]], %[[ISLAND_2_CTL]]) {
|
||||
// CHECK-NEXT: tf_executor.yield
|
||||
|
||||
|
||||
// CHECK-LABEL: func @check_control_dep_added_when_region_does_not_have_inputs
|
||||
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
|
||||
func @check_control_dep_added_when_region_does_not_have_inputs(%arg0 : tensor<i1>) {
|
||||
tf_executor.graph {
|
||||
%1:2 = tf_executor.island {
|
||||
%2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
|
||||
tf_executor.yield %2 : tensor<i1>
|
||||
}
|
||||
%7:3 = tf_executor.island() {
|
||||
%8:2 = "tf_device.parallel_execute"() (
|
||||
{
|
||||
%3 = "tf.opB"() : () -> tensor<i1>
|
||||
tf_device.return %3 : tensor<i1>
|
||||
},
|
||||
{
|
||||
%5 = "tf.opC"(%1#0) : (tensor<i1>) -> tensor<i32>
|
||||
tf_device.return %5 : tensor<i32>
|
||||
}
|
||||
) {} : () -> (tensor<i1>, tensor<i32>)
|
||||
|
||||
tf_executor.yield %8#0, %8#1 : tensor<i1>, tensor<i32>
|
||||
}
|
||||
|
||||
tf_executor.island {
|
||||
"tf.opD"(%7#0, %7#1) : (tensor<i1>, tensor<i32>) -> ()
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island {
|
||||
// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor<i1>) -> tensor<i1>
|
||||
// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor<i1>
|
||||
// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CTL:[a-z_0-9]*]] = tf_executor.island {
|
||||
// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor<i1>
|
||||
// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island(%[[INPUT_CTL]]) {
|
||||
// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"() : () -> tensor<i1>
|
||||
// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor<i1>
|
||||
// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island {
|
||||
// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"(%outputs_0) : (tensor<i1>) -> tensor<i32>
|
||||
// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor<i32>
|
||||
// CHECK: %{{.*}} = tf_executor.island {
|
||||
// CHECK-NEXT: tf_executor.yield %[[ISLAND_1_OUTPUT]], %[[ISLAND_2_OUTPUT]]
|
||||
|
||||
|
||||
// CHECK-LABEL: func @check_output_barrier_correctly_forwards_outputs
|
||||
func @check_output_barrier_correctly_forwards_outputs(%arg0 : tensor<i1>) -> tensor<i1> {
|
||||
%0 = tf_executor.graph {
|
||||
%1:2 = tf_executor.island {
|
||||
%2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
|
||||
tf_executor.yield %2 : tensor<i1>
|
||||
}
|
||||
%8:3 = tf_executor.island() {
|
||||
%7:2 = "tf_device.parallel_execute"() ({
|
||||
%3 = "tf.opB"() : () -> tensor<i1>
|
||||
tf_device.return %3 : tensor<i1>
|
||||
},
|
||||
{
|
||||
%5 = "tf.opC"(%1#0) : (tensor<i1>) -> tensor<i32>
|
||||
tf_device.return %5 : tensor<i32>
|
||||
}) {} : () -> (tensor<i1>, tensor<i32>)
|
||||
tf_executor.yield %7#0, %7#1 : tensor<i1>, tensor<i32>
|
||||
}
|
||||
tf_executor.fetch %8#0 : tensor<i1>
|
||||
}
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island {
|
||||
// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor<i1>) -> tensor<i1>
|
||||
// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor<i1>
|
||||
// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CTL:[a-z_0-9]*]] = tf_executor.island {
|
||||
// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor<i1>
|
||||
// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island(%[[INPUT_CTL]]) {
|
||||
// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"() : () -> tensor<i1>
|
||||
// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor<i1>
|
||||
// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island {
|
||||
// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"(%[[INPUT_0]]) : (tensor<i1>) -> tensor<i32>
|
||||
// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor<i32>
|
||||
// CHECK: %[[OUTPUT_SINK_OUTPUT:[a-z_0-9]*]]:2, %[[OUTPUT_SINK_CTL:[a-z_0-9]*]] = tf_executor.island {
|
||||
// CHECK-NEXT: tf_executor.yield %[[ISLAND_1_OUTPUT]], %[[ISLAND_2_OUTPUT]] : tensor<i1>, tensor<i32>
|
@ -542,3 +542,116 @@ func @if_else(%arg0: tensor<*x!tf.resource<tensor<4xf32>>>, %arg1: tensor<*x!tf.
|
||||
-> (tensor<*x!tf.resource<tensor<4xf32>>>) {
|
||||
return %arg1 : tensor<*x!tf.resource<tensor<4xf32>>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass lifts resources on two partitioned call ops sharing the
|
||||
// same callee. The lifting should clone the callee then modify the clone.
|
||||
|
||||
// CHECK-LABEL: @launch_with_partitioned_call
|
||||
func @launch_with_partitioned_call() -> tensor<f32> {
|
||||
// CHECK: %[[VH:.*]] = "tf.VarHandleOp"()
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
// CHECK: %[[CONST:.*]] = "tf.Const"()
|
||||
%1 = "tf.Const"() {value = dense<10.0> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VH]])
|
||||
// CHECK: %[[LAUNCH:.*]] = "tf_device.launch"()
|
||||
%2 = "tf_device.launch"() ( {
|
||||
// CHECK: %[[PC0:.*]] = "tf.PartitionedCall"(%[[CONST]], %[[READ]], %[[CONST]])
|
||||
// CHECK-SAME: f = @callee_resource_lifted
|
||||
%3 = "tf.PartitionedCall"(%1, %0, %1) {f = @callee, config = "", config_proto = "", executor_type = ""}
|
||||
: (tensor<f32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[PC1:.*]] = "tf.PartitionedCall"(%[[CONST]], %[[READ]], %[[CONST]])
|
||||
// CHECK-SAME: f = @callee_resource_lifted
|
||||
%4 = "tf.PartitionedCall"(%1, %0, %1) {f = @callee, config = "", config_proto = "", executor_type = ""}
|
||||
: (tensor<f32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[PC0]], %[[PC1]])
|
||||
%5 = "tf.AddV2"(%3, %4) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: tf_device.return %[[ADD]] : tensor<f32>
|
||||
tf_device.return %5 : tensor<f32>
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<f32>
|
||||
return %2 : tensor<f32>
|
||||
}
|
||||
// CHECK: @callee(%[[OA0:.*]]: tensor<f32>, %[[OA1:.*]]: tensor<*x!tf.resource<tensor<f32>>>, %[[OA2:.*]]: tensor<f32>) -> tensor<f32>
|
||||
func @callee(%arg0: tensor<f32>, %arg1: tensor<*x!tf.resource<tensor<f32>>>, %arg2: tensor<f32>) -> tensor<f32> {
|
||||
// CHECK: "tf.ReadVariableOp"(%[[OA1]])
|
||||
%0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
%1 = "tf.AddV2"(%0, %arg0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
%2 = "tf.AddV2"(%1, %arg2) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %2 : tensor<f32>
|
||||
}
|
||||
// CHECK: func @callee_resource_lifted(%[[A0:.*]]: tensor<f32>, %[[A1:.*]]: tensor<f32>, %[[A2:.*]]: tensor<f32>) -> tensor<f32>
|
||||
// CHECK-NEXT: %[[ADD0:.*]] = "tf.AddV2"(%[[A1]], %[[A0]])
|
||||
// CHECK-NEXT: %[[ADD1:.*]] = "tf.AddV2"(%[[ADD0]], %[[A2]])
|
||||
// CHECK-NEXT: return %[[ADD1]]
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass lifts resources on two stateful partitioned call ops
|
||||
// sharing the same callee. The lifting should clone the callee then modify the
|
||||
// clone.
|
||||
|
||||
// CHECK-LABEL: @launch_with_stateful_partitioned_call
|
||||
func @launch_with_stateful_partitioned_call() -> () {
|
||||
// CHECK: %[[VH0:.*]] = "tf.VarHandleOp"()
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
// CHECK: %[[VH1:.*]] = "tf.VarHandleOp"()
|
||||
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
// CHECK: %[[CONST:.*]] = "tf.Const"()
|
||||
%2 = "tf.Const"() {value = dense<10.0> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK-DAG: %[[READ0:.*]] = "tf.ReadVariableOp"(%[[VH0]])
|
||||
// CHECK-DAG: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[VH1]])
|
||||
// CHECK: %[[LAUNCH:.*]] = "tf_device.launch"()
|
||||
"tf_device.launch"() ( {
|
||||
// CHECK: %[[PC0:.*]] = "tf.StatefulPartitionedCall"(%[[READ0]], %[[READ1]], %[[CONST]])
|
||||
// CHECK-SAME: f = @callee_resource_lifted
|
||||
%3 = "tf.StatefulPartitionedCall"(%0, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
// CHECK: %[[PC1:.*]] = "tf.StatefulPartitionedCall"(%[[PC0]], %[[READ1]], %[[CONST]])
|
||||
// CHECK-SAME: f = @callee_resource_lifted
|
||||
%4 = "tf.StatefulPartitionedCall"(%3, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
// CHECK: tf_device.return %[[PC1]] : tensor<f32>
|
||||
tf_device.return
|
||||
// CHECK: {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<f32>
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
|
||||
// CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[LAUNCH]])
|
||||
return
|
||||
}
|
||||
// CHECK: @callee(%[[OA0:.*]]: tensor<*x!tf.resource<tensor<f32>>>, %[[OA1:.*]]: tensor<*x!tf.resource<tensor<f32>>>, %[[OA2:.*]]: tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
func @callee(%arg0: tensor<*x!tf.resource<tensor<f32>>>, %arg1: tensor<*x!tf.resource<tensor<f32>>>, %arg2: tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>> {
|
||||
// CHECK: "tf.ReadVariableOp"(%[[OA1]])
|
||||
%0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
%1 = "tf.AddV2"(%0, %arg2) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
"tf.AssignVariableOp"(%arg0, %1) {dtype = i32} : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
|
||||
return %arg0 : tensor<*x!tf.resource<tensor<f32>>>
|
||||
}
|
||||
// CHECK: func @callee_resource_lifted(%[[A0:.*]]: tensor<f32>, %[[A1:.*]]: tensor<f32>, %[[A2:.*]]: tensor<f32>) -> tensor<f32>
|
||||
// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[A1]], %[[A2]])
|
||||
// CHECK-NEXT: return %[[ADD]]
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass reports error on called function that has resource output
|
||||
// which doesn't alias an input.
|
||||
|
||||
func @launch_with_stateful_partitioned_call() -> () {
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
%2 = "tf.Const"() {value = dense<10.0> : tensor<f32>} : () -> tensor<f32>
|
||||
"tf_device.launch"() ( {
|
||||
%3 = "tf.StatefulPartitionedCall"(%0, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
%4 = "tf.StatefulPartitionedCall"(%3, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
tf_device.return
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
// expected-error @+1 {{Unsupported function call: resource return value does not alias an input.}}
|
||||
func @callee(%arg0: tensor<*x!tf.resource<tensor<f32>>>, %arg1: tensor<*x!tf.resource<tensor<f32>>>, %arg2: tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>> {
|
||||
%0 = "tf._Unknown_"() : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
return %0 : tensor<*x!tf.resource<tensor<f32>>>
|
||||
}
|
||||
|
@ -45,6 +45,17 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
return %1 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<?xf32>
|
||||
func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
|
||||
br ^bb1
|
||||
^bb1:
|
||||
// CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK: return %[[IDENTITY]] : tensor<?xf32>
|
||||
%ret = "tf.Identity"(%arg0) : (tensor<?xf32>) -> tensor<*xf32>
|
||||
return %ret : tensor<*xf32>
|
||||
}
|
||||
|
||||
|
||||
// Tests the case where an inference opportunity relies on folding.
|
||||
|
||||
// CHECK-LABEL: func @simple_folding
|
||||
|
@ -46,7 +46,7 @@ class TestModule(tf.Module):
|
||||
# CHECK: "tf_saved_model.global_tensor"() {sym_name = "[[CONST:[a-zA-Z_0-9]+]]", tf_saved_model.exported_names = [], type = tensor<f32>, value = dense<4.300000e+01> : tensor<f32>} : () -> ()
|
||||
# CHECK: func {{@[a-zA-Z_0-9]+}}(
|
||||
# CHECK-SAME: %arg0: tensor<f32> {tf_saved_model.index_path = [0]},
|
||||
# CHECK-SAME: %arg1: tensor<*x!tf.resource> {tf_saved_model.bound_input = @[[VAR]]},
|
||||
# CHECK-SAME: %arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @[[VAR]]},
|
||||
# CHECK-SAME: %arg2: tensor<f32> {tf_saved_model.bound_input = @[[CONST]]}) -> (
|
||||
# CHECK-SAME: tensor<f32> {tf_saved_model.index_path = []})
|
||||
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["some_function"]
|
||||
|
@ -46,7 +46,7 @@ class TestModule(tf.Module):
|
||||
#
|
||||
# CHECK: func {{@[a-zA-Z_0-9]+}}(
|
||||
# CHECK-SAME: %arg0: tensor<f32> {tf_saved_model.index_path = [0]},
|
||||
# CHECK-SAME: %arg1: tensor<*x!tf.resource> {tf_saved_model.bound_input = {{@[a-zA-Z_0-9]+}}}
|
||||
# CHECK-SAME: %arg1: tensor<!tf.resource<{{.*}}>> {tf_saved_model.bound_input = {{@[a-zA-Z_0-9]+}}}
|
||||
# CHECK-SAME: ) -> (
|
||||
# CHECK-SAME: tensor<f32> {tf_saved_model.index_path = [0]},
|
||||
# CHECK-SAME: tensor<f32> {tf_saved_model.index_path = [1]})
|
||||
@ -55,7 +55,7 @@ class TestModule(tf.Module):
|
||||
#
|
||||
# CHECK: func {{@[a-zA-Z_0-9]+}}(
|
||||
# CHECK-SAME: %arg0: tensor<f32> {tf_saved_model.index_path = [0]},
|
||||
# CHECK-SAME: %arg1: tensor<*x!tf.resource> {tf_saved_model.bound_input = {{@[a-zA-Z_0-9]+}}}
|
||||
# CHECK-SAME: %arg1: tensor<!tf.resource<{{.*}}>> {tf_saved_model.bound_input = {{@[a-zA-Z_0-9]+}}}
|
||||
# CHECK-SAME: ) -> (
|
||||
# CHECK-SAME: tensor<f32> {tf_saved_model.index_path = [0]},
|
||||
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [1]})
|
||||
|
@ -25,8 +25,8 @@ module attributes {tf_saved_model.semantics} {
|
||||
// CHECK: tf_saved_model.global_tensor
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<1.0> : tensor<f32> } : () -> ()
|
||||
|
||||
// CHECK: func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v})
|
||||
func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v})
|
||||
// CHECK: func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
// CHECK-NOT: tf.Const
|
||||
return
|
||||
|
@ -26,7 +26,7 @@ module attributes {tf_saved_model.semantics} {
|
||||
func @__concrete_function_run_computation(
|
||||
%arg0: tensor<f32> {tf_saved_model.index_path = [0, "foo"]},
|
||||
%arg1: tensor<1x64xf32> {tf_saved_model.bound_input = @some_constant},
|
||||
%arg2: tensor<*x!tf.resource> {tf_saved_model.bound_input = @some_variable}
|
||||
%arg2: tensor<!tf.resource<tensor<?x64xf32>>> {tf_saved_model.bound_input = @some_variable}
|
||||
) -> (
|
||||
tensor<f32> {tf_saved_model.index_path = [0, "bar"]}
|
||||
) attributes { tf_saved_model.exported_names = ["some_func"] }
|
||||
|
@ -219,8 +219,8 @@ module attributes {tf_saved_model.semantics} {
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.0> : tensor<f32> } : () -> ()
|
||||
// expected-error@+1 {{duplicate 'tf_saved_model.bound_input' binding}}
|
||||
func @f(
|
||||
%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v},
|
||||
%arg1: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}
|
||||
%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v},
|
||||
%arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}
|
||||
) attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
return
|
||||
}
|
||||
@ -232,9 +232,9 @@ module attributes {tf_saved_model.semantics} {
|
||||
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<?xf32>, value = dense<1.> : tensor<1xf32> } : () -> ()
|
||||
// expected-error@+1 {{can only apply 'tf_saved_model' argument attributes to exported functions}}
|
||||
func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v})
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<?xf32>>> {tf_saved_model.bound_input = @v})
|
||||
-> (tensor<?xf32> {tf_saved_model.index_path = []}) {
|
||||
%0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor<?xf32>
|
||||
%0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<?xf32>>>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
}
|
||||
@ -244,7 +244,7 @@ module attributes {tf_saved_model.semantics} {
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<?xf32>, value = dense<1.> : tensor<1xf32> } : () -> ()
|
||||
// expected-error@+1 {{bound inputs for mutable 'tf_saved_model.global_tensor's must be tensors of '!tf.resource'}}
|
||||
// expected-error@+1 {{mutable bound input with type 'tensor<f32>' expected to have type 'tensor<!tf.resource<tensor<?xf32>>>'}}
|
||||
func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @v})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
return
|
||||
@ -257,7 +257,7 @@ module attributes {tf_saved_model.semantics} {
|
||||
|
||||
"tf_saved_model.global_tensor"() { sym_name = "v", type = tensor<1xf32>, value = dense<1.> : tensor<1xf32> } : () -> ()
|
||||
// expected-error@+1 {{bound input for immutable 'tf_saved_model.global_tensor' must match the global tensor's type}}
|
||||
func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v})
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<1xf32>>> {tf_saved_model.bound_input = @v})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
return
|
||||
}
|
||||
|
@ -14,10 +14,10 @@ module attributes {tf_saved_model.semantics} {
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
// CHECK: func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @v})
|
||||
func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
// CHECK-NOT: tf.ReadVariableOp
|
||||
%val = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor<f32>
|
||||
%val = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
// CHECK: return %arg0
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
@ -35,12 +35,12 @@ module attributes {tf_saved_model.semantics} {
|
||||
// CHECK-SAME: } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
// CHECK: func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v})
|
||||
func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v})
|
||||
// CHECK: func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
%c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
|
||||
// CHECK: tf.AssignVariableOp
|
||||
"tf.AssignVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> ()
|
||||
"tf.AssignVariableOp"(%arg0, %c0) : (tensor<!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -57,10 +57,10 @@ module attributes {tf_saved_model.semantics} {
|
||||
// CHECK-SAME: } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", tf_saved_model.exported_names = ["v"], type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
// CHECK: func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v})
|
||||
func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
|
||||
// CHECK: func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
%val = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor<f32>
|
||||
%val = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user