Merge branch 'master' into dlpack

This commit is contained in:
Jinjing Zhou 2020-02-21 20:04:36 +08:00 committed by GitHub
commit 0dad831803
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
937 changed files with 32314 additions and 31199 deletions

View File

@ -319,11 +319,13 @@ build:xla --define=with_xla_support=true
# BEGIN TF REMOTE BUILD EXECUTION OPTIONS
# Options when using remote execution
# WARNING: THESE OPTIONS WONT WORK IF YOU DO NOT HAVE PROPER AUTHENTICATION AND PERMISSIONS
# Flag to enable remote config
common --experimental_repo_remote_exec
build:rbe --action_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1
build:rbe --auth_enabled=true
build:rbe --auth_scope=https://www.googleapis.com/auth/cloud-source-tools
build:rbe --google_default_credentials
build:rbe --bes_backend=buildeventservice.googleapis.com
build:rbe --bes_best_effort=false
build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations"
build:rbe --bes_timeout=600s
build:rbe --define=EXECUTOR=remote
@ -336,7 +338,7 @@ build:rbe --spawn_strategy=remote,worker,standalone,local
test:rbe --test_env=USER=anon
# Attempt to minimize the amount of data transfer between bazel and the remote
# workers:
build:rbe --experimental_inmemory_jdeps_files --experimental_inmemory_dotd_files --experimental_remote_download_outputs=toplevel
build:rbe --remote_download_toplevel
build:rbe_linux --config=rbe
build:rbe_linux --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin"

View File

@ -10,13 +10,20 @@ labels: 'type:bug'
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:bug_template</em>
**System information** - Have I written custom code (as opposed to using a stock
example script provided in TensorFlow): - OS Platform and Distribution (e.g.,
Linux Ubuntu 16.04): - Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
the issue happens on mobile device: - TensorFlow installed from (source or
binary): - TensorFlow version (use command below): - Python version: - Bazel
version (if compiling from source): - GCC/Compiler version (if compiling from
source): - CUDA/cuDNN version: - GPU model and memory:
**System information**
- Have I written custom code (as opposed to using a stock
example script provided in TensorFlow):
- OS Platform and Distribution (e.g.,
Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
the issue happens on mobile device:
- TensorFlow installed from (source or
binary): - TensorFlow version (use command below):
- Python version: - Bazel
version (if compiling from source):
- GCC/Compiler version (if compiling from
source):
- CUDA/cuDNN version: - GPU model and memory:
You can collect some of this information using our environment capture
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
@ -28,8 +35,9 @@ tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
**Describe the expected behavior**
**Code to reproduce the issue** Provide a reproducible test case that is the
bare minimum necessary to generate the problem.
**Standalone code to reproduce the issue**
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
**Other info / logs** Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full

View File

@ -17,8 +17,14 @@ labels: 'comp:lite'
# Copy and paste here
```
**Standalone code to reproduce the issue**
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
Also, please include a link to a GraphDef or the model if possible.
**Any other info / logs**
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
Include any logs or source code that would be helpful to diagnose the problem.
If including tracebacks, please include the full traceback. Large logs and files
should be attached.

View File

@ -1,6 +1,7 @@
---
name: TensorFlow Lite New Converter Issue
about: Use this template for reporting issues during model conversion to TFLite
labels: 'TFLiteConverter'
---
@ -12,6 +13,7 @@ about: Use this template for reporting issues during model conversion to TFLite
**Command used to run the converter or code if youre using the Python API**
If possible, please share a link to Colab/Jupyter/any notebook.
```
# Copy and paste here the exact command

View File

@ -11,13 +11,20 @@ As per our
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:performance_template</em>
**System information** - Have I written custom code (as opposed to using a stock
example script provided in TensorFlow): - OS Platform and Distribution (e.g.,
Linux Ubuntu 16.04): - Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
the issue happens on mobile device: - TensorFlow installed from (source or
binary): - TensorFlow version (use command below): - Python version: - Bazel
version (if compiling from source): - GCC/Compiler version (if compiling from
source): - CUDA/cuDNN version: - GPU model and memory:
**System information**
- Have I written custom code (as opposed to using a stock
example script provided in TensorFlow):
- OS Platform and Distribution (e.g.,
Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
the issue happens on mobile device:
- TensorFlow installed from (source or
binary): - TensorFlow version (use command below):
- Python version: - Bazel
version (if compiling from source):
- GCC/Compiler version (if compiling from
source):
- CUDA/cuDNN version: - GPU model and memory:
You can collect some of this information using our environment capture
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
@ -29,8 +36,9 @@ tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
**Describe the expected behavior**
**Code to reproduce the issue** Provide a reproducible test case that is the
bare minimum necessary to generate the problem.
**Standalone code to reproduce the issue**
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
**Other info / logs** Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full

View File

@ -113,3 +113,28 @@ http_archive(
"https://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip",
],
)
# Required for dependency @com_github_grpc_grpc
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
grpc_deps()
load(
"@build_bazel_rules_apple//apple:repositories.bzl",
"apple_rules_dependencies",
)
apple_rules_dependencies()
load(
"@build_bazel_apple_support//lib:repositories.bzl",
"apple_support_dependencies",
)
apple_support_dependencies()
load("@upb//bazel:repository_defs.bzl", "bazel_version_repository")
bazel_version_repository(name = "bazel_version")

View File

@ -49,7 +49,7 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
_TF_WORKSPACE_ROOT = ''
_TF_BAZELRC = ''
_TF_CURRENT_BAZEL_VERSION = None
_TF_MIN_BAZEL_VERSION = '1.2.1'
_TF_MIN_BAZEL_VERSION = '2.0.0'
_TF_MAX_BAZEL_VERSION = '2.0.0'
NCCL_LIB_PATHS = [

View File

@ -505,13 +505,15 @@ selects.config_setting_group(
package_group(
name = "internal",
packages = [
# To pass open source testing in the pip Kokoros.
"//bazel_pip/tensorflow/...",
"//learning/brain/swift/x10/...",
"//perftools/accelerators/xprof/api/...",
"//third_party/py/autograph/...",
"//third_party/swift/tensorflow/x10/...",
"//tensorflow/...",
"//tensorflow_estimator/python/estimator/...",
"//tensorflow_models/official/...",
"//third_party/py/autograph/...",
"//third_party/swift/tensorflow/x10/...",
],
)
@ -545,8 +547,8 @@ cc_library(
name = "grpc",
visibility = ["//visibility:public"],
deps = select({
":linux_s390x": ["@grpc//:grpc_unsecure"],
"//conditions:default": ["@grpc"],
":linux_s390x": ["@com_github_grpc_grpc//:grpc_unsecure"],
"//conditions:default": ["@com_github_grpc_grpc//:grpc"],
}),
)
@ -554,8 +556,8 @@ cc_library(
name = "grpc++",
visibility = ["//visibility:public"],
deps = select({
":linux_s390x": ["@grpc//:grpc++_unsecure"],
"//conditions:default": ["@grpc//:grpc++"],
":linux_s390x": ["@com_github_grpc_grpc//:grpc++_unsecure"],
"//conditions:default": ["@com_github_grpc_grpc//:grpc++"],
}),
)

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/net.h"
#include "tensorflow/core/platform/platform.h"
@ -816,12 +817,15 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
const int num_inputs = input_shapes->num_items;
NodeDef node_def;
node_def.set_name(tfe_op->operation.Name());
node_def.set_op(tfe_op->operation.Name());
node_def.set_name(tfe_op->operation->Name());
node_def.set_op(tfe_op->operation->Name());
for (int i = 0; i < num_inputs; ++i) {
node_def.add_input("dummy_input");
}
tfe_op->operation.Attrs().FillAttrValueMap(node_def.mutable_attr());
tensorflow::down_cast<tensorflow::OperationInterface*>(
tfe_op->operation.get())
->Attrs()
.FillAttrValueMap(node_def.mutable_attr());
const tensorflow::OpRegistrationData* op_reg_data;
status->status =

View File

@ -28,6 +28,8 @@ tf_cuda_library(
"c_api_debug.cc",
"c_api_experimental.h",
"c_api_internal.h",
"operation_interface.cc",
"operation_interface.h",
"tensor_handle_interface.h",
],
hdrs = ["c_api.h"],
@ -56,6 +58,7 @@ tf_cuda_library(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/platform:casts",
"//tensorflow/core/platform:errors",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler/lib:traceme",
@ -93,6 +96,7 @@ filegroup(
"c_api_experimental.h",
"c_api_internal.h",
"dlpack.h",
"operation_interface.h",
"tensor_handle_interface.h",
],
visibility = [
@ -105,6 +109,7 @@ tf_cuda_library(
name = "c_api_internal",
srcs = [
"c_api_experimental.h",
"operation_interface.h",
"tensor_handle_interface.h",
],
hdrs = ["c_api_internal.h"],
@ -129,6 +134,7 @@ tf_cuda_library(
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"@com_google_absl//absl/container:fixed_array",
],
)

View File

@ -27,7 +27,6 @@ limitations under the License.
// clang-format on
#include "absl/algorithm/container.h"
#include "absl/container/fixed_array.h"
#include "absl/memory/memory.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
@ -95,14 +94,6 @@ using tensorflow::string;
namespace {
const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) {
const tensorflow::OpDef* op_def = op->operation.OpDef();
if (op_def) return op_def;
status->status =
tensorflow::OpDefForOp(op->operation.Name().c_str(), &op_def);
return op_def;
}
bool IsCPU(
absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> variant) {
if (VariantDeviceIsCustom(variant)) {
@ -1125,9 +1116,13 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
return retval;
} else {
tensorflow::Tensor tensor;
if (IsCPU(handle_->device())) {
if (IsCPU(handle_->device()) || handle_->HasLocalMirror(nullptr)) {
const tensorflow::Tensor* src = nullptr;
*status = handle_->Tensor(&src);
if (handle_->HasLocalMirror(nullptr)) {
*status = handle_->TensorFromDevice(nullptr, &src);
} else {
*status = handle_->Tensor(&src);
}
if (!status->ok()) return nullptr;
tensor = *src;
} else {
@ -1135,6 +1130,13 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
CHECK_NE(ctx, nullptr);
*status = handle_->CopyToDevice(*ctx, ctx->HostCPU(), &tensor);
if (!status->ok()) return nullptr;
if (handle_->ImplicitMirroring()) {
*status = handle_->AddEmptyLocalMirror(nullptr);
if (!status->ok()) return nullptr;
Tensor mirror = tensor;
*status = handle_->SetTensor(std::move(mirror), nullptr);
if (!status->ok()) return nullptr;
}
}
return tensorflow::TF_TensorFromTensor(tensor, status);
}
@ -1199,18 +1201,11 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
}
if (dtype == TF_STRING || dtype == TF_RESOURCE ||
!tensorflow::DataTypeCanUseMemcpy(
static_cast<tensorflow::DataType>(dtype))) {
status->status = tensorflow::errors::InvalidArgument(
"Trying to create a tensor with a pointer to non-pod memory.");
deallocator(data, len, deallocator_arg);
return nullptr;
}
// TODO(apassos) do we need to wrap the deallocator here to make sure to sync
// the device?
TF_ManagedBuffer* buf =
new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
new TF_ManagedBuffer(data, len, deallocator, deallocator_arg,
/*owns_memory=*/false);
tensorflow::Tensor t(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimvec), buf);
@ -1261,9 +1256,8 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) {
std::unique_ptr<TFE_Op> new_op(
new TFE_Op{tensorflow::EagerOperation(ctx->context)});
status->status =
new_op->operation.Reset(op_or_function_name, nullptr, false, nullptr);
new TFE_Op{std::make_unique<tensorflow::OperationInterface>(ctx)});
status->status = new_op->operation->Reset(op_or_function_name, nullptr);
if (!status->status.ok()) {
new_op.reset();
}
@ -1273,49 +1267,51 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
void TFE_DeleteOp(TFE_Op* op) { delete op; }
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
status->status = op->operation.SetDeviceName(device_name);
status->status = op->operation->SetDeviceName(device_name);
}
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
tensorflow::Device* device = (op->operation.Device() == nullptr)
? op->operation.EagerContext().HostCPU()
: op->operation.Device();
return device->name().c_str();
return op->operation->DeviceName().c_str();
}
void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
op->operation.SetUseXla(enable);
#ifndef TENSORFLOW_EAGER_USE_XLA
#ifdef TENSORFLOW_EAGER_USE_XLA
tensorflow::Status s = op->operation->SetUseXla(enable);
if (!s.ok()) {
LOG(ERROR) << "Could not enable XLA compilation for op: " << s;
}
#else
LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
"built with XLA support.";
#endif // TENSORFLOW_EAGER_USE_XLA
}
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
tensorflow::TensorHandle* h =
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
input->handle.get())
->Handle();
op->operation.AddInput(h);
status->status = op->operation.MaybeInferSingleInputAttrs(h);
status->status = op->operation->AddInput(input->handle);
}
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
TF_Status* status) {
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>> handles(
num_inputs);
for (int i = 0; i < num_inputs; ++i) {
op->operation.AddInput(
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
inputs[i]->handle.get())
->Handle());
handles[i].reset(inputs[i]->handle->Copy());
}
status->status = op->operation.InferInputListAttrs(num_inputs);
status->status = op->operation->AddInputList(handles);
}
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
unsigned char* is_list, TF_Status* status) {
TF_AttrType ret = TF_ATTR_INT;
status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(),
attr_name, &ret, is_list);
const tensorflow::AttrTypeMap* attr_types_;
bool is_function;
status->status = tensorflow::AttrTypeMapForOp(op->operation->Name().c_str(),
&attr_types_, &is_function);
if (!status->status.ok()) {
return ret;
}
status->status =
tensorflow::AttrTypeByName(*attr_types_, attr_name, &ret, is_list);
return ret;
}
@ -1336,221 +1332,150 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
size_t length) {
op->operation.MutableAttrs()->Set(
attr_name,
tensorflow::StringPiece(static_cast<const char*>(value), length));
auto s = op->operation->SetAttrString(
attr_name, static_cast<const char*>(value), length);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
op->operation.MutableAttrs()->Set(attr_name, static_cast<int64>(value));
auto s = op->operation->SetAttrInt(attr_name, value);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
op->operation.MutableAttrs()->Set(attr_name, value);
auto s = op->operation->SetAttrFloat(attr_name, value);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
op->operation.MutableAttrs()->Set(attr_name, (value == 0) ? false : true);
auto s = op->operation->SetAttrBool(attr_name, (value == 0) ? false : true);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
op->operation.MutableAttrs()->Set(attr_name,
static_cast<tensorflow::DataType>(value));
auto s = op->operation->SetAttrType(attr_name, value);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
const int num_dims, TF_Status* out_status) {
if (num_dims > tensorflow::TensorShape::MaxDimensions()) {
TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat(
"Value specified for `", attr_name, "` has ", num_dims,
" dimensions which is over the limit of ",
tensorflow::TensorShape::MaxDimensions(), ".")
.c_str());
return;
}
tensorflow::TensorShapeProto proto;
if (num_dims < 0) {
proto.set_unknown_rank(true);
} else {
for (int d = 0; d < num_dims; ++d) {
proto.add_dim()->set_size(dims[d]);
}
}
op->operation.MutableAttrs()->Set(attr_name, proto);
out_status->status = op->operation->SetAttrShape(attr_name, dims, num_dims);
}
void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
const TFE_Op* value) {
tensorflow::AttrValue attr_value;
tensorflow::NameAttrList* func = attr_value.mutable_func();
func->set_name(value->operation.Name());
value->operation.Attrs().FillAttrValueMap(func->mutable_attr());
op->operation.MutableAttrs()->Set(attr_name, attr_value);
auto s = op->operation->SetAttrFunction(attr_name, value->operation);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
const char* data, size_t length) {
tensorflow::AttrValue attr_value;
tensorflow::NameAttrList* func = attr_value.mutable_func();
func->set_name(data, length);
op->operation.MutableAttrs()->Set(attr_name, attr_value);
auto s = op->operation->SetAttrFunctionName(attr_name, data, length);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
TF_Status* status) {
tensorflow::Tensor t;
status->status = TF_TensorToTensor(tensor, &t);
if (status->status.ok()) op->operation.MutableAttrs()->Set(attr_name, t);
status->status = op->operation->SetAttrTensor(attr_name, tensor);
}
void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
const void* const* values, const size_t* lengths,
int num_values) {
std::vector<tensorflow::StringPiece> v(num_values);
for (int i = 0; i < num_values; ++i) {
v[i] = tensorflow::StringPiece(static_cast<const char*>(values[i]),
lengths[i]);
auto s =
op->operation->SetAttrStringList(attr_name, values, lengths, num_values);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
op->operation.MutableAttrs()->Set(attr_name, v);
}
void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
const float* values, int num_values) {
op->operation.MutableAttrs()->Set(
attr_name, tensorflow::gtl::ArraySlice<const float>(values, num_values));
auto s = op->operation->SetAttrFloatList(attr_name, values, num_values);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
const int64_t* values, int num_values) {
op->operation.MutableAttrs()->Set(
attr_name, tensorflow::gtl::ArraySlice<const int64>(
reinterpret_cast<const int64*>(values), num_values));
auto s = op->operation->SetAttrIntList(attr_name, values, num_values);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
const TF_DataType* values, int num_values) {
op->operation.MutableAttrs()->Set(
attr_name,
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
reinterpret_cast<const tensorflow::DataType*>(values), num_values));
auto s = op->operation->SetAttrTypeList(attr_name, values, num_values);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
const unsigned char* values, int num_values) {
std::unique_ptr<bool[]> b(new bool[num_values]);
for (int i = 0; i < num_values; ++i) {
b[i] = values[i];
auto s = op->operation->SetAttrBoolList(attr_name, values, num_values);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
op->operation.MutableAttrs()->Set(
attr_name, tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
}
void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
const int64_t** dims, const int* num_dims,
int num_values, TF_Status* out_status) {
std::unique_ptr<tensorflow::TensorShapeProto[]> proto(
new tensorflow::TensorShapeProto[num_values]);
for (int i = 0; i < num_values; ++i) {
const auto num_dims_i = num_dims[i];
if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) {
TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat(
"Value specified for `", attr_name, "` has ", num_dims_i,
" dimensions which is over the limit of ",
tensorflow::TensorShape::MaxDimensions(), ".")
.c_str());
return;
}
if (num_dims_i < 0) {
proto[i].set_unknown_rank(true);
} else {
const int64_t* dims_i = dims[i];
auto proto_i = &proto[i];
for (int d = 0; d < num_dims_i; ++d) {
proto_i->add_dim()->set_size(dims_i[d]);
}
}
}
op->operation.MutableAttrs()->Set(
attr_name, tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
proto.get(), num_values));
out_status->status =
op->operation->SetAttrShapeList(attr_name, dims, num_dims, num_values);
}
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
const TFE_Op** value, int num_values) {
std::unique_ptr<tensorflow::NameAttrList[]> funcs(
new tensorflow::NameAttrList[num_values]);
for (int i = 0; i < num_values; i++) {
funcs[i].set_name(value[i]->operation.Name());
value[i]->operation.Attrs().FillAttrValueMap(funcs[i].mutable_attr());
auto s = op->operation->SetAttrFunctionList(attr_name, value, num_values);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
op->operation.MutableAttrs()->Set(
attr_name, tensorflow::gtl::ArraySlice<const tensorflow::NameAttrList>(
funcs.get(), num_values));
}
TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
const char* input_name,
TF_Status* status) {
const tensorflow::OpDef* op_def = GetOpDef(op, status);
if (!status->status.ok()) {
return -1;
}
tensorflow::AttrValueMap attrs;
op->operation.Attrs().FillAttrValueMap(&attrs);
tensorflow::NameRangeMap name_ranges;
status->status = tensorflow::NameRangesForNode(
tensorflow::AttrSlice(&attrs), *op_def, &name_ranges, nullptr);
if (!status->status.ok()) {
return -1;
}
auto iter = name_ranges.find(input_name);
if (iter == name_ranges.end()) {
status->status = tensorflow::errors::InvalidArgument("Input '", input_name,
"' not found");
return -1;
}
return iter->second.second - iter->second.first;
int ret = -1;
status->status = op->operation->InputLength(input_name, &ret);
return ret;
}
TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
const char* output_name,
TF_Status* status) {
const tensorflow::OpDef* op_def = GetOpDef(op, status);
if (!status->status.ok()) {
return -1;
}
tensorflow::AttrValueMap attrs;
op->operation.Attrs().FillAttrValueMap(&attrs);
tensorflow::NameRangeMap name_ranges;
status->status = tensorflow::NameRangesForNode(
tensorflow::AttrSlice(&attrs), *op_def, nullptr, &name_ranges);
if (!status->status.ok()) {
return -1;
}
auto iter = name_ranges.find(output_name);
if (iter == name_ranges.end()) {
status->status = tensorflow::errors::InvalidArgument(
"Output '", output_name, "' not found");
return -1;
}
return iter->second.second - iter->second.first;
int ret = -1;
status->status = op->operation->OutputLength(output_name, &ret);
return ret;
}
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) {
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
VLOG(1) << "Calling TFE_Execute() on op " << op;
status->status = tensorflow::EagerExecute(&op->operation,
handle_retvals.data(), num_retvals);
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>> handles(
*num_retvals);
status->status = op->operation->Execute(&handles, num_retvals);
if (!status->status.ok()) {
return;
}
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle_retvals[i])};
retvals[i] = new TFE_TensorHandle{std::move(handles[i])};
}
}
@ -1678,6 +1603,23 @@ void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); }
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); }
void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) {
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
op->operation.get());
*attrs = TFE_OpAttrs(&operation->Attrs());
}
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
tensorflow::AttrValueMap m;
attrs->attributes->FillAttrValueMap(&m);
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
op->operation.get());
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
for (auto attribute : m) {
destination->Set(attribute.first, attribute.second);
}
}
namespace tensorflow {
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
const tensorflow::AttrValue& default_value,
@ -1797,10 +1739,10 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
op->Inputs()[i])});
}
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
// TODO(allenl): figure out how to get attrs from EagerOperation
TF_Status status;
TFE_OpAttrs attributes(&op->Attrs());
device_.execute(inputs.size(), inputs.data(), op->Name().c_str(),
num_retvals, outputs.data(), &status, info_);
&attributes, num_retvals, outputs.data(), &status, info_);
if (status.status.ok()) {
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(

View File

@ -31,20 +31,14 @@ using tensorflow::string;
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status) {
if (op_to_reset) {
status->status = op_to_reset->operation.Reset(
op_or_function_name, raw_device_name, false, nullptr);
status->status =
op_to_reset->operation->Reset(op_or_function_name, raw_device_name);
} else {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"op_to_reset should not be nullptr");
}
}
void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
op->operation.ConsumeInput(
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle());
}
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
ctx->context->SetShouldStoreGraphs(true);
}
@ -520,8 +514,7 @@ void TFE_DeleteCancellationManager(
void TFE_OpSetCancellationManager(TFE_Op* op,
TFE_CancellationManager* cancellation_manager,
TF_Status* status) {
op->operation.SetCancellationManager(
&cancellation_manager->cancellation_manager);
status->status = op->operation->SetCancellationManager(cancellation_manager);
}
TFE_Executor* TFE_NewExecutor(bool is_async) {
@ -569,3 +562,22 @@ void TFE_TensorHandleEnableImplicitMirroring(TFE_TensorHandle* h,
h->handle->EnableImplicitMirroring();
status->status = tensorflow::Status::OK();
}
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
TF_Buffer* buf, TF_Status* status) {
auto* function_def = ctx->context->FindFunctionDef(function_name);
if (function_def == nullptr) {
status->status = tensorflow::errors::NotFound(
"Unable to find FunctionDef with name: ", function_name);
return;
}
string str = function_def->SerializeAsString();
void* data = tensorflow::port::Malloc(str.length());
str.copy(static_cast<char*>(data), str.length(), 0);
buf->data = data;
buf->length = str.length();
buf->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
status->status = tensorflow::Status::OK();
}

View File

@ -34,9 +34,6 @@ TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset,
const char* raw_device_name,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
TF_Status* status);
// Enables only graph collection in RunMetadata on the functions executed from
// this context.
TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx);
@ -424,7 +421,27 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
TF_Buffer* buf);
#define TFE_CUSTOM_DEVICE_VERSION 0
// APIs for generically dealing with op attributes (e.g. when forwarding them
// through custom device implementations).
//
// TODO(allenl): Currently these are black boxes, but we should have some way to
// inspect values. This would let people e.g. copy over most attributes and then
// modify some based on their values.
// A reference to an op's name -> attribute mapping
typedef struct TFE_OpAttrs TFE_OpAttrs;
// Fetch a struct with a reference to information about attributes of `op`.
//
// The `attrs` struct does not own any memory, and `op` must outlive it.
TF_CAPI_EXPORT extern void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs);
// Add attributes in `attrs` to `op`.
//
// Does not overwrite or update existing attributes, but adds new ones.
TF_CAPI_EXPORT extern void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs);
#define TFE_CUSTOM_DEVICE_VERSION 1
// Struct to be filled in
typedef struct TFE_CustomDevice {
@ -441,10 +458,10 @@ typedef struct TFE_CustomDevice {
void* device_info);
// Method to execute an operation.
// TODO(allenl) figure out a generic way of passing attrs here
void (*execute)(int num_inputs, TFE_TensorHandle** inputs,
const char* operation_name, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
const char* operation_name, const TFE_OpAttrs* attributes,
int* num_outputs, TFE_TensorHandle** outputs, TF_Status* s,
void* device_info);
// Method to delete a device.
void (*delete_device)(void* device_info);
@ -475,6 +492,11 @@ typedef struct TFE_CustomDevice {
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
const char* device_name, void* device_info);
TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx,
const char* function_name,
TF_Buffer* buf,
TF_Status* status);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -27,12 +27,12 @@ limitations under the License.
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/function.h"
@ -89,7 +89,7 @@ struct TFE_TensorDebugInfo {
};
struct TFE_Op {
tensorflow::EagerOperation operation;
std::unique_ptr<AbstractOperationInterface> operation;
};
struct TFE_MonitoringCounterCell {
@ -236,4 +236,13 @@ struct TFE_Executor {
tensorflow::EagerExecutor* unowned_executor;
};
struct TFE_OpAttrs {
explicit TFE_OpAttrs() : attributes(nullptr) {}
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value)
: attributes(value) {}
const tensorflow::AttrBuilder* attributes;
};
#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_

View File

@ -17,6 +17,8 @@ limitations under the License.
#include <string.h>
#include <string>
#include "absl/strings/match.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
@ -367,7 +369,7 @@ TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevicesAsync) {
void TensorHandleSilentCopy(bool async,
TFE_ContextDevicePlacementPolicy global_policy,
TFE_ContextDevicePlacementPolicy thread_policy,
bool cpu_op) {
bool mirror, bool cpu_op) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -390,6 +392,12 @@ void TensorHandleSilentCopy(bool async,
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
hcpu, ctx, gpu_device_name.c_str(), status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
if (mirror) {
TFE_TensorHandleEnableImplicitMirroring(hcpu, status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
TFE_TensorHandleEnableImplicitMirroring(hgpu, status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
}
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
if (cpu_op) {
@ -414,10 +422,23 @@ void TensorHandleSilentCopy(bool async,
hgpu->handle.get())
->Handle();
// The input handles should never change since they have been mirrored.
ASSERT_EQ(matmul->operation.Inputs()[0], arg0);
ASSERT_EQ(matmul->operation.Inputs()[1], arg1);
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
matmul->operation.get());
if (mirror) {
// The input handles should never change since they have been mirrored.
ASSERT_EQ(op->GetInput(0), arg0);
ASSERT_EQ(op->GetInput(1), arg1);
} else {
if (cpu_op) {
ASSERT_EQ(op->GetInput(0), arg0);
// The GPU handle should be replaced with a CPU copy
ASSERT_NE(op->GetInput(1), arg1);
} else {
// The CPU handle should be replaced with a GPU copy
ASSERT_NE(op->GetInput(0), arg0);
ASSERT_EQ(op->GetInput(1), arg1);
}
}
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteTensorHandle(hgpu);
@ -433,19 +454,27 @@ void TensorHandleSilentCopy(bool async,
}
TEST(CAPI, TensorHandleSilentCopy) {
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
TFE_DEVICE_PLACEMENT_SILENT, false);
TFE_DEVICE_PLACEMENT_SILENT, false, false);
}
TEST(CAPI, TensorHandleSilentCopyAsync) {
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_SILENT,
TFE_DEVICE_PLACEMENT_SILENT, false);
TFE_DEVICE_PLACEMENT_SILENT, false, false);
}
TEST(CAPI, TensorHandleSilentCopyLocalPolicy) {
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_EXPLICIT,
TFE_DEVICE_PLACEMENT_SILENT, false);
TFE_DEVICE_PLACEMENT_SILENT, false, false);
}
TEST(CAPI, TensorHandleSilentCopyLocalPolicyAsync) {
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_EXPLICIT,
TFE_DEVICE_PLACEMENT_SILENT, false);
TFE_DEVICE_PLACEMENT_SILENT, false, false);
}
TEST(CAPI, TensorHandleMirrorCopy) {
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
TFE_DEVICE_PLACEMENT_SILENT, true, false);
}
TEST(CAPI, TensorHandleMirrorCopyCpu) {
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
TFE_DEVICE_PLACEMENT_SILENT, true, true);
}
void SetAndGetOpDevices(bool async) {
@ -581,6 +610,91 @@ TEST(CAPI, TensorHandleDevices) {
TFE_DeleteContext(ctx);
}
void ExecuteAdd(bool async, bool forward_input) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* n = TestMatrixTensorHandle100x100();
// If a GPU exists, copy the handle to GPU so that we can exercise
// unprotecting a mirror.
std::string gpu_device_name;
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
TFE_TensorHandle* n_gpu =
TFE_TensorHandleCopyToDevice(n, ctx, gpu_device_name.c_str(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandleEnableImplicitMirroring(n_gpu, status);
TFE_DeleteTensorHandle(n);
n = n_gpu;
}
TFE_TensorHandle* m = TestMatrixTensorHandle100x100();
// Store pointer to raw buffer for validation of forwarding behaviour.
TF_Tensor* orig = TFE_TensorHandleResolve(n, status);
void* orig_ptr = TF_TensorData(orig);
TF_DeleteTensor(orig);
TFE_Op* add_op = AddOp(ctx, n, m);
std::string cpu_device_name;
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
TFE_OpSetDevice(add_op, cpu_device_name.c_str(), status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
if (forward_input) {
TFE_DeleteTensorHandle(n);
}
int num_retvals = 1;
if (async) {
// Enqueue dummy ops so we backlog async execution & actually test async.
for (int i = 0; i < 10000; ++i) {
TFE_TensorHandle* dummy = nullptr;
TFE_Execute(add_op, &dummy, &num_retvals, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(dummy);
}
}
TFE_TensorHandle* retval = nullptr;
TFE_Execute(add_op, &retval, &num_retvals, status);
EXPECT_EQ(1, num_retvals);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
if (!forward_input) {
TFE_DeleteTensorHandle(n);
}
TFE_DeleteOp(add_op);
TF_Tensor* t = TFE_TensorHandleResolve(retval, status);
if (forward_input || async) {
EXPECT_EQ(orig_ptr, TF_TensorData(t));
} else {
EXPECT_NE(orig_ptr, TF_TensorData(t));
}
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(m);
TFE_DeleteTensorHandle(retval);
TFE_DeleteContext(ctx);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float result[100 * 100] = {0};
EXPECT_EQ(sizeof(result), TF_TensorByteSize(t));
memcpy(&result[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
for (int i = 0; i < 100 * 100; ++i) {
EXPECT_EQ(2.0f, result[i]);
}
TF_DeleteStatus(status);
}
TEST(CAPI, ExecuteAdd) { ExecuteAdd(false, false); }
TEST(CAPI, ExecuteAddAsync) { ExecuteAdd(true, false); }
TEST(CAPI, ExecuteAddForward) { ExecuteAdd(false, true); }
TEST(CAPI, ExecuteAddForwardAsync) { ExecuteAdd(true, true); }
void Execute_MatMul_CPU(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -1219,6 +1333,14 @@ TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
TFE_DeleteTensorHandle(h_shares_tensor);
}
tensorflow::AttrValueMap ExtractAttrs(TFE_Op* op) {
tensorflow::AttrValueMap attr_values;
tensorflow::down_cast<tensorflow::OperationInterface*>(op->operation.get())
->Attrs()
.FillAttrValueMap(&attr_values);
return attr_values;
}
TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -1235,8 +1357,7 @@ TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
TFE_OpAddInput(minOp, axis, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
minOp->operation.Attrs().FillAttrValueMap(&attr_values);
tensorflow::AttrValueMap attr_values = ExtractAttrs(minOp);
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
EXPECT_NE(attr_found, attr_values.cend());
EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
@ -1275,8 +1396,7 @@ TEST(CAPI, TestTFE_OpInferSingleTypeInputListAttrs) {
TFE_OpAddInputList(concatOp, inputs, 2, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
concatOp->operation.Attrs().FillAttrValueMap(&attr_values);
tensorflow::AttrValueMap attr_values = ExtractAttrs(concatOp);
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
EXPECT_NE(attr_found, attr_values.cend());
EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
@ -1316,8 +1436,7 @@ TEST(CAPI, TestTFE_OpInferMixedTypeInputListAttrs) {
TFE_OpAddInputList(assertOp, data, 3, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
assertOp->operation.Attrs().FillAttrValueMap(&attr_values);
tensorflow::AttrValueMap attr_values = ExtractAttrs(assertOp);
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
EXPECT_NE(attr_found, attr_values.cend());
EXPECT_EQ(attr_found->second.list().type(0), tensorflow::DataType::DT_BOOL);
@ -1353,16 +1472,15 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
TFE_TensorHandle* inputs[] = {input1, input2};
TFE_OpAddInput(concatOp, dim, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CHECK(concatOp->operation.OpDef());
CHECK(concatOp->operation->OpDef());
TFE_OpAddInput(concatOp, inputs[0], status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_FALSE(concatOp->operation.OpDef())
EXPECT_FALSE(concatOp->operation->OpDef())
<< "Inference context is still present";
TFE_OpAddInput(concatOp, inputs[1], status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
concatOp->operation.Attrs().FillAttrValueMap(&attr_values);
tensorflow::AttrValueMap attr_values = ExtractAttrs(concatOp);
EXPECT_EQ(attr_values.find("T"), attr_values.end());
EXPECT_EQ(attr_values.find("N"), attr_values.end());
@ -1449,4 +1567,40 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) {
TFE_DeleteContext(ctx);
}
TEST(CAPI, TestTFE_OpGetAttrs) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
TFE_OpAttrs attributes;
TFE_OpGetAttrs(var_op, &attributes);
TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status);
TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT);
TFE_OpAddAttrs(copy_op, &attributes);
unsigned char is_list = 0;
ASSERT_EQ(TF_ATTR_TYPE,
TFE_OpGetAttrType(copy_op, "dtype", &is_list, status));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(TF_ATTR_SHAPE,
TFE_OpGetAttrType(copy_op, "shape", &is_list, status));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
copy_op->operation.get());
op->Attrs().FillAttrValueMap(&attr_values);
EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type());
TF_DeleteStatus(status);
TFE_DeleteOp(var_op);
TFE_DeleteOp(copy_op);
TFE_DeleteContext(ctx);
}
} // namespace

View File

@ -131,6 +131,21 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2() {
return th;
}
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "AddV2", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, a, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, b, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
return op;
}
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
TF_Status* status = TF_NewStatus();

View File

@ -42,6 +42,9 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2();
// Return a tensor handle containing a 3x2 matrix of floats
TFE_TensorHandle* TestMatrixTensorHandle3X2();
// Return an add op multiplying `a` by `b`.
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
// Return a matmul op multiplying `a` by `b`.
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/test.h"
namespace {
@ -31,6 +32,8 @@ struct LoggingDevice {
tensorflow::string underlying_device;
// Set to true whenever a TensorHandle is copied onto the device
bool* arrived_flag;
// Set to true whenever an operation is executed
bool* executed_flag;
};
struct LoggedTensor {
@ -81,12 +84,14 @@ TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_TensorHandle* tensor,
}
void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs,
const char* operation_name, int* num_outputs,
const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s,
void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_Op* op(TFE_NewOp(dev->ctx, operation_name, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddAttrs(op, attributes);
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
for (int j = 0; j < num_inputs; ++j) {
TFE_TensorHandle* input = inputs[j];
@ -115,6 +120,7 @@ void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs,
outputs[i] = MakeLoggedTensorHandle(dev->ctx, dev->device_name,
std::move(logged_tensor), s);
}
*(dev->executed_flag) = true;
}
void DeleteLoggingDevice(void* device_info) {
@ -122,7 +128,7 @@ void DeleteLoggingDevice(void* device_info) {
}
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag) {
bool* arrived_flag, bool* executed_flag) {
TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
@ -131,6 +137,7 @@ void RegisterLoggingDevice(TFE_Context* context, const char* name,
LoggingDevice* device = new LoggingDevice;
device->ctx = context;
device->arrived_flag = arrived_flag;
device->executed_flag = executed_flag;
device->device_name = name;
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_RegisterCustomDevice(context, custom_device, name, device);
@ -144,13 +151,15 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
TFE_DeleteContextOptions(opts);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context, name, &arrived);
RegisterLoggingDevice(context, name, &arrived, &executed);
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
ASSERT_FALSE(arrived);
TFE_TensorHandle* hdevice =
TFE_TensorHandleCopyToDevice(hcpu, context, name, status.get());
ASSERT_TRUE(arrived);
ASSERT_FALSE(executed);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> matmul(
MatMulOp(context, hcpu, hdevice), TFE_DeleteOp);
@ -160,6 +169,7 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
int num_retvals = 1;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
TFE_DeleteTensorHandle(retval);
TFE_DeleteTensorHandle(hcpu);
@ -167,4 +177,118 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
TFE_DeleteContext(context);
}
TEST(CUSTOM_DEVICE, ResetOperation) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts, status.get()), TFE_DeleteContext);
TFE_DeleteContextOptions(opts);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
const char* custom_device_name =
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed);
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> reused_op(
TFE_NewOp(context.get(), "Identity", status.get()), TFE_DeleteOp);
TFE_OpReset(reused_op.get(), "Identity", custom_device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())),
tensorflow::string(custom_device_name));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpReset(reused_op.get(), "Identity",
"/job:localhost/replica:0/task:0/device:CPU:0", status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())),
tensorflow::string("/job:localhost/replica:0/task:0/device:CPU:0"));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
TEST(CUSTOM_DEVICE, MakeVariable) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, &arrived, &executed);
// Create a variable handle placed on the custom device.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get());
TFE_OpSetAttrString(op.get(), "container", "", 0);
TFE_OpSetAttrString(op.get(), "shared_name", "", 0);
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
executed = false;
TFE_Execute(op.get(), &var_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
auto handle_cleaner = tensorflow::gtl::MakeCleanup(
[var_handle]() { TFE_DeleteTensorHandle(var_handle); });
// Assign to the variable, copying to the custom device.
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpAddInput(op.get(), one.get(), status.get());
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
executed = false;
num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
// Read the variable's value.
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpSetDevice(op.get(), name, status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
executed = false;
num_retvals = 1;
TFE_TensorHandle* var_value = nullptr;
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
auto value_cleaner = tensorflow::gtl::MakeCleanup(
[var_value]() { TFE_DeleteTensorHandle(var_value); });
ASSERT_EQ(tensorflow::string(name),
tensorflow::string(
TFE_TensorHandleBackingDeviceName(var_value, status.get())));
TFE_TensorHandle* var_value_unpacked =
reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(var_value, status.get()))
->tensor;
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> resolved_value(
TFE_TensorHandleResolve(var_value_unpacked, status.get()),
TF_DeleteTensor);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(111., *static_cast<float*>(TF_TensorData(resolved_value.get())));
// Free the backing buffer for the variable.
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
} // namespace

View File

@ -0,0 +1,312 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/operation_interface.h"
#include "absl/container/fixed_array.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/eager/execute.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
OperationInterface::OperationInterface(TFE_Context* ctx)
: operation_(ctx->context) {}
const string& OperationInterface::DeviceName() const {
absl::variant<Device*, CustomDevice*> variant_device =
(operation_.Device() == kVariantDeviceNull)
? operation_.EagerContext().HostCPU()
: operation_.Device();
return absl::visit([](auto* d) -> const string& { return d->name(); },
variant_device);
}
Status OperationInterface::SetDeviceName(const char* name) {
return operation_.SetDeviceName(name);
}
Status OperationInterface::SetAttrString(const char* attr_name,
const char* data, size_t length) {
operation_.MutableAttrs()->Set(attr_name, StringPiece(data, length));
return Status::OK();
}
Status OperationInterface::SetAttrInt(const char* attr_name, int64_t value) {
operation_.MutableAttrs()->Set(attr_name, static_cast<int64>(value));
return Status::OK();
}
Status OperationInterface::SetAttrFloat(const char* attr_name, float value) {
operation_.MutableAttrs()->Set(attr_name, value);
return Status::OK();
}
Status OperationInterface::SetAttrBool(const char* attr_name, bool value) {
operation_.MutableAttrs()->Set(attr_name, value);
return Status::OK();
}
Status OperationInterface::SetAttrType(const char* attr_name,
TF_DataType value) {
operation_.MutableAttrs()->Set(attr_name, static_cast<DataType>(value));
return Status::OK();
}
Status OperationInterface::SetAttrShape(const char* attr_name,
const int64_t* dims,
const int num_dims) {
if (num_dims > TensorShape::MaxDimensions()) {
return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
num_dims,
" dimensions which is over the limit of ",
TensorShape::MaxDimensions(), ".");
}
TensorShapeProto proto;
if (num_dims < 0) {
proto.set_unknown_rank(true);
} else {
for (int d = 0; d < num_dims; ++d) {
proto.add_dim()->set_size(dims[d]);
}
}
operation_.MutableAttrs()->Set(attr_name, proto);
return Status::OK();
}
Status OperationInterface::SetAttrFunction(
const char* attr_name,
const std::unique_ptr<AbstractOperationInterface>& value) {
AttrValue attr_value;
NameAttrList* func = attr_value.mutable_func();
func->set_name(value->Name());
OperationInterface* value_operation =
tensorflow::down_cast<OperationInterface*>(value.get());
value_operation->operation_.Attrs().FillAttrValueMap(func->mutable_attr());
operation_.MutableAttrs()->Set(attr_name, attr_value);
return Status::OK();
}
Status OperationInterface::SetAttrFunctionName(const char* attr_name,
const char* data,
size_t length) {
AttrValue attr_value;
NameAttrList* func = attr_value.mutable_func();
func->set_name(data, length);
operation_.MutableAttrs()->Set(attr_name, attr_value);
return Status::OK();
}
Status OperationInterface::SetAttrTensor(const char* attr_name,
TF_Tensor* tensor) {
Tensor t;
TF_RETURN_IF_ERROR(TF_TensorToTensor(tensor, &t));
operation_.MutableAttrs()->Set(attr_name, t);
return Status::OK();
}
Status OperationInterface::SetAttrStringList(const char* attr_name,
const void* const* values,
const size_t* lengths,
int num_values) {
std::vector<StringPiece> v(num_values);
for (int i = 0; i < num_values; ++i) {
v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
}
operation_.MutableAttrs()->Set(attr_name, v);
return Status::OK();
}
Status OperationInterface::SetAttrFloatList(const char* attr_name,
const float* values,
int num_values) {
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<const float>(values, num_values));
return Status::OK();
}
Status OperationInterface::SetAttrIntList(const char* attr_name,
const int64_t* values,
int num_values) {
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<const int64>(
reinterpret_cast<const int64*>(values), num_values));
return Status::OK();
}
Status OperationInterface::SetAttrTypeList(const char* attr_name,
const TF_DataType* values,
int num_values) {
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<const DataType>(
reinterpret_cast<const DataType*>(values), num_values));
return Status::OK();
}
Status OperationInterface::SetAttrBoolList(const char* attr_name,
const unsigned char* values,
int num_values) {
std::unique_ptr<bool[]> b(new bool[num_values]);
for (int i = 0; i < num_values; ++i) {
b[i] = values[i];
}
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<const bool>(b.get(), num_values));
return Status::OK();
}
Status OperationInterface::SetAttrShapeList(const char* attr_name,
const int64_t** dims,
const int* num_dims,
int num_values) {
std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
for (int i = 0; i < num_values; ++i) {
const auto num_dims_i = num_dims[i];
if (num_dims_i > TensorShape::MaxDimensions()) {
return errors::InvalidArgument(
strings::StrCat("Value specified for `", attr_name, "` has ",
num_dims_i, " dimensions which is over the limit of ",
TensorShape::MaxDimensions(), "."));
}
if (num_dims_i < 0) {
proto[i].set_unknown_rank(true);
} else {
const int64_t* dims_i = dims[i];
auto proto_i = &proto[i];
for (int d = 0; d < num_dims_i; ++d) {
proto_i->add_dim()->set_size(dims_i[d]);
}
}
}
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
return Status::OK();
}
Status OperationInterface::SetAttrFunctionList(const char* attr_name,
const TFE_Op** value,
int num_values) {
std::unique_ptr<NameAttrList[]> funcs(new NameAttrList[num_values]);
for (int i = 0; i < num_values; i++) {
auto value_operation =
tensorflow::down_cast<OperationInterface*>(value[i]->operation.get());
funcs[i].set_name(value_operation->operation_.Name());
value_operation->operation_.Attrs().FillAttrValueMap(
funcs[i].mutable_attr());
}
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<const NameAttrList>(funcs.get(), num_values));
return Status::OK();
}
const OpDef* OperationInterface::GetOpDef(Status* status) {
const tensorflow::OpDef* op_def = operation_.OpDef();
if (op_def) return op_def;
*status = OpDefForOp(Name(), &op_def);
return op_def;
}
Status OperationInterface::InputLength(const char* input_name, int* length) {
Status status;
const tensorflow::OpDef* op_def = GetOpDef(&status);
if (!status.ok()) {
return status;
}
AttrValueMap attrs;
operation_.Attrs().FillAttrValueMap(&attrs);
NameRangeMap name_ranges;
TF_RETURN_IF_ERROR(
NameRangesForNode(AttrSlice(&attrs), *op_def, &name_ranges, nullptr));
auto iter = name_ranges.find(input_name);
if (iter == name_ranges.end()) {
return errors::InvalidArgument("Input '", input_name, "' not found");
}
*length = iter->second.second - iter->second.first;
return Status::OK();
}
Status OperationInterface::OutputLength(const char* output_name, int* length) {
Status status;
const tensorflow::OpDef* op_def = GetOpDef(&status);
if (!status.ok()) {
return status;
}
AttrValueMap attrs;
operation_.Attrs().FillAttrValueMap(&attrs);
NameRangeMap name_ranges;
TF_RETURN_IF_ERROR(
NameRangesForNode(AttrSlice(&attrs), *op_def, nullptr, &name_ranges));
auto iter = name_ranges.find(output_name);
if (iter == name_ranges.end()) {
return errors::InvalidArgument("Output '", output_name, "' not found");
}
*length = iter->second.second - iter->second.first;
return Status::OK();
}
Status OperationInterface::AddInput(
const std::unique_ptr<AbstractTensorHandleInterface>& input) {
TensorHandle* h =
tensorflow::down_cast<TensorHandleInterface*>(input.get())->Handle();
operation_.AddInput(h);
return operation_.MaybeInferSingleInputAttrs(h);
}
Status OperationInterface::AddInputList(
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
inputs) {
for (auto& input : inputs) {
TensorHandle* h =
tensorflow::down_cast<TensorHandleInterface*>(input.get())->Handle();
operation_.AddInput(h);
}
return operation_.InferInputListAttrs(inputs.size());
}
Status OperationInterface::Execute(
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
int* num_retvals) {
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
TF_RETURN_IF_ERROR(
EagerExecute(&operation_, handle_retvals.data(), num_retvals));
for (int i = 0; i < *num_retvals; ++i) {
retvals->at(i).reset(
new tensorflow::TensorHandleInterface(handle_retvals[i]));
}
return Status::OK();
}
Status OperationInterface::SetCancellationManager(
TFE_CancellationManager* cancellation_manager) {
operation_.SetCancellationManager(
&cancellation_manager->cancellation_manager);
return Status::OK();
}
Status OperationInterface::SetUseXla(bool enable) {
operation_.SetUseXla(enable);
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,188 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
#define TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
#include <memory>
#include "absl/container/fixed_array.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
// Abstract interface to an operation.
class AbstractOperationInterface {
public:
virtual ~AbstractOperationInterface() {}
virtual void Clear() = 0;
virtual tensorflow::Status Reset(const char* op,
const char* raw_device_name) = 0;
virtual const tensorflow::string& Name() const = 0;
virtual const tensorflow::string& DeviceName() const = 0;
virtual tensorflow::Status SetDeviceName(const char* name) = 0;
virtual tensorflow::Status AddInput(
const std::unique_ptr<AbstractTensorHandleInterface>& input) = 0;
virtual tensorflow::Status AddInputList(
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
inputs) = 0;
virtual tensorflow::Status Execute(
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
int* num_retvals) = 0;
virtual const tensorflow::OpDef* OpDef() const = 0;
virtual tensorflow::Status SetAttrString(const char* attr_name,
const char* data, size_t length) = 0;
virtual tensorflow::Status SetAttrInt(const char* attr_name,
int64_t value) = 0;
virtual tensorflow::Status SetAttrFloat(const char* attr_name,
float value) = 0;
virtual tensorflow::Status SetAttrBool(const char* attr_name, bool value) = 0;
virtual tensorflow::Status SetAttrType(const char* attr_name,
TF_DataType value) = 0;
virtual tensorflow::Status SetAttrShape(const char* attr_name,
const int64_t* dims,
const int num_dims) = 0;
virtual tensorflow::Status SetAttrFunction(
const char* attr_name,
const std::unique_ptr<AbstractOperationInterface>& value) = 0;
virtual tensorflow::Status SetAttrFunctionName(const char* attr_name,
const char* value,
size_t length) = 0;
virtual tensorflow::Status SetAttrTensor(const char* attr_name,
TF_Tensor* tensor) = 0;
virtual tensorflow::Status SetAttrStringList(const char* attr_name,
const void* const* values,
const size_t* lengths,
int num_values) = 0;
virtual tensorflow::Status SetAttrFloatList(const char* attr_name,
const float* values,
int num_values) = 0;
virtual tensorflow::Status SetAttrIntList(const char* attr_name,
const int64_t* values,
int num_values) = 0;
virtual tensorflow::Status SetAttrTypeList(const char* attr_name,
const TF_DataType* values,
int num_values) = 0;
virtual tensorflow::Status SetAttrBoolList(const char* attr_name,
const unsigned char* values,
int num_values) = 0;
virtual tensorflow::Status SetAttrShapeList(const char* attr_name,
const int64_t** dims,
const int* num_dims,
int num_values) = 0;
virtual tensorflow::Status SetAttrFunctionList(const char* attr_name,
const TFE_Op** value,
int num_values) = 0;
virtual tensorflow::Status InputLength(const char* input_name,
int* length) = 0;
virtual tensorflow::Status OutputLength(const char* output_name,
int* length) = 0;
// Experimental
virtual tensorflow::Status SetUseXla(bool enable) {
return tensorflow::errors::Unimplemented("SetUseXla not implemented");
}
virtual tensorflow::Status SetCancellationManager(
TFE_CancellationManager* cancellation_manager) {
return tensorflow::errors::Unimplemented(
"SetCancellationManager not implemented");
}
};
namespace tensorflow {
class OpDef;
class OperationInterface : public AbstractOperationInterface {
public:
explicit OperationInterface(TFE_Context* ctx);
~OperationInterface() override{};
void Clear() override { operation_.Clear(); }
Status Reset(const char* op, const char* raw_device_name) override {
return operation_.Reset(op, raw_device_name, false, nullptr);
}
const string& Name() const override { return operation_.Name(); }
const string& DeviceName() const override;
Status SetDeviceName(const char* name) override;
Status AddInput(
const std::unique_ptr<AbstractTensorHandleInterface>& input) override;
Status AddInputList(
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
inputs) override;
Status Execute(
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
int* num_retvals) override;
const tensorflow::OpDef* OpDef() const override {
return operation_.OpDef();
};
Status SetAttrString(const char* attr_name, const char* data,
size_t length) override;
Status SetAttrInt(const char* attr_name, int64_t value) override;
Status SetAttrFloat(const char* attr_name, float value) override;
Status SetAttrBool(const char* attr_name, bool value) override;
Status SetAttrType(const char* attr_name, TF_DataType value) override;
Status SetAttrShape(const char* attr_name, const int64_t* dims,
const int num_dims) override;
Status SetAttrFunction(
const char* attr_name,
const std::unique_ptr<AbstractOperationInterface>& value) override;
Status SetAttrFunctionName(const char* attr_name, const char* data,
size_t length) override;
Status SetAttrTensor(const char* attr_name, TF_Tensor* tensor) override;
Status SetAttrStringList(const char* attr_name, const void* const* values,
const size_t* lengths, int num_values) override;
Status SetAttrFloatList(const char* attr_name, const float* values,
int num_values) override;
Status SetAttrIntList(const char* attr_name, const int64_t* values,
int num_values) override;
Status SetAttrTypeList(const char* attr_name, const TF_DataType* values,
int num_values) override;
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
int num_values) override;
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
const int* num_dims, int num_values) override;
Status SetAttrFunctionList(const char* attr_name, const TFE_Op** value,
int num_values) override;
Status InputLength(const char* input_name, int* length) override;
Status OutputLength(const char* output_name, int* length) override;
Status SetUseXla(bool enable) override;
Status SetCancellationManager(
TFE_CancellationManager* cancellation_manager) override;
// TODO(gjn): Remove once TFE_InferShapes is removed
const tensorflow::AttrBuilder& Attrs() const { return operation_.Attrs(); }
tensorflow::AttrBuilder* MutableAttrs() { return operation_.MutableAttrs(); }
const TensorHandle* GetInput(int i) const { return operation_.Inputs()[i]; }
private:
const tensorflow::OpDef* GetOpDef(Status* status);
EagerOperation operation_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h"
#include <memory>
#include <vector>
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
@ -64,25 +65,41 @@ void deallocate_buffer(void* data, size_t len, void* arg) {
}
} // namespace tensorflow
namespace {
TF_Tensor* CreateTensor(TF_ManagedBuffer* buf, TF_DataType dtype,
const int64_t* dims, int num_dims, size_t len) {
std::vector<tensorflow::int64> dimvec(num_dims);
for (int i = 0; i < num_dims; ++i) {
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
}
// TODO(gjn): Make the choice of interface a compile-time configuration.
tensorflow::TensorInterface ret(
Tensor(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimvec), buf));
buf->Unref();
size_t elem_size = TF_DataTypeSize(dtype);
if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
return nullptr;
}
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(ret)};
}
} // namespace
TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims,
int num_dims, size_t len) {
void* data = tensorflow::allocate_tensor("TF_AllocateTensor", len,
tensorflow::cpu_allocator());
return TF_NewTensor(dtype, dims, num_dims, data, len,
tensorflow::deallocate_buffer,
tensorflow::cpu_allocator());
TF_ManagedBuffer* buf =
new TF_ManagedBuffer(data, len, tensorflow::deallocate_buffer,
tensorflow::cpu_allocator(), /*owns_memory=*/true);
return CreateTensor(buf, dtype, dims, num_dims, len);
}
TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
void* data, size_t len,
void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg) {
std::vector<tensorflow::int64> dimvec(num_dims);
for (int i = 0; i < num_dims; ++i) {
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
}
TF_ManagedBuffer* buf = nullptr;
if (dtype != TF_STRING && dtype != TF_RESOURCE &&
tensorflow::DataTypeCanUseMemcpy(
@ -97,24 +114,17 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
// Other types have the same representation, so copy only if it is safe to
// do so.
buf = new TF_ManagedBuffer(tensorflow::allocate_tensor("TF_NewTensor", len),
len, tensorflow::deallocate_buffer, nullptr);
len, tensorflow::deallocate_buffer, nullptr,
/*owns_memory=*/true);
std::memcpy(buf->data(), data, len);
// Free the original buffer.
deallocator(data, len, deallocator_arg);
} else {
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg,
/*owns_memory=*/false);
}
// TODO(gjn): Make the choice of interface a compile-time configuration.
tensorflow::TensorInterface ret(
Tensor(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimvec), buf));
buf->Unref();
size_t elem_size = TF_DataTypeSize(dtype);
if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
return nullptr;
}
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(ret)};
return CreateTensor(buf, dtype, dims, num_dims, len);
}
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* t) {

View File

@ -38,11 +38,12 @@ class TF_ManagedBuffer : public tensorflow::TensorBuffer {
public:
TF_ManagedBuffer(void* data, size_t len,
void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg)
void* deallocator_arg, bool owns_memory)
: TensorBuffer(data),
len_(len),
deallocator_(deallocator),
deallocator_arg_(deallocator_arg) {}
deallocator_arg_(deallocator_arg),
owns_memory_(owns_memory) {}
~TF_ManagedBuffer() override {
(*deallocator_)(data(), len_, deallocator_arg_);
@ -57,13 +58,13 @@ class TF_ManagedBuffer : public tensorflow::TensorBuffer {
proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
}
// Prevents input forwarding from mutating this buffer.
bool OwnsMemory() const override { return false; }
bool OwnsMemory() const override { return owns_memory_; }
private:
const size_t len_;
void (*const deallocator_)(void* data, size_t len, void* arg);
void* const deallocator_arg_;
bool owns_memory_;
};
namespace tensorflow {

View File

@ -15,13 +15,12 @@ limitations under the License.
#include <vector>
#include "tensorflow/cc/framework/grad_op_registry.h"
#include "tensorflow/cc/framework/gradients.h"
#include "tensorflow/cc/ops/array_ops_internal.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/cc/framework/grad_op_registry.h"
#include "tensorflow/cc/framework/gradients.h"
namespace tensorflow {
namespace ops {
namespace {
@ -90,15 +89,25 @@ Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad);
Status QuantizeAndDequantizeV2Grad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
grad_outputs->push_back(Identity(scope, grad_inputs[0]));
grad_outputs->push_back(NoGradient());
grad_outputs->push_back(NoGradient());
Status QuantizeAndDequantizeV2GradHelper(const Scope& scope,
const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
Input input = Shape(scope, op.input(0));
Input input_min = op.input(1);
Input input_max = op.input(2);
int64 axis;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
auto qdq_v2_grad = QuantizeAndDequantizeV2Grad(
scope, grad_inputs[0], input, input_min, input_max,
QuantizeAndDequantizeV2Grad::Axis(axis));
grad_outputs->push_back(qdq_v2_grad.input_backprop);
grad_outputs->push_back(qdq_v2_grad.input_min_backprop);
grad_outputs->push_back(qdq_v2_grad.input_max_backprop);
return scope.status();
}
REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2", QuantizeAndDequantizeV2Grad);
REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2",
QuantizeAndDequantizeV2GradHelper);
Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,

View File

@ -68,6 +68,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/platform:resource_loader",
],
)

View File

@ -21,15 +21,22 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
constexpr char kTestDataPbTxt[] =
"cc/saved_model/testdata/half_plus_two_pbtxt/00000123";
constexpr char kTestDataSharded[] =
"cc/saved_model/testdata/half_plus_two/00000123";
string TestDataPbTxt() {
return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
"half_plus_two_pbtxt", "00000123");
}
string TestDataSharded() {
return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
"half_plus_two", "00000123");
}
class ReaderTest : public ::testing::Test {
protected:
@ -49,8 +56,7 @@ class ReaderTest : public ::testing::Test {
TEST_F(ReaderTest, TagMatch) {
MetaGraphDef meta_graph_def;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
&meta_graph_def));
CheckMetaGraphDef(meta_graph_def);
@ -59,8 +65,7 @@ TEST_F(ReaderTest, TagMatch) {
TEST_F(ReaderTest, NoTagMatch) {
MetaGraphDef meta_graph_def;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
&meta_graph_def);
EXPECT_FALSE(st.ok());
@ -73,8 +78,7 @@ TEST_F(ReaderTest, NoTagMatch) {
TEST_F(ReaderTest, NoTagMatchMultiple) {
MetaGraphDef meta_graph_def;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
Status st = ReadMetaGraphDefFromSavedModel(
export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
EXPECT_FALSE(st.ok());
@ -87,8 +91,7 @@ TEST_F(ReaderTest, NoTagMatchMultiple) {
TEST_F(ReaderTest, PbtxtFormat) {
MetaGraphDef meta_graph_def;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPbTxt);
const string export_dir = GetDataDependencyFilepath(TestDataPbTxt());
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
&meta_graph_def));
CheckMetaGraphDef(meta_graph_def);
@ -97,8 +100,7 @@ TEST_F(ReaderTest, PbtxtFormat) {
TEST_F(ReaderTest, InvalidExportPath) {
MetaGraphDef meta_graph_def;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path");
const string export_dir = GetDataDependencyFilepath("missing-path");
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
&meta_graph_def);
EXPECT_FALSE(st.ok());

View File

@ -84,6 +84,7 @@ tf_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:resource_loader",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support", # fixdeps: keep
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/codegen.h"
#include <algorithm>
#include <string>
#include <vector>
@ -29,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@ -139,23 +141,40 @@ TEST_F(ParseCppClassTest, ParseFail) {
static void CompareWithGoldenFile(
const string& tensorflow_relative_golden_file_name,
const string& expected_contents) {
const string& expected_contents, bool ignore_cr) {
// Get rid of all CR characters, we may be running under windows.
string sanitized_expected_contents(expected_contents);
if (ignore_cr) {
sanitized_expected_contents.erase(
std::remove(sanitized_expected_contents.begin(),
sanitized_expected_contents.end(), '\r'),
sanitized_expected_contents.end());
}
// To update the golden file, flip update_golden to true and run the
// following:
// bazel test --test_strategy=local \
// third_party/tensorflow/compiler/aot:codegen_test
const bool update_golden = false;
const string golden_file_name = io::JoinPath(
testing::TensorFlowSrcRoot(), tensorflow_relative_golden_file_name);
string golden_file_name;
if (update_golden) {
golden_file_name = io::JoinPath(testing::TensorFlowSrcRoot(),
tensorflow_relative_golden_file_name);
TF_EXPECT_OK(
WriteStringToFile(Env::Default(), golden_file_name, expected_contents));
}
golden_file_name =
GetDataDependencyFilepath(tensorflow_relative_golden_file_name);
string golden_file_contents;
TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name,
&golden_file_contents));
if (ignore_cr) {
golden_file_contents.erase(std::remove(golden_file_contents.begin(),
golden_file_contents.end(), '\r'),
golden_file_contents.end());
}
EXPECT_EQ(golden_file_contents, expected_contents);
}
@ -229,14 +248,18 @@ TEST(CodegenTest, Golden) {
// The other fields in metadata_result are tested as part of the generated
// header test.
CompareWithGoldenFile("compiler/aot/codegen_test_o.golden",
metadata_result.object_file_data);
// This specific golden test checks a binary file. It can potentially run into
// issues due to ABIs not being stable, but has not so far.
// If we see any ABI issues, we should reconsider this specific test case.
CompareWithGoldenFile("tensorflow/compiler/aot/codegen_test_o.golden",
metadata_result.object_file_data, false);
string header;
TF_ASSERT_OK(
GenerateHeader(opts, config, compile_result, metadata_result, &header));
CompareWithGoldenFile("compiler/aot/codegen_test_h.golden", header);
CompareWithGoldenFile("tensorflow/compiler/aot/codegen_test_h.golden", header,
true);
}
} // namespace
} // namespace tfcompile

View File

@ -1883,6 +1883,8 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"EmptyTensorList",
"ExtractImagePatches",
"Igamma",
"IgammaGradA",
"RandomGammaGrad",
"Igammac",
"FFT",
"FFT2D",
@ -1909,7 +1911,6 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"LinSpace",
"ListDiff",
"LogMatrixDeterminant",
"LowerBound",
"MatMul",
"MatrixBandPart",
"MatrixDiag",
@ -2036,7 +2037,6 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"TensorScatterUpdate",
"TridiagonalSolve",
"TruncatedNormal",
"UpperBound",
"UnsortedSegmentMax",
"UnsortedSegmentMin",
"UnsortedSegmentProd",

View File

@ -20,15 +20,17 @@ limitations under the License.
namespace tensorflow {
bool XlaKernelCreator::CanCreateKernel(const FunctionLibraryRuntime& flr,
const NodeDef& node_def) const {
return CanCreateXlaKernel(node_def);
bool XlaKernelCreator::CanCreateKernel(
const FunctionLibraryRuntime& flr,
const std::shared_ptr<const NodeProperties>& props) const {
return CanCreateXlaKernel(props->node_def);
}
Status XlaKernelCreator::CreateKernel(FunctionLibraryRuntime* flr,
const NodeDef& node_def,
std::unique_ptr<OpKernel>* kernel) const {
return CreateXlaKernel(flr, node_def, kernel);
Status XlaKernelCreator::CreateKernel(
FunctionLibraryRuntime* flr,
const std::shared_ptr<const NodeProperties>& props,
std::unique_ptr<OpKernel>* kernel) const {
return CreateXlaKernel(flr, props->node_def, kernel);
}
namespace {

View File

@ -29,11 +29,13 @@ class XlaKernelCreator : public CustomKernelCreator {
// Given a NodeDef 'node_def' and the function library runtime 'flr', returns
// true if 'node_def' is a call to a compilable function defined in 'flr',
// with the kXlaCompileAttr set.
bool CanCreateKernel(const FunctionLibraryRuntime& flr,
const NodeDef& node_def) const override;
bool CanCreateKernel(
const FunctionLibraryRuntime& flr,
const std::shared_ptr<const NodeProperties>& props) const override;
// Given a supported NodeDef, returns a XlaLaunchOp that computes the node.
Status CreateKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
Status CreateKernel(FunctionLibraryRuntime* flr,
const std::shared_ptr<const NodeProperties>& props,
std::unique_ptr<OpKernel>* kernel) const override;
};

View File

@ -30,10 +30,12 @@ limitations under the License.
namespace tensorflow {
NodeDef ToNodeDef(const string& text) {
std::shared_ptr<NodeProperties> ToNodeProperties(const string& text) {
NodeDef node_def;
DataTypeVector dummy;
EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
return node_def;
return std::make_shared<NodeProperties>(nullptr, std::move(node_def), dummy,
dummy);
}
// Create a FunctionDef that takes one resource and one regular param
@ -98,11 +100,11 @@ TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) {
(*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
Init({fdef});
XlaKernelCreator xla_kernel_creator;
NodeDef callsite =
ToNodeDef(R"pb(
auto callsite =
ToNodeProperties(R"pb(
name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
)pb");
(*callsite.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
(*(callsite->node_def.mutable_attr()))["_XlaMustCompile"] = BoolAttr(true);
// Note: need to set attribute on the created node.
Status status = xla_kernel_creator.CreateKernel(flr_, callsite, &kernel_);
@ -127,13 +129,14 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrNotSet) {
Init({fdef});
XlaKernelCreator xla_kernel_creator;
Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto(
name: 'XTimesY'
op: 'XTimesY'
input: 'a'
input: 'b'
)proto"),
&kernel_);
Status status =
xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto(
name: 'XTimesY'
op: 'XTimesY'
input: 'a'
input: 'b'
)proto"),
&kernel_);
EXPECT_TRUE(errors::IsInternal(status)) << status.ToString();
}
@ -143,13 +146,14 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) {
Init({fdef});
XlaKernelCreator xla_kernel_creator;
Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto(
name: 'XTimesY'
op: 'XTimesY'
input: 'a'
input: 'b'
)proto"),
&kernel_);
Status status =
xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto(
name: 'XTimesY'
op: 'XTimesY'
input: 'a'
input: 'b'
)proto"),
&kernel_);
EXPECT_TRUE(errors::IsInternal(status)) << status.ToString();
}

View File

@ -218,12 +218,13 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
Device* dev = flr->device();
Status s;
OpKernelConstruction construction(
DeviceType(dev->device_type()), dev,
dev->GetAllocator(AllocatorAttributes()), &node_def,
&fbody->fdef.signature(), flr, dev->resource_manager(), fbody->arg_types,
input_memory_types, fbody->ret_types, output_memory_types,
flr->graph_def_version(), &s);
auto props = std::make_shared<NodeProperties>(
&fbody->fdef.signature(), node_def, fbody->arg_types, fbody->ret_types);
OpKernelConstruction construction(DeviceType(dev->device_type()), dev,
dev->GetAllocator(AllocatorAttributes()),
flr, dev->resource_manager(), props,
input_memory_types, output_memory_types,
flr->graph_def_version(), &s);
*kernel = absl::make_unique<XlaLocalLaunchBase>(
&construction, constant_arg_indices, resource_arg_indices, function,

View File

@ -0,0 +1,3 @@
# TensorFlow MLIR
These are the docs for: https://www.tensorflow.org/mlir

View File

@ -0,0 +1,26 @@
upper_tabs:
# Tabs left of dropdown menu
- include: /_upper_tabs_left.yaml
- include: /api_docs/_upper_tabs_api.yaml
# Dropdown menu
- name: Resources
path: /resources
is_default: true
menu:
- include: /resources/_menu_toc.yaml
lower_tabs:
# Subsite tabs
other:
- name: Guide
contents:
- title: Overview
path: /mlir/overview
- heading: Dialects
- title: Overview
path: /mlir/dialects
- title: TensorFlow
path: /mlir/tf_ops
- title: TensorFlow Lite
path: /mlir/tfl_ops
- include: /_upper_tabs_right.yaml

View File

@ -0,0 +1,54 @@
book_path: /mlir/_book.yaml
project_path: /mlir/_project.yaml
description: <!--no description-->
landing_page:
custom_css_path: /site-assets/css/style.css
rows:
- heading: MLIR unifies the infrastructure for high-performance ML models in TensorFlow.
items:
- description: >
The <a href="https://mlir.llvm.org/" class="external">MLIR</a> project defines a common
intermediate representation (IR) that unifies the infrastructure required to execute high
performance machine learning models in TensorFlow and similar ML frameworks. This project
will include the application of HPC techniques, along with integration of
search algorithms like reinforcement learning. MLIR aims to reduce the
cost to bring up new hardware, and improve usability for existing
TensorFlow users.
- code_block: |
<pre class = "prettyprint">
// Syntactically similar to LLVM:
func @testFunction(%arg0: i32) {
%x = call @thingToCall(%arg0) : (i32) -> i32
br ^bb1
^bb1:
%y = addi %x, %x : i32
return %y : i32
}
</pre>
- classname: devsite-landing-row-cards
items:
- heading: "Multi-Level Intermediate Representation for Compiler Infrastructure"
youtube_id: qzljG6DKgic
buttons:
- label: Watch the video
path: https://www.youtube.com/watch?v=qzljG6DKgic
- heading: "A new intermediate representation and compiler framework"
image_path: /resources/images/tf-logo-card-16x9.png
path: https://blog.tensorflow.org/2019/04/mlir-new-intermediate-representation.html
buttons:
- label: Read on TensorFlow blog
path: https://blog.tensorflow.org/2019/04/mlir-new-intermediate-representation.html
- heading: MLIR on GitHub
image_path: /resources/images/github-card-16x9.png
path: https://github.com/llvm/llvm-project/tree/master/mlir
buttons:
- label: View on GitHub
path: https://github.com/llvm/llvm-project/tree/master/mlir
- heading: TensorFlow MLIR on GitHub
image_path: /resources/images/github-card-16x9.png
path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir
buttons:
- label: View on GitHub
path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir

View File

@ -0,0 +1,37 @@
# MLIR dialects
## Overview
To separate different hardware and software targets, MLIR has “dialects”,
including:
* TensorFlow IR, which represents all things possible in TensorFlow graphs.
* XLA HLO IR, which is designed to take advantage of XLAs compilation
abilities (with output to, among other things, TPUs).
* An experimental affine dialect, which focuses on
[polyhedral representations](https://en.wikipedia.org/wiki/Polytope_model)
and optimizations.
* LLVM IR, which has a 1:1 mapping between it and LLVMs own representation,
allowing MLIR to emit GPU and CPU code through LLVM.
* TensorFlow Lite, which will translate to running code on mobile platforms.
Each dialect consists of a set of defined operations which have invariants
placed on them, like: “This is a binary operator, and the inputs and outputs
have the same types.”
## Adding to MLIR
MLIR has no fixed/built-in list of globally known operations (no “intrinsics”).
Dialects can define entirely custom types, which is how MLIR can model things
like the LLVM IR type system (which has first class aggregates), domain
abstractions important for ML-optimized accelerators like quantized types, and
even the Swift or Clang type systems (which are built around Swift/Clang
declaration nodes) in the future.
If you want to connect a new low-level compiler, you would create a new dialect
and the lowerings between the TensorFlow Graph dialect and your dialect.
This smooths the path for hardware and compiler makers. You can even target
dialects at different levels in the same model; the higher-level optimizers
will respect the unfamiliar parts of the IR and wait for a lower level to handle
it.

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 148 KiB

View File

@ -0,0 +1,36 @@
# MLIR
## Overview
MLIR, or Multi-Level Intermediate Representation, is a representation format
and library of compiler utilities that sits between the model representation
and low-level compilers/executors that generate hardware-specific code.
MLIR is, at its heart, a flexible infrastructure for modern optimizing
compilers. This means it consists of a specification for intermediate
representations (IR) and a code toolkit to perform transformations on that
representation. (In compiler parlance, as you move from higher-level
representations to lower-level representations, these transformations can be
called “lowerings”)
MLIR is highly influenced by [LLVM](https://llvm.org/) and unabashedly reuses
many great ideas from it. It has a flexible type system, and allows
representing, analyzing and transforming graphs combining multiple levels of
abstraction in the same compilation unit. These abstractions include TensorFlow
operations, nested polyhedral loop regions, and even LLVM instructions and fixed
hardware operations and types.
We expect MLIR to be of interest to many groups, including:
* Compiler researchers and implementers looking to optimize performance and
memory consumption of machine learning models
* Hardware makers looking for a way to connect their hardware to TensorFlow,
such as TPUs, portable neural hardware in phones, and other custom ASICs
* People writing language bindings that want to take advantage of optimizing
compilers and hardware acceleration.
The TensorFlow ecosystem contains a number of compilers and optimizers that
operate at multiple levels of the software and hardware stack. We expect the
gradual adoption of MLIR to simplify every aspect of this stack.
<img alt="MLIR overview diagram" src="./images/mlir-infra.svg"/>

View File

@ -208,6 +208,7 @@ cc_library(
"ir/tfl_ops.h.inc",
"ir/tfl_ops_interface.cc.inc",
"ir/tfl_ops_interface.h.inc",
"runtime_verifiers.inc",
"utils/attribute_utils.cc",
],
hdrs = [
@ -303,12 +304,14 @@ cc_library(
"transforms/optimize_functional_ops.cc",
"transforms/prepare_composite_functions_tf.cc",
"transforms/prepare_tf.cc",
"transforms/runtime_type_verify.cc",
"transforms/split_merged_operands.cc",
"transforms/trim_functions_tf.cc",
"transforms/unroll_batch_matmul.cc",
"transforms/while_loop_outline.cc",
],
hdrs = [
"ir/tfl_ops_interface.h.inc",
"transforms/dilated_conv.h",
"transforms/passes.h",
"transforms/unroll_batch_matmul.h",
@ -461,9 +464,9 @@ cc_library(
)
tf_native_cc_binary(
name = "operator-converter-gen",
name = "converter-gen",
srcs = [
"operator_converter_gen.cc",
"converter_gen.cc",
],
deps = [
"@llvm-project//llvm:support",
@ -473,14 +476,18 @@ tf_native_cc_binary(
)
gentbl(
name = "operator_converter_inc",
name = "converter_inc",
tbl_outs = [
(
"", # This driver has no options.
"--gen-operator-converters",
"operator_converters.inc",
),
(
"--gen-runtime-verifiers",
"runtime_verifiers.inc",
),
],
tblgen = ":operator-converter-gen",
tblgen = ":converter-gen",
td_file = "ir/tfl_ops.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
@ -582,7 +589,6 @@ cc_library(
"@com_google_absl//absl/strings",
"@flatbuffers",
"@llvm-project//llvm:support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
@ -645,12 +651,14 @@ tf_cc_binary(
"//tensorflow/compiler/mlir:init_mlir",
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:errors",
"//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
],
)
@ -694,7 +702,6 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -727,7 +734,6 @@ cc_library(
"//tensorflow/lite/tools/optimize:quantize_weights",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",

View File

@ -28,6 +28,9 @@ limitations under the License.
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
#include "mlir/TableGen/Attribute.h" // TF:llvm-project
#include "mlir/TableGen/Format.h" // TF:llvm-project
#include "mlir/TableGen/Operator.h" // TF:llvm-project
#include "mlir/TableGen/Predicate.h" // TF:llvm-project
using llvm::DefInit;
using llvm::dyn_cast;
@ -41,6 +44,19 @@ using llvm::SmallVector;
using llvm::StringInit;
using llvm::StringRef;
enum ActionType {
OpConv,
RuntimeVerify,
};
// NOLINTNEXTLINE
llvm::cl::opt<ActionType> action(
llvm::cl::desc("Action to perform:"),
llvm::cl::values(clEnumValN(OpConv, "gen-operator-converters",
"Generate operator converters"),
clEnumValN(RuntimeVerify, "gen-runtime-verifiers",
"Generate TFLite runtime verifiers")));
// Returns the associated option name for the given op definition.
static inline std::string GetOperatorOptionName(const Record &def) {
assert(def.getName().startswith("TFL_") && "unexpected op prefix");
@ -342,8 +358,101 @@ static bool OperatorWritersMain(raw_ostream &os, RecordKeeper &records) {
return false;
}
static void GenOperandResultVerifier(raw_ostream &os,
llvm::ArrayRef<llvm::Init *> values,
StringRef valueKind) {
mlir::tblgen::FmtContext fctx;
bool first = true;
for (auto static_value : llvm::enumerate(values)) {
auto *definit = llvm::cast<llvm::DefInit>(static_value.value());
auto *val = definit->getDef()->getValue("tflRuntimeTypePredicate");
if (!val) continue;
// Create code block on first type to verify.
if (first) {
os << " {\n";
os << " unsigned index = " << static_value.index() << ";\n";
first = false;
}
mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
auto desc =
definit->getDef()->getValueAsString("tflRuntimeTypeDescription");
// Emit a loop to check all the dynamic values in the pack.
os << formatv(" for (Value v : top.getODS{0}{1}s({2})) {{\n",
// Capitalize the first letter to match the function name
valueKind.substr(0, 1).upper(), valueKind.substr(1),
static_value.index());
os << " (void)v;\n"
<< " if (!("
<< tgfmt(pred.getCondition(), &fctx.withSelf("v.getType()")) << ")) {\n"
<< formatv(
" return op->emitOpError(\"{0} #\") << index "
"<< \" must be {1}, but got \" << v.getType();\n",
valueKind, desc)
<< " }\n" // if
<< " ++index;\n"
<< " }\n"; // for
}
// Emit closing brace if needed.
if (!first) os << " }\n";
}
// NOLINTNEXTLINE
static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
emitSourceFileHeader("MLIR TFLite Runtime Verifiers", os);
// Retrieve all the definitions derived from TFL_Op and sort by record name.
std::vector<Record *> defs = records.getAllDerivedDefinitions("Op");
llvm::sort(defs, LessRecord());
// Iterate through all the ops defined.
for (const auto *def : defs) {
mlir::tblgen::Operator op(*def);
if (!op.getTrait("TflRuntimeVerifyOpInterface::Trait")) continue;
mlir::tblgen::FmtContext verify_ctx;
os << "::mlir::LogicalResult " << op.getCppClassName()
<< "::VerifyTflRuntimeTypes(::mlir::Operation *op) {\n";
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
verify_ctx.withOp("top");
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
auto &value = op.getOperand(i);
// Skip from from first variadic operands for now. Else getOperand index
// used below doesn't match.
if (value.isVariadic()) break;
if (!value.name.empty())
verify_ctx.addSubst(value.name, formatv("op->getOperand({0})", i));
}
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
auto &value = op.getResult(i);
// Skip from from first variadic results for now. Else getResult index
// used below doesn't match.
if (value.isVariadic()) break;
if (!value.name.empty())
verify_ctx.addSubst(value.name, formatv("op->getResult({0})", i));
}
}
GenOperandResultVerifier(os, def->getValueAsDag("arguments")->getArgs(),
"operand");
GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(),
"result");
os << " return mlir::success();\n}\n";
}
return false;
}
int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
llvm::cl::ParseCommandLineOptions(argc, argv);
return TableGenMain(argv[0], &OperatorWritersMain);
if (action == ActionType::OpConv)
return TableGenMain(argv[0], &OperatorWritersMain);
return TableGenMain(argv[0], &RuntimeVerifierWriterMain);
}

View File

@ -71,4 +71,23 @@ def TFL_SparseOp : OpInterface<"SparseOpInterface"> {
];
}
//===----------------------------------------------------------------------===//
// TFL runtime type verification of operand/result types.
def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> {
let description = [{
Interface to verify TFLite runtime op verification.
This verifies that the converted TFLite ops has operand/result type
supported by the TFLite runtime.
}];
let methods = [
StaticInterfaceMethod<
[{Returns whether the op's operands/results are supported by runtime.}],
"LogicalResult", "VerifyTflRuntimeTypes", (ins "Operation*":$op)
>,
];
}
#endif // TFL_OP_INTERFACES

View File

@ -723,12 +723,11 @@ static LogicalResult Verify(PackOp op) {
}
// Make sure all inputs have the same shape and element type.
// TODO(rahulsp): Simplify once b/135032064 is fixed.
for (Value operand : op.getOperands()) {
auto other_type = operand.getType().cast<ShapedType>();
if (input_type != other_type)
// TODO(b/135032063): Simplify once fixed.
for (Type operand_type : op.getOperandTypes()) {
if (failed(mlir::verifyCompatibleShape(input_type, operand_type)))
return op.emitOpError("operands should be of the same type. got ")
<< input_type << ", " << other_type;
<< input_type << ", " << operand_type;
}
return success();
@ -1872,6 +1871,7 @@ LogicalResult WhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) {
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.cc.inc"
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
#include "tensorflow/compiler/mlir/lite/runtime_verifiers.inc"
Operation *TensorFlowLiteDialect::materializeConstant(OpBuilder &builder,
Attribute value,

File diff suppressed because it is too large Load Diff

View File

@ -282,6 +282,7 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
if (pass_config.legalize_tf_while) {
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
}
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,

View File

@ -150,7 +150,8 @@ struct QuantizationPattern : public RewritePattern {
explicit QuantizationPattern(MLIRContext* context, bool enable_verify,
float error_tolerance, bool single_layer_verify)
: RewritePattern(DQ::getOperationName(), 1, context),
// Set the score to a large number so it is always preferred.
: RewritePattern(DQ::getOperationName(), 300, context),
enable_verify(enable_verify),
error_tolerance(error_tolerance),
single_layer_verify(single_layer_verify) {}

View File

@ -178,15 +178,20 @@ func @inputsAfterOutputs() {
// -----
// expected-error@+1 {{Found malformed ophint regions: missing inputs or outputs.}}
module {
func @extractOphintFailure() {
func @extractOphintSame() {
%0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x16x1xf32>
%1 = call @AnotherFunc(%0) : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
%2 = "tf.Sigmoid"(%1) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
%3 = "tf.Mul"(%2, %1) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x16x1xf32>, tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
%4 = "tf.Identity"(%3) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "d4b1eb00b81211e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation-d4b1eb00b81211e99426dc4a3e957995-0-None-None"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
return
// CHECK: [[VAL_0:%.*]] = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x16x1xf32>
// CHECK: [[VAL_1:%.*]] = call @AnotherFunc([[VAL_0]]) : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
// CHECK: [[VAL_2:%.*]] = "tf.Sigmoid"([[VAL_1]]) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
// CHECK: [[VAL_3:%.*]] = "tf.Mul"([[VAL_2]], [[VAL_1]]) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x16x1xf32>, tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
// CHECK: [[VAL_4:%.*]] = "tf.Identity"([[VAL_3]]) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "d4b1eb00b81211e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation-d4b1eb00b81211e99426dc4a3e957995-0-None-None"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
}
func @AnotherFunc(%arg0: tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> {

View File

@ -739,6 +739,15 @@ func @matrix_diag_v3(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
// CHECK: return [[VAL_6]] : tensor<8x16x16xf32>
}
func @matrix_set_diag(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
%0 = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
return %0 : tensor<3x3xi32>
// CHECK-LABEL: func @matrix_set_diag(
// CHECK: [[VAL_0:%.*]] = "tfl.matrix_set_diag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
// CHECK: return [[VAL_0]]
}
func @maximum(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> {
%0 = "tf.Maximum"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
@ -1364,3 +1373,83 @@ func @reciprocal_i64(%arg0: tensor<8xi64>) -> tensor<8xi64> {
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xi64>, tensor<8xi64>) -> tensor<8xi64>
// CHECK: return
}
func @random_uniform() -> tensor<2x5xf32> {
%0 = "tf.Const"() { value = dense<[2, 5]> : tensor<2xi32> } : () -> tensor<2xi32>
%1 = "tf.RandomUniform"(%0) { seed = 1, seed2 = 0} : (tensor<2xi32>) -> tensor<2x5xf32>
return %1 : tensor<2x5xf32>
// CHECK-LABEL: random_uniform
// CHECK: %[[CST:.*]] = constant dense
// CHECK: return %[[CST:.*]] : tensor<2x5xf32>
}
func @random_uniform_no_fold(%arg0: tensor<2xi32>) -> tensor<2x5xf32> {
%1 = "tf.RandomUniform"(%arg0) { seed = 0, seed2 = 0} : (tensor<2xi32>) -> tensor<2x5xf32>
return %1 : tensor<2x5xf32>
// CHECK-LABEL: random_uniform_no_fold
// CHECK: %[[RANDOM:.*]] = "tf.RandomUniform"
}
func @random_uniform_no_fold2(%arg0: tensor<2xi32>) -> tensor<*xf32> {
%1 = "tf.RandomUniform"(%arg0) { seed = 1, seed2 = 2} : (tensor<2xi32>) -> tensor<*xf32>
return %1 : tensor<*xf32>
// CHECK-LABEL: random_uniform_no_fold2
// CHECK: %[[RANDOM:.*]] = "tf.RandomUniform"
}
func @random_uniform_no_fold3() -> tensor<2x5xf64> {
%0 = "tf.Const"() { value = dense<[2, 5]> : tensor<2xi32> } : () -> tensor<2xi32>
%1 = "tf.RandomUniform"(%0) { seed = 1, seed2 = 0} : (tensor<2xi32>) -> tensor<2x5xf64>
return %1 : tensor<2x5xf64>
// CHECK-LABEL: random_uniform_no_fold3
// CHECK: %[[RANDOM:.*]] = "tf.RandomUniform"
}
func @LstmWithoutProjection(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x16xf32>) {
%1 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16x28xf32>} : () -> tensor<16x28xf32>
%2 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16x16xf32>} : () -> tensor<16x16xf32>
%3 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16xf32>} : () -> tensor<16xf32>
%4 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<1x16xf32>} : () -> tensor<1x16xf32>
%5 = "tf.Const"() {device = "", dtype = f32, value = dense<-1.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
%6:3 = "tf.UnidirectionalSequenceLstm"(%arg, %1, %1, %1, %1, %2, %2, %2, %2, %3, %3, %3, %3, %3, %3, %3, %5, %5, %4, %4) {_tflite_input_indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 18, 19], device = ""} : (tensor<28x1x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1x16xf32>, tensor<1x16xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<28x1x16xf32>)
return %6#2 : tensor<28x1x16xf32>
}
// CHECK: func @LstmWithoutProjection([[VAL_0:%.*]]: tensor<28x1x28xf32>) -> tensor<28x1x16xf32> {
// CHECK: [[VAL_1:%.*]] = constant dense<0.000000e+00> : tensor<16x28xf32>
// CHECK: [[VAL_2:%.*]] = constant dense<0.000000e+00> : tensor<16x16xf32>
// CHECK: [[VAL_3:%.*]] = constant dense<0.000000e+00> : tensor<16xf32>
// CHECK: [[VAL_4:%.*]] = constant dense<0.000000e+00> : tensor<1x16xf32>
// CHECK: [[VAL_5:%.*]] = constant unit
// CHECK: [[VAL_6:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_1]], [[VAL_1]], [[VAL_1]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_2]], [[VAL_2]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_5]], [[VAL_5]], [[VAL_4]], [[VAL_4]], [[VAL_5]], [[VAL_5]], [[VAL_5]], [[VAL_5]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<28x1x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, none, none, tensor<1x16xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x16xf32>
// CHECK: return [[VAL_6]] : tensor<28x1x16xf32>
// CHECK: }
func @LstmWithProjection(%arg: tensor<28x1x16xf32>) -> (tensor<28x1x8xf32>) {
%1 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16x16xf32>} : () -> tensor<16x16xf32>
%2 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16x8xf32>} : () -> tensor<16x8xf32>
%3 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16xf32>} : () -> tensor<16xf32>
%4 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<1x16xf32>} : () -> tensor<1x16xf32>
%5 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<8x16xf32>} : () -> tensor<8x16xf32>
%6 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<1x8xf32>} : () -> tensor<1x8xf32>
%7 = "tf.Const"() {device = "", dtype = f32, value = dense<-1.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
%8:3 = "tf.UnidirectionalSequenceLstm"(%arg, %1, %1, %1, %1, %2, %2, %2, %2, %7, %7, %7, %3, %3, %3, %3, %5, %7, %6, %4) {_tflite_input_indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 18, 19], device = ""} : (tensor<28x1x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<8x16xf32>, tensor<1xf32>, tensor<1x8xf32>, tensor<1x16xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<28x1x8xf32>)
return %8#2 : tensor<28x1x8xf32>
}
// CHECK-LABEL: func @LstmWithProjection(
// CHECK-SAME: [[VAL_7:%.*]]: tensor<28x1x16xf32>) -> tensor<28x1x8xf32> {
// CHECK: [[VAL_8:%.*]] = constant dense<0.000000e+00> : tensor<16x16xf32>
// CHECK: [[VAL_9:%.*]] = constant dense<0.000000e+00> : tensor<16x8xf32>
// CHECK: [[VAL_10:%.*]] = constant dense<0.000000e+00> : tensor<16xf32>
// CHECK: [[VAL_11:%.*]] = constant dense<0.000000e+00> : tensor<1x16xf32>
// CHECK: [[VAL_12:%.*]] = constant dense<0.000000e+00> : tensor<8x16xf32>
// CHECK: [[VAL_13:%.*]] = constant dense<0.000000e+00> : tensor<1x8xf32>
// CHECK: [[VAL_14:%.*]] = constant unit
// CHECK: [[VAL_15:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_12]], [[VAL_14]], [[VAL_13]], [[VAL_11]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_14]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<28x1x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, none, none, none, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<8x16xf32>, none, tensor<1x8xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x8xf32>
// CHECK: return [[VAL_15]] : tensor<28x1x8xf32>
// CHECK: }

View File

@ -1,4 +1,4 @@
// RUN: tf-opt -split-input-file -verify-diagnostics %s | FileCheck %s --dump-input-on-failure
// RUN: tf-opt -split-input-file -verify-diagnostics -tfl-runtime-verify %s | FileCheck %s --dump-input-on-failure
// Unary math ops
// -----
@ -878,6 +878,14 @@ func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
// -----
func @packUnranked(%arg0: tensor<2xi32>, %arg1: tensor<*xi32>) -> tensor<2x2xi32> {
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32}
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<*xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
func @packInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<1x4x2xi32> {
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32}
%0 = "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<1x4x2xi32>

View File

@ -511,3 +511,34 @@ func @PadStridedSliceNewAxisMask2(%arg0: tensor<4x64x64x1xf32>) -> tensor<1x4x64
%1 = "tf.StridedSlice"(%0, %cst, %cst, %cst_0) {Index = i32, T = f32, _output_shapes = ["tfshape$dim { size: 1 } dim { size: 4 } dim { size: 64 } dim { size: 64 }"], begin_mask = 6 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 1 : i64, shrink_axis_mask = 0 : i64} : (tensor<4x64x64xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x4x64x64xf32>
return %1 : tensor<1x4x64x64xf32>
}
// CHECK-LABEL: @MatrixSetDiagV2Conversion
func @MatrixSetDiagV2Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
%cst = constant dense<0> : tensor<i32>
%0 = "tf.MatrixSetDiagV2"(%arg0, %arg1, %cst) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
return %0 : tensor<3x3xi32>
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
// CHECK: return %[[RES]]
}
// CHECK-LABEL: @MatrixSetDiagV2NonZeroK
func @MatrixSetDiagV2NonZeroK(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
%cst = constant dense<1> : tensor<i32>
%0 = "tf.MatrixSetDiagV2"(%arg0, %arg1, %cst) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
return %0 : tensor<3x3xi32>
// CHECK: %[[CST:.*]] = constant dense<1> : tensor<i32>
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiagV2"(%arg0, %arg1, %[[CST]]) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
// CHECK: return %[[RES]]
}
// CHECK-LABEL: @MatrixSetDiagV3Conversion
func @MatrixSetDiagV3Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
%cst = constant dense<0> : tensor<i32>
%0 = "tf.MatrixSetDiagV3"(%arg0, %arg1, %cst) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
return %0 : tensor<3x3xi32>
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
// CHECK: return %[[RES]]
}

View File

@ -2,39 +2,44 @@
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-quantize -tfl-numeric-verify | FileCheck --check-prefix=DEBUG %s
// CHECK-LABEL: QuantizeFloatConst
func @QuantizeFloatConst() -> tensor<f32> {
func @QuantizeFloatConst() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
%0 = constant dense<-0.1> : tensor<2x2xf32>
%1 = "tfl.quantize"(%0) {qtype = tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
%2 = "tfl.dequantize"(%1) : (tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<f32>
return %2 : tensor<f32>
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
return %1 : tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<0> : tensor<2x2xi8>}
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[cst]])
// CHECK: return %[[dq]] : tensor<f32>
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<0> : tensor<2x2xi8>}
// CHECK: return %[[cst]]
}
// CHECK-LABEL: QuantizeDenseFloatConst
func @QuantizeDenseFloatConst() -> tensor<2x2xf32> {
func @QuantizeDenseFloatConst() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
%0 = constant dense<[[-0.1, 1.0], [1.0, 3.0]]> : tensor<2x2xf32>
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
%2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<2x2xf32>
return %2 : tensor<2x2xf32>
return %1 : tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<{{\[\[}}0, -1], {{\[}}-1, -1]]> : tensor<2x2xi8>}
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[cst]])
// CHECK: return %[[dq]] : tensor<2x2xf32>
// CHECK: return %[[cst]]
}
// CHECK-LABEL: QuantizeSplatFloatConst
func @QuantizeSplatFloatConst() -> tensor<2x2xf32> {
func @QuantizeSplatFloatConst() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
%0 = constant dense<3.0> : tensor<2x2xf32>
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
return %1 : tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<-1> : tensor<2x2xi8>}
// CHECK: return %[[cst]]
}
// CHECK-LABEL: NotQuantizeFloatConst
func @NotQuantizeFloatConst() -> tensor<2x2xf32> {
%0 = constant dense<-0.1> : tensor<2x2xf32>
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
%2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<2x2xf32>
return %2 : tensor<2x2xf32>
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<-1> : tensor<2x2xi8>}
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[cst]])
// CHECK: return %[[dq]] : tensor<2x2xf32>
// CHECK: %[[cst:.*]] = constant dense<-1.000000e-01> : tensor<2x2xf32>
// CHECK: return %[[cst]] : tensor<2x2xf32>
}
// CHECK-LABEL: DequantizeAndQuantize

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/init_mlir.h"
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
@ -32,8 +33,10 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
#include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h"
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/stream_executor/lib/statusor.h"
@ -130,12 +133,24 @@ int main(int argc, char **argv) {
llvm::SourceMgr source_mgr;
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context);
StatusOr<mlir::OwningModuleRef> module =
tensorflow::LoadFromGraphdefOrMlirSource(
input_file_name, input_mlir, use_splatted_constant, custom_opdefs,
debug_info_file, input_arrays, input_dtypes, input_shapes,
output_arrays,
/*prune_unused_nodes=*/true, &source_mgr, &context);
StatusOr<mlir::OwningModuleRef> module;
// TODO(b/147435528): We need to test the e2e behavior once the graph freezing
// inside mlir is done.
if (import_saved_model || import_saved_model_v1) {
if (input_mlir)
module = tensorflow::errors::InvalidArgument(
"Importing saved model should not have input_mlir set");
module = tensorflow::ImportSavedModel(
import_saved_model, import_saved_model_v1, input_file_name,
saved_model_tags, saved_model_exported_names, &context);
} else {
module = tensorflow::LoadFromGraphdefOrMlirSource(
input_file_name, input_mlir, use_splatted_constant, custom_opdefs,
debug_info_file, input_arrays, input_dtypes, input_shapes,
output_arrays,
/*prune_unused_nodes=*/true, &source_mgr, &context);
}
// If errors occur, the library call in the above already logged the error
// message. So we can just return here.
@ -182,6 +197,7 @@ int main(int argc, char **argv) {
pass_config.inline_functions = inline_functions;
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm);
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
std::string result;
auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer(

View File

@ -22,6 +22,33 @@ using llvm::cl::opt;
opt<std::string> input_file_name(llvm::cl::Positional,
llvm::cl::desc("<input file>"),
llvm::cl::init("-"));
// NOLINTNEXTLINE
opt<bool> import_saved_model(
"savedmodel-to-mlir",
llvm::cl::desc("Import a saved model to its MLIR representation"),
llvm::cl::value_desc("dir"));
// NOLINTNEXTLINE
opt<bool> import_saved_model_v1(
"savedmodel-v1-to-mlir",
llvm::cl::desc("Import a saved model V1 to its MLIR representation"),
llvm::cl::value_desc("dir"));
// NOLINTNEXTLINE
opt<std::string> saved_model_tags(
"tf-savedmodel-tags",
llvm::cl::desc("Tags used to indicate which MetaGraphDef to import, "
"separated by ','"),
llvm::cl::init("serve"));
// NOLINTNEXTLINE
opt<std::string> saved_model_exported_names(
"tf-savedmodel-exported-names",
llvm::cl::desc("Names to export from SavedModel, separated by ','. Empty "
"(the default) means export all."),
llvm::cl::init(""));
// NOLINTNEXTLINE
opt<std::string> output_file_name("o", llvm::cl::desc("<output file>"),
llvm::cl::value_desc("filename"),

View File

@ -39,4 +39,10 @@ extern llvm::cl::opt<bool> inline_functions;
extern llvm::cl::list<std::string> custom_opdefs;
extern llvm::cl::opt<bool> emit_quant_adaptor_ops;
extern llvm::cl::opt<std::string> quant_stats_file_name;
// Import saved model.
extern llvm::cl::opt<bool> import_saved_model;
extern llvm::cl::opt<bool> import_saved_model_v1;
extern llvm::cl::opt<std::string> saved_model_tags;
extern llvm::cl::opt<std::string> saved_model_exported_names;
#endif // TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_TRANSLATE_CL_H_

View File

@ -15,6 +15,10 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
#include <string>
#include <unordered_set>
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Parser.h" // TF:llvm-project
@ -155,4 +159,37 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
return Status::OK();
}
StatusOr<mlir::OwningModuleRef> ImportSavedModel(
bool import_saved_model, bool import_saved_model_v1,
const std::string& input_filename, const std::string& saved_model_tags,
const std::string& saved_model_exported_names, mlir::MLIRContext* context) {
if (import_saved_model) {
std::unordered_set<std::string> tags =
absl::StrSplit(saved_model_tags, ',');
std::vector<std::string> exported_names =
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
auto module = tensorflow::SavedModelToMlirImport(
input_filename, tags, absl::Span<std::string>(exported_names), context);
if (!module)
return tensorflow::errors::InvalidArgument("fail to open input file");
return module;
} else if (import_saved_model_v1) {
std::unordered_set<std::string> tags =
absl::StrSplit(saved_model_tags, ',');
auto module =
tensorflow::SavedModelV1ToMlirImport(input_filename, tags, context);
if (!module)
return tensorflow::errors::InvalidArgument("fail to open input file");
return module;
} else {
return tensorflow::errors::InvalidArgument(
"Should be either saved model v1 or v2");
}
}
} // namespace tensorflow

View File

@ -40,6 +40,12 @@ LoadFromGraphdefOrMlirSource(
absl::string_view output_arrays, bool prune_unused_nodes,
llvm::SourceMgr* source_mgr, mlir::MLIRContext* context);
// Load Saved model (either v1 or v2) into MLIR.
stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel(
bool import_saved_model, bool import_saved_model_v1,
const std::string& input_filename, const std::string& saved_model_tags,
const std::string& saved_model_exported_names, mlir::MLIRContext* context);
// Taking a MLIR module in TF executor dialect and a set of parameters,
// applies a set of passes to convert the module to TF Lite dialect and
// serializes the result to a string. Depending on an attribute in the module

View File

@ -698,11 +698,10 @@ void ExtractOphintPass::runOnModule() {
if (ophint_composite_ops.empty()) continue;
// Verify: Make sure all ophint_composite_ops are valid.
// If not valid, we just don't do anything.
for (const auto& kv : ophint_composite_ops) {
if (failed(kv.getValue().VerifyOphint())) {
module.emitError()
<< "Found malformed ophint regions: missing inputs or outputs.";
return signalPassFailure();
return;
}
}

View File

@ -365,3 +365,7 @@ def : Pat<
/*padding=*/ $padding,
/*stride_h=*/ ExtractI32At<1>:$strides,
/*stride_w=*/ ExtractI32At<2>:$strides)>;
def : Pat<
(TF_MatrixSetDiagOp $input, $diagonal),
(TFL_MatrixSetDiagOp $input, $diagonal)>;

View File

@ -49,6 +49,8 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/random/philox_random.h"
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
namespace mlir {
@ -61,6 +63,9 @@ namespace {
using xla::Status;
using xla::StatusOr;
constexpr char kUnidirectionalSequenceLstm[] = "tf.UnidirectionalSequenceLstm";
constexpr char kTfLiteInputIndices[] = "_tflite_input_indices";
// Legalize operations in functions.
struct LegalizeTF : public FunctionPass<LegalizeTF> {
void runOnFunction() override;
@ -114,9 +119,54 @@ DECL_CONVERT_OP(SplitV);
DECL_CONVERT_OP(StridedSlice);
DECL_CONVERT_OP(Unpack);
DECL_CONVERT_OP(Reciprocal);
DECL_CONVERT_OP(RandomUniform);
#undef DECL_CONVERT_OP
PatternMatchResult ConvertTFRandomUniformOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto random_uniform_op = cast<TF::RandomUniformOp>(op);
if (random_uniform_op.seed() == 0 && random_uniform_op.seed2() == 0) {
return matchFailure();
}
if (!random_uniform_op.dtype().isF32()) {
return matchFailure();
}
typedef tensorflow::random::UniformDistribution<
tensorflow::random::PhiloxRandom, float>
Distribution;
tensorflow::random::PhiloxRandom generator(
random_uniform_op.seed().getSExtValue(),
random_uniform_op.seed2().getSExtValue());
Distribution dist;
int num_elements = 0;
if (auto output_type =
random_uniform_op.output().getType().dyn_cast_or_null<ShapedType>()) {
if (auto ranked_output = output_type.dyn_cast_or_null<RankedTensorType>()) {
if (!ranked_output.hasRank() || ranked_output.getNumDynamicDims() != 0) {
return matchFailure();
}
num_elements = output_type.getNumElements();
size_t offset = 0;
size_t num_samples = Distribution::kResultElementCount;
llvm::SmallVector<float, 32> data;
data.resize(num_elements);
while (offset < num_elements) {
const typename Distribution::ResultType samples = dist(&generator);
std::copy(&samples[0],
&samples[0] + std::min(num_samples, data.size() - offset),
&data[0] + offset);
offset += num_samples;
}
auto output_data = DenseFPElementsAttr::get(output_type, data);
rewriter.replaceOpWithNewOp<ConstantOp>(op, output_type, output_data);
return matchSuccess();
}
}
return matchFailure();
}
PatternMatchResult ConvertTFConcatOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_concat_op = cast<TF::ConcatOp>(op);
@ -514,6 +564,74 @@ PatternMatchResult ConvertTFReciprocalOp::matchAndRewrite(
return matchSuccess();
}
// Legalize unidirectional sequence lstm.
struct LegalizeUnidirectionalSequenceLstm : public RewritePattern {
explicit LegalizeUnidirectionalSequenceLstm(MLIRContext* context)
: RewritePattern(kUnidirectionalSequenceLstm, 1, context) {}
PatternMatchResult matchAndRewrite(Operation* op,
PatternRewriter& rewriter) const override {
auto tflite_indices_attr =
op->getAttrOfType<ArrayAttr>(kTfLiteInputIndices);
if (!tflite_indices_attr) return matchFailure();
SmallVector<int64_t, 20> tflite_indices;
for (auto index_attr : tflite_indices_attr.getValue()) {
IntegerAttr index = index_attr.cast<IntegerAttr>();
tflite_indices.push_back(index.getInt());
}
// Optional input placeholder.
Value none = rewriter.create<mlir::ConstantOp>(
op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
// Populate inputs.
// UnidirectionalSequenceLstm is expected to have 24 inputs.
SmallVector<Value, 24> inputs;
int count = 0;
int total_ophint_converted_inputs = tflite_indices.size();
for (int i = 0; i < 24; ++i) {
if (count < total_ophint_converted_inputs && tflite_indices[count] == i) {
// specified input.
inputs.push_back(op->getOperand(i));
count++;
} else {
// Non specified input.
inputs.push_back(none);
}
}
// Populate outputs.
// UnidirectionalSequenceLstm should only have 1 output, and that is the
// original ophint converted node's 3rd output.
SmallVector<Type, 4> result_types;
result_types.push_back(op->getOpResult(2).getType());
// Populate attributes.
SmallVector<NamedAttribute, 4> attributes;
// Activation will always be tanh.
attributes.push_back(rewriter.getNamedAttr("fused_activation_function",
rewriter.getStringAttr("TANH")));
// cell_clip.
attributes.push_back(
rewriter.getNamedAttr("cell_clip", rewriter.getF32FloatAttr(10.0)));
// proj_clip.
attributes.push_back(
rewriter.getNamedAttr("proj_clip", rewriter.getF32FloatAttr(0.0)));
// will always be time_majored.
attributes.push_back(
rewriter.getNamedAttr("time_major", rewriter.getBoolAttr(true)));
auto lstm_op = rewriter.create<TFL::UnidirectionalSequenceLSTMOp>(
op->getLoc(), result_types, inputs, attributes);
// Rewire the output.
op->getResult(2).replaceAllUsesWith(lstm_op.getResult());
op->erase();
return matchSuccess();
}
};
void LegalizeTF::runOnFunction() {
OwningRewritePatternList patterns;
auto* ctx = &getContext();
@ -521,11 +639,15 @@ void LegalizeTF::runOnFunction() {
// Add the generated patterns to the list.
populateWithGenerated(ctx, &patterns);
patterns.insert<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp,
ConvertTFMatrixDiagV2Op, ConvertTFMatrixDiagV3Op,
ConvertTFPackOp, ConvertTFReshapeOp, ConvertTFSplitOp,
ConvertTFSplitVOp, ConvertTFStridedSliceOp, ConvertTFUnpackOp,
ConvertTFAssertOp, ConvertTFReciprocalOp>(ctx);
patterns
.insert<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp,
ConvertTFMatrixDiagV2Op, ConvertTFMatrixDiagV3Op, ConvertTFPackOp,
ConvertTFReshapeOp, ConvertTFSplitOp, ConvertTFSplitVOp,
ConvertTFStridedSliceOp, ConvertTFUnpackOp, ConvertTFAssertOp,
ConvertTFReciprocalOp, ConvertTFRandomUniformOp>(ctx);
// Ophint python converter converted tf node pattern.
patterns.insert<LegalizeUnidirectionalSequenceLstm>(ctx);
applyPatternsGreedily(func, patterns);
}

View File

@ -199,6 +199,22 @@ def : Pat<
(TFL_HardSwishOp $x),
[(EqualOperands $x, $y)]>;
// Matching HardSwish with extra FakeQuant. These FakeQuant ops were due to
// incorrect placement in the quantization aware training.
// TODO(b/149735743): We should make the placement automatically.
def : Pat<
(TFL_MulOp (TFL_DequantizeOp (TFL_QuantizeOp
(TFL_MulOp
$x, (TFL_DequantizeOp (TFL_QuantizeOp (TFL_AddOp
$y,
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
TFL_AF_Relu6), $qattr2)),
TFL_AF_None), $qattr1)),
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
TFL_AF_None),
(TFL_HardSwishOp $x),
[(EqualOperands $x, $y)]>;
// Constraint that the attribute value is less than 'n'
class ConstDoubleValueLessThan<string n> : Constraint<
CPred<"$0.isa<DenseElementsAttr>() && "

View File

@ -91,6 +91,9 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFWhilePass();
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateWhileOutlinePass();
// Verifies runtime supports types used.
std::unique_ptr<OpPassBase<FuncOp>> CreateRuntimeTypeVerifyPass();
} // namespace TFL
} // namespace mlir

View File

@ -190,3 +190,16 @@ def : Pat<(TF_ReshapeOp:$old_value
// parameters of the input, so we can remove the quantization ops.
def : Pat<(TF_RankOp (TFL_DequantizeOp (TFL_QuantizeOp $input, $qtype))),
(TF_RankOp $input)>;
// `k` is expected to be 0, other values are not supported currently.
def : Pat<(TF_MatrixSetDiagV2Op $input, $diagonal,
(ConstantOp ConstantAttr<I32ElementsAttr, "{0}">)),
(TF_MatrixSetDiagOp $input, $diagonal)>;
// `align` attribute can be ignored because we only support converting
// `MatrixSetDiagV3` to `MatrixSetDiag` with default `k` inputs.
def : Pat<(TF_MatrixSetDiagV3Op $input, $diagonal,
(ConstantOp ConstantAttr<I32ElementsAttr, "{0}">),
$align),
(TF_MatrixSetDiagOp $input, $diagonal)>;

View File

@ -21,12 +21,20 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
// Quantize attribute $0 by using quantization parameter from %1.
def QuantizeByQuantizedType : NativeCodeCall<"quant::Quantize($0, $1.getValue())">;
def F32ElementsAttr : ElementsAttrBase<
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
// Squash tfl.dequantize and tfl.quantize pairs.
// TODO(fengliuai): Compare the scale of input and output. This can also be
// squashed to a requantize op if the scales are different.
def : Pat<(TFL_QuantizeOp (TFL_DequantizeOp $in), $qt), (replaceWithValue $in)>;
// If the tfl.dequantize op wasn't fused, we shouldn't quantize the floating
// point constant.
def : Pat<(TFL_DequantizeOp
(TFL_QuantizeOp (ConstantOp F32ElementsAttr:$cst), $qt)),
(ConstantOp $cst)>;
// Quantize the value of a constant op if the quantization parameters have been
// propagated to the output.
def : Pat<(TFL_QuantizeOp

View File

@ -0,0 +1,52 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
namespace mlir {
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.h.inc"
namespace TFL {
namespace {
// This pass verifies that the operands and results types are supported by
// TFLite runtime.
class RuntimeTypeVerifyPass : public mlir::FunctionPass<RuntimeTypeVerifyPass> {
public:
explicit RuntimeTypeVerifyPass() {}
private:
void runOnFunction() override;
};
void RuntimeTypeVerifyPass::runOnFunction() {
getFunction().walk([&](TflRuntimeVerifyOpInterface op) {
if (failed(op.VerifyTflRuntimeTypes(op.getOperation())))
signalPassFailure();
});
}
} // namespace
// Verifies runtime supports types used.
std::unique_ptr<OpPassBase<FuncOp>> CreateRuntimeTypeVerifyPass() {
return std::make_unique<RuntimeTypeVerifyPass>();
}
static PassRegistration<RuntimeTypeVerifyPass> pass(
"tfl-runtime-verify", "TFLite runtime verification");
} // namespace TFL
} // namespace mlir

View File

@ -168,6 +168,10 @@ std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val) {
result.getResultNumber());
return std::string(result.getOwner()->getName().getStringRef());
}
// Use the ASM syntax for BloackArgument
if (auto arg = val.dyn_cast<mlir::BlockArgument>()) {
return "arg" + std::to_string(arg.getArgNumber());
}
return "";
}

View File

@ -287,6 +287,7 @@ cc_library(
"transforms/materialize_mlir_passthrough_op.cc",
"transforms/optimize.cc",
"transforms/optimize_global_tensors.cc",
"transforms/parallel_execute_to_islands.cc",
"transforms/promote_resources_to_args.cc",
"transforms/raise_control_flow.cc",
"transforms/replicate_invariant_op_hoisting.cc",
@ -708,7 +709,6 @@ cc_library(
deps = [
":tensorflow_dialect_registration",
":tf_dialect_passes",
"@llvm-project//mlir:AllPassesAndDialects",
],
)
@ -913,7 +913,6 @@ cc_library(
"//tensorflow/core/platform:logging",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",

View File

@ -41,11 +41,52 @@ limitations under the License.
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "mlir/Support/STLExtras.h" // TF:llvm-project
#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/core/platform/logging.h"
namespace mlir {
namespace tf_device {
//===----------------------------------------------------------------------===//
// TF Device Dialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
struct TFInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
//===--------------------------------------------------------------------===//
// Analysis Hooks
//===--------------------------------------------------------------------===//
// Defines the legality of inlining TF Device operations.
bool isLegalToInline(Operation*, Region*, BlockAndValueMapping&) const final {
// For now, enable inlining all operations.
return true;
}
//===--------------------------------------------------------------------===//
// Transformation Hooks
//===--------------------------------------------------------------------===//
// Attempts to materialize a conversion for a type mismatch between a call
// from this dialect, and a callable region. This method should generate an
// operation that takes 'input' as the only operand, and produces a single
// result of 'resultType'. If a conversion can not be generated, nullptr
// should be returned.
// This is just re-using the same logic as the TensorFlow dialect right now.
Operation* materializeCallConversion(OpBuilder& builder, Value input,
Type result_type,
Location conversion_loc) const final {
if (!result_type.isa<TensorType>() || !input.getType().isa<TensorType>())
return nullptr;
return builder.create<TF::CastOp>(conversion_loc, result_type, input,
/*truncate=*/builder.getBoolAttr(false));
}
};
} // end anonymous namespace
TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context)
: Dialect(/*name=*/"tf_device", context) {
addOperations<
@ -54,6 +95,8 @@ TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context)
>();
addOperations<ParallelExecuteOp>();
addInterfaces<TFInlinerInterface>();
}
//===----------------------------------------------------------------------===//

View File

@ -573,9 +573,9 @@ void Print(SwitchNOp switchn, OpAsmPrinter &p) {
ParseResult ParseSwitchNOp(OpAsmParser &parser, OperationState &result) {
// Parsing:
// %2:6 = tf_executor.SwitchN %0, %1 by 5 : tensor<??xf32>
// %2:6 = tf_executor.SwitchN %0, %1 of 5 : tensor<??xf32>
// Where the first operand is the data to replicate, the second is an i32
// indicating which output to populate, followed by the keyword `by` and the
// indicating which output to populate, followed by the keyword `of` and the
// number of outputs (+1 for the control token).
SmallVector<OpAsmParser::OperandType, 2> op_infos;
SmallVector<Type, 1> types;

View File

@ -165,7 +165,7 @@ def TfExecutor_IslandOp : TfExecutor_Op<"island",
The `tf_executor.island` operation has a single region with a single block
attached (only functional control flow is allowed). The block is terminated
by a `tf_executor.yield` operation. The operands of the terminator
correspond to the result values of the `tf_executor.graph` operation. An
correspond to the result values of the `tf_executor.island` operation. An
extra result of type `!tf_executor.control` is always produced by every
`tf_executor.island`.
Within an island, execution semantics follow standard sequential behavior as
@ -299,7 +299,7 @@ def TfExecutor_SwitchNOp : TfExecutor_Op<"SwitchN",
.SetShapeFn(SwitchNShape);
For example:
%2:6 = tf_executor.SwitchN %0, %1 by 5 : tensor<??xf32>
%2:6 = tf_executor.SwitchN %0, %1 of 5 : tensor<??xf32>
Note: One additional result corresponds to the control output.
}];

View File

@ -49,7 +49,7 @@ an output element, this operation computes \\(y = |x|\\).
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape]>,
def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic]>,
WithBroadcastableBinOpBuilder {
let summary = "Returns x + y element-wise.";
@ -98,7 +98,7 @@ Inputs must be of same size and shape.
let hasFolder = 1;
}
def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape]>,
def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic]>,
WithBroadcastableBinOpBuilder {
let summary = "Returns x + y element-wise.";
@ -508,8 +508,9 @@ Broadcasting is supported, so `value` may have any number of dimensions.
let extraClassDeclaration = [{
// TF_LayoutSensitiveInterface:
SmallVector<int64_t, 4> GetLayoutDependentArgs() { return {0}; }
SmallVector<int64_t, 4> GetLayoutDependentResults() { return {0}; }
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
LogicalResult UpdateDataFormat(StringRef data_format);
}];
}
@ -980,7 +981,7 @@ tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j]
let hasCanonicalizer = 1;
}
def TF_Conv2DOp : TF_Op<"Conv2D", [NoSideEffect]> {
def TF_Conv2DOp : TF_Op<"Conv2D", [NoSideEffect, TF_LayoutSensitiveInterface]> {
let summary = [{
Computes a 2-D convolution given 4-D `input` and `filter` tensors.
}];
@ -1030,6 +1031,13 @@ horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
let verifier = [{
return Verify(*this);
}];
let extraClassDeclaration = [{
// TF_LayoutSensitiveInterface:
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
LogicalResult UpdateDataFormat(StringRef data_format);
}];
}
def TF_Conv2DBackpropFilterOp : TF_Op<"Conv2DBackpropFilter", [NoSideEffect]> {
@ -2091,7 +2099,7 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<3>;
}
def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect]> {
def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
let summary = "Batch normalization.";
let description = [{
@ -2122,6 +2130,13 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>;
let extraClassDeclaration = [{
// TF_FoldOperandsTransposeInterface:
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
}];
}
def TF_GatherOp : TF_Op<"Gather", [NoSideEffect]> {
@ -3392,6 +3407,130 @@ tf.matrix_diag(diagonal, k = -1, num_rows = 3, padding_value = 9)
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_MatrixSetDiagOp : TF_Op<"MatrixSetDiag", [NoSideEffect]> {
let summary = [{
Returns a batched matrix tensor with new batched diagonal values.
}];
let description = [{
Given `input` and `diagonal`, this operation returns a tensor with the
same shape and values as `input`, except for the main diagonal of the
innermost matrices. These will be overwritten by the values in `diagonal`.
The output is computed as follows:
Assume `input` has `k+1` dimensions `[I, J, K, ..., M, N]` and `diagonal` has
`k` dimensions `[I, J, K, ..., min(M, N)]`. Then the output is a
tensor of rank `k+1` with dimensions `[I, J, K, ..., M, N]` where:
* `output[i, j, k, ..., m, n] = diagonal[i, j, k, ..., n]` for `m == n`.
* `output[i, j, k, ..., m, n] = input[i, j, k, ..., m, n]` for `m != n`.
}];
let arguments = (ins
TF_Tensor:$input,
TF_Tensor:$diagonal
);
let results = (outs
TF_Tensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_MatrixSetDiagV2Op : TF_Op<"MatrixSetDiagV2", [NoSideEffect]> {
let summary = [{
Returns a batched matrix tensor with new batched diagonal values.
}];
let description = [{
Given `input` and `diagonal`, this operation returns a tensor with the
same shape and values as `input`, except for the specified diagonals of the
innermost matrices. These will be overwritten by the values in `diagonal`.
`input` has `r+1` dimensions `[I, J, ..., L, M, N]`. When `k` is scalar or
`k[0] == k[1]`, `diagonal` has `r` dimensions `[I, J, ..., L, max_diag_len]`.
Otherwise, it has `r+1` dimensions `[I, J, ..., L, num_diags, max_diag_len]`.
`num_diags` is the number of diagonals, `num_diags = k[1] - k[0] + 1`.
`max_diag_len` is the longest diagonal in the range `[k[0], k[1]]`,
`max_diag_len = min(M + min(k[1], 0), N + min(-k[0], 0))`
The output is a tensor of rank `k+1` with dimensions `[I, J, ..., L, M, N]`.
If `k` is scalar or `k[0] == k[1]`:
```
output[i, j, ..., l, m, n]
= diagonal[i, j, ..., l, n-max(k[1], 0)] ; if n - m == k[1]
input[i, j, ..., l, m, n] ; otherwise
```
Otherwise,
```
output[i, j, ..., l, m, n]
= diagonal[i, j, ..., l, diag_index, index_in_diag] ; if k[0] <= d <= k[1]
input[i, j, ..., l, m, n] ; otherwise
```
where `d = n - m`, `diag_index = k[1] - d`, and `index_in_diag = n - max(d, 0)`.
For example:
```
# The main diagonal.
input = np.array([[[7, 7, 7, 7], # Input shape: (2, 3, 4)
[7, 7, 7, 7],
[7, 7, 7, 7]],
[[7, 7, 7, 7],
[7, 7, 7, 7],
[7, 7, 7, 7]]])
diagonal = np.array([[1, 2, 3], # Diagonal shape: (2, 3)
[4, 5, 6]])
tf.matrix_set_diag(diagonal) ==> [[[1, 7, 7, 7], # Output shape: (2, 3, 4)
[7, 2, 7, 7],
[7, 7, 3, 7]],
[[4, 7, 7, 7],
[7, 5, 7, 7],
[7, 7, 6, 7]]]
# A superdiagonal (per batch).
tf.matrix_set_diag(diagonal, k = 1)
==> [[[7, 1, 7, 7], # Output shape: (2, 3, 4)
[7, 7, 2, 7],
[7, 7, 7, 3]],
[[7, 4, 7, 7],
[7, 7, 5, 7],
[7, 7, 7, 6]]]
# A band of diagonals.
diagonals = np.array([[[1, 2, 3], # Diagonal shape: (2, 2, 3)
[4, 5, 0]],
[[6, 1, 2],
[3, 4, 0]]])
tf.matrix_set_diag(diagonals, k = (-1, 0))
==> [[[1, 7, 7, 7], # Output shape: (2, 3, 4)
[4, 2, 7, 7],
[0, 5, 3, 7]],
[[6, 7, 7, 7],
[3, 1, 7, 7],
[7, 4, 2, 7]]]
```
}];
let arguments = (ins
TF_Tensor:$input,
TF_Tensor:$diagonal,
I32Tensor:$k
);
let results = (outs
TF_Tensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_MatrixSetDiagV3Op : TF_Op<"MatrixSetDiagV3", [NoSideEffect]> {
let summary = [{
Returns a batched matrix tensor with new batched diagonal values.
@ -3551,7 +3690,7 @@ retained with length 1.
>];
}
def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect]> {
def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
let summary = "Performs max pooling on the input.";
let description = [{
@ -3571,6 +3710,13 @@ def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect]> {
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let extraClassDeclaration = [{
// TF_FoldOperandsTransposeInterface:
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
}];
}
def TF_MaxPoolGradOp : TF_Op<"MaxPoolGrad", [NoSideEffect]> {
@ -4714,7 +4860,7 @@ I.e., \\(y = 1 / x\\).
let hasCanonicalizer = 1;
}
def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType]> {
def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType, TF_LayoutAgnostic]> {
let summary = "Computes rectified linear: `max(features, 0)`.";
let description = [{
@ -6657,7 +6803,7 @@ variables.
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
}
def TF_TanhOp : TF_Op<"Tanh", [NoSideEffect, SameOperandsAndResultType]> {
def TF_TanhOp : TF_Op<"Tanh", [NoSideEffect, SameOperandsAndResultType, TF_LayoutAgnostic]> {
let summary = "Computes hyperbolic tangent of `x` element-wise.";
let description = [{

View File

@ -58,6 +58,10 @@ TODO: Make invariants more structured so that we can reference them in ops.
def TF_OperandsSameAsResultsTypeOrRef : NativeOpTrait<
"TF::OperandsSameAsResultsTypeOrRef">;
// Layout agnostic operations do not depend on the operands data layout (data
// format), as an example all element wise operations are layout agnostic.
def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">;
//===----------------------------------------------------------------------===//
// TensorFlow op definitions
//===----------------------------------------------------------------------===//

View File

@ -44,11 +44,17 @@ def TF_LayoutSensitiveInterface : OpInterface<"LayoutSensitiveInterface"> {
>,
InterfaceMethod<
[{Returns indices of layout dependent arguments.}],
"SmallVector<int64_t, 4>", "GetLayoutDependentArgs", (ins)
"SmallVector<unsigned, 4>", "GetLayoutDependentArgs", (ins)
>,
InterfaceMethod<
[{Returns indices of layout dependent results.}],
"SmallVector<int64_t, 4>", "GetLayoutDependentResults", (ins)
"SmallVector<unsigned, 4>", "GetLayoutDependentResults", (ins)
>,
InterfaceMethod<
[{Updates operation attributes and operands to account for the updated
data format. If data format is not supported, must return failure.}],
"LogicalResult", "UpdateDataFormat",
(ins "StringRef":$data_format)
>,
];
@ -57,4 +63,42 @@ def TF_LayoutSensitiveInterface : OpInterface<"LayoutSensitiveInterface"> {
}];
}
def TF_FoldOperandsTransposeInterface : OpInterface<"FoldOperandsTransposeInterface"> {
let description = [{
Operation supports folding operand(s) transposes into the operation itself.
(1) Operation might have layout dependent operands and results...
Example: MaxPool(Transpose($arg, $perm))
-> Transpose(MaxPool($arg, $perm))
(2) ... or it might have only layout dependent operands:
Example: Mean(Transpose($arg, $reduction_dims))
-> Mean($arg, Transpose($reduction_dims))
}];
let methods = [
InterfaceMethod<
[{Returns indices of layout dependent arguments.}],
"SmallVector<unsigned, 4>", "GetLayoutDependentArgs", (ins)
>,
InterfaceMethod<
[{Returns indices of layout dependent results.}],
"SmallVector<unsigned, 4>", "GetLayoutDependentResults", (ins)
>,
InterfaceMethod<
[{Updates operation attributes and operands to account for the folded
permutation. If folding of permutation is not possible, must return
failure.}],
"LogicalResult", "FoldOperandsPermutation",
(ins "ArrayRef<int64_t>":$permutation)
>,
];
let verify = [{
return VerifyFoldOperandsTransposeInterface($_op);
}];
}
#endif // TF_OP_INTERFACES

View File

@ -292,6 +292,156 @@ static LogicalResult VerifyTypesCompatibility(
return success();
}
//===----------------------------------------------------------------------===//
// TF op helper functions to work with layout transformation.
//===----------------------------------------------------------------------===//
SmallVector<int64_t, 4> GetDataFormatPermutation(StringRef from, StringRef to) {
if (from == "NHWC" && to == "NCHW") {
return {0, 3, 1, 2};
} else if (from == "NCHW" && to == "NHWC") {
return {0, 2, 3, 1};
} else {
return {};
}
}
// Shuffle elements in the `attr` according to the permutation. Optional
// `inner_size` allows to shuffle array attributes created from rank 2 tensors
// on outer dimension only.
ArrayAttr ShuffleArrayAttr(ArrayAttr attr, ArrayRef<int64_t> permutation,
int inner_size = 1) {
if (attr.size() == 0) return attr;
assert(attr.size() % inner_size == 0);
assert(attr.size() / inner_size == permutation.size());
SmallVector<Attribute, 8> values{attr.begin(), attr.end()};
SmallVector<Attribute, 8> shuffled(values.size());
for (size_t i = 0; i < permutation.size(); ++i) {
for (size_t j = 0; j < inner_size; ++j) {
shuffled[i * inner_size + j] = values[permutation[i] * inner_size + j];
}
}
return ArrayAttr::get(shuffled, attr.getContext());
}
// Shuffle ranked tensor dimensions according to the permutation.
Type ShuffleRankedTensorType(Type type, ArrayRef<int64_t> permutation) {
if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
ArrayRef<int64_t> shape = ranked_type.getShape();
assert(permutation.size() == shape.size());
SmallVector<int64_t, 4> new_shape(permutation.size());
for (size_t i = 0; i < permutation.size(); ++i)
new_shape[i] = shape[permutation[i]];
return RankedTensorType::get(new_shape, ranked_type.getElementType());
}
return type;
}
static bool AreCancellablePermutations(DenseIntElementsAttr perm0,
DenseIntElementsAttr perm1) {
if (perm0.getNumElements() == 0 || perm1.getNumElements() == 0) return false;
if (perm0.getNumElements() != perm1.getNumElements()) return false;
SmallVector<int64_t, 8> perm0_values;
for (auto value : perm0.getIntValues())
perm0_values.push_back(value.getSExtValue());
SmallVector<int64_t, 8> perm1_values;
for (auto value : perm1.getIntValues())
perm1_values.push_back(value.getSExtValue());
for (int i = 0; i < perm0_values.size(); ++i) {
if (perm0_values[perm1_values[i]] != i) return false;
}
return true;
}
// Default implementation of `LayoutSensitiveInterface::UpdateDataFormat` for
// layout sensitive operations that do not have any additional layout dependent
// attributes besides `data_format` string.
template <typename Op>
LogicalResult UpdateDataFormat(StringRef data_format, Op *op) {
auto perm = GetDataFormatPermutation(op->data_format(), data_format);
if (perm.empty()) return failure();
// Update data format attribute.
op->setAttr("data_format", StringAttr::get(data_format, op->getContext()));
// Update types for all layout sensitive results.
auto layout_sensitive = cast<LayoutSensitiveInterface>(op->getOperation());
for (unsigned idx : layout_sensitive.GetLayoutDependentResults()) {
OpResult result = op->getOperation()->getResult(idx);
result.setType(ShuffleRankedTensorType(result.getType(), perm));
}
return success();
}
// Default implementation for folding operand transpose into the operation.
// See `FoldOperandsTransposeInterface::FoldOperandsPermutation`.
template <typename Op>
LogicalResult FoldOperandsPermutation(
ArrayRef<int64_t> permutation, Op *op,
ArrayRef<std::pair<StringRef, ArrayAttr>> shuffle_attrs = {}) {
MLIRContext *context = op->template getParentOfType<ModuleOp>().getContext();
// We only support NHWC <-> NCHW permutations.
static constexpr std::array<int64_t, 4> kNchwToNhwc = {0, 2, 3, 1};
static constexpr std::array<int64_t, 4> kNhwcToNchw = {0, 3, 1, 2};
// Operation data format after folding `permutation`.
StringRef target_data_format = [&]() -> StringRef {
if (op->data_format() == "NHWC" && permutation.equals(kNchwToNhwc)) {
return "NCHW"; // cancel NCHW->NHWC operand permutation
} else if (op->data_format() == "NCHW" && permutation.equals(kNhwcToNchw)) {
return "NHWC"; // cancel NHWC->NCHW operand permutation
} else {
return "";
}
}();
if (target_data_format.empty()) return failure();
// To fold operand `permutation` into the `op` we need shuffle all layout
// dependent attributes and types with a reverse permutation, and change
// operation data format to `target_data_format`.
//
// Example:
// %1 = SomeOp(...) {data_format = NHWC}
// %2 = Transpose(%1) {permutation = NHWC->NCHW}
// %3 = Op(%2) {data_format = NCHW}
//
// To bypass %2 we have to change data format to shuffle data format from NCHW
// to NHWC, which is the reverse of operand permutation (function argument).
auto reverse_permutation =
GetDataFormatPermutation(op->data_format(), target_data_format);
if (reverse_permutation.empty()) return failure();
op->setAttr("data_format", StringAttr::get(target_data_format, context));
for (auto pair : shuffle_attrs) {
StringRef attr_name = pair.first;
ArrayAttr attr_value = pair.second;
op->setAttr(attr_name, ShuffleArrayAttr(attr_value, reverse_permutation));
}
auto fold = cast<FoldOperandsTransposeInterface>(op->getOperation());
for (unsigned idx : fold.GetLayoutDependentResults()) {
OpResult result = op->getOperation()->getResult(idx);
result.setType(
ShuffleRankedTensorType(result.getType(), reverse_permutation));
}
return success();
}
namespace {
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
} // namespace
@ -459,6 +609,15 @@ static LogicalResult Verify(BiasAddOp op) {
return success();
}
// TODO(ezhulenev): BiasAddOp is not really layout sensitive, it must only
// support folding operand transposes.
LogicalResult BiasAddOp::UpdateDataFormat(StringRef data_format) {
auto ranked = value().getType().dyn_cast<RankedTensorType>();
if (!ranked || ranked.getRank() != 4) return failure();
return ::mlir::TF::UpdateDataFormat(data_format, this);
}
//===----------------------------------------------------------------------===//
// BiasAddGradOp
//===----------------------------------------------------------------------===//
@ -817,6 +976,21 @@ static LogicalResult Verify(OpT op) {
return success();
}
LogicalResult Conv2DOp::UpdateDataFormat(StringRef data_format) {
auto perm = GetDataFormatPermutation(this->data_format(), data_format);
if (perm.empty()) return failure();
// Update data_format attribute and result types.
if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure();
// Update convolution attributes.
setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
setAttr("strides", ShuffleArrayAttr(strides(), perm));
setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2));
return success();
}
//===----------------------------------------------------------------------===//
// Conv2dBackpropInputOp
//===----------------------------------------------------------------------===//
@ -1138,6 +1312,11 @@ static LogicalResult Verify(FusedBatchNormOp op) {
return success();
}
LogicalResult FusedBatchNormV3Op::FoldOperandsPermutation(
ArrayRef<int64_t> permutation) {
return ::mlir::TF::FoldOperandsPermutation(permutation, this);
}
//===----------------------------------------------------------------------===//
// GatherV2Op
//===----------------------------------------------------------------------===//
@ -1330,6 +1509,16 @@ void MaxOp::build(Builder *builder, OperationState &result, Value input,
build(builder, result, out_ty, input, reduction_indices, keep_dims);
}
//===----------------------------------------------------------------------===//
// MaxPoolOp
//===----------------------------------------------------------------------===//
LogicalResult MaxPoolOp::FoldOperandsPermutation(
ArrayRef<int64_t> permutation) {
return ::mlir::TF::FoldOperandsPermutation(
permutation, this, {{"strides", strides()}, {"ksize", ksize()}});
}
//===----------------------------------------------------------------------===//
// MaxPoolGradOp
//===----------------------------------------------------------------------===//
@ -1347,6 +1536,38 @@ static LogicalResult Verify(MaxPoolGradOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// MeanOp
//===----------------------------------------------------------------------===//
LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) {
// Reduction indices must be defined by a constant operation.
auto reduction_op =
dyn_cast_or_null<TF::ConstOp>(reduction_indices().getDefiningOp());
if (!reduction_op) return failure();
auto reductions_value = reduction_op.value().dyn_cast<DenseElementsAttr>();
if (!reductions_value) return failure();
// Prepare new reduction indices according to operand permutation.
SmallVector<int64_t, 4> shuffled_reduction;
llvm::transform(reductions_value.getIntValues(),
std::back_inserter(shuffled_reduction),
[&](APInt idx) { return permutation[idx.getSExtValue()]; });
// Add constant operation with a new reduction indices.
OpBuilder builder(getOperation());
auto type = mlir::RankedTensorType::get(shuffled_reduction.size(),
builder.getIntegerType(64));
auto values = mlir::DenseIntElementsAttr::get(type, shuffled_reduction);
auto shuffled_reduction_op = builder.create<TF::ConstOp>(getLoc(), values);
// Use new reduction indices.
setOperand(1, shuffled_reduction_op);
return success();
}
//===----------------------------------------------------------------------===//
// NegOp
//===----------------------------------------------------------------------===//
@ -2723,23 +2944,46 @@ void TransposeOp::build(Builder *builder, OperationState &result, Value x,
perm);
}
OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
auto const_perm = dyn_cast_or_null<TF::ConstOp>(perm().getDefiningOp());
namespace {
if (!const_perm) {
return {};
}
OpFoldResult FoldIdentityTranspose(TransposeOp op) {
auto const_perm = dyn_cast_or_null<TF::ConstOp>(op.perm().getDefiningOp());
if (!const_perm) return {};
auto const_value = const_perm.value();
const auto &elements = const_value.getValues<APInt>();
for (auto it : llvm::enumerate(elements)) {
if (it.index() != it.value()) {
return {};
}
if (it.index() != it.value()) return {};
}
return x();
return op.x();
}
OpFoldResult FoldCancellableTranspose(TransposeOp op) {
// Operand is a TransposeOp.
auto transpose = dyn_cast_or_null<TF::TransposeOp>(op.x().getDefiningOp());
if (!transpose) return {};
// Permutations defined by constant operations.
auto perm0 = dyn_cast_or_null<TF::ConstOp>(op.perm().getDefiningOp());
auto perm1 = dyn_cast_or_null<TF::ConstOp>(transpose.perm().getDefiningOp());
if (!perm0 || !perm1) return {};
// With permutation indices that cancel each other
auto perm0_value = perm0.value().cast<DenseIntElementsAttr>();
auto perm1_value = perm1.value().cast<DenseIntElementsAttr>();
if (!AreCancellablePermutations(perm0_value, perm1_value)) return {};
return transpose.x();
}
} // namespace
OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
if (auto folded = FoldIdentityTranspose(*this)) return folded;
if (auto folded = FoldCancellableTranspose(*this)) return folded;
return {};
}
//===----------------------------------------------------------------------===//

View File

@ -172,7 +172,7 @@ else_branch: A function that takes 'inputs' and returns a list of
}];
}
def TF_MeanOp : TF_Op<"Mean", [NoSideEffect]> {
def TF_MeanOp : TF_Op<"Mean", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
let summary = "Computes the mean of elements across dimensions of a tensor.";
let description = [{
@ -195,6 +195,13 @@ retained with length 1.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
let extraClassDeclaration = [{
// TF_FoldOperandsTransposeInterface:
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {}; }
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
}];
}
def TF_LegacyCallOp : TF_Op<"LegacyCall",

View File

@ -112,12 +112,26 @@ static LogicalResult VerifyIndexPath(Operation *op, NamedAttribute named_attr) {
return mlir::success();
}
// Return true if `type` is a tensor of `!tf.resource`. This is the type that is
// used to represent mutable variables on exported functions' bound inputs.
static bool IsResourceVarType(Type type) {
TensorType tensor_type = type.dyn_cast<TensorType>();
if (!tensor_type) return false;
return tensor_type.getElementType().isa<TF::ResourceType>();
static LogicalResult VerifyBoundInputArgType(Operation *op_for_diagnostics,
Type arg_type,
GlobalTensorOp global_tensor) {
if (global_tensor.is_mutable()) {
auto expected_type = RankedTensorType::get(
{}, TF::ResourceType::get({global_tensor.type().cast<TensorType>()},
arg_type.getContext()));
if (arg_type != expected_type) {
return op_for_diagnostics->emitError()
<< "mutable bound input with type " << arg_type
<< " expected to have type " << expected_type;
}
} else {
if (arg_type != global_tensor.type()) {
return op_for_diagnostics->emitError()
<< "bound input for immutable 'tf_saved_model.global_tensor' must "
"match the global tensor's type";
}
}
return success();
}
LogicalResult TensorFlowSavedModelDialect::verifyRegionArgAttribute(
@ -137,20 +151,7 @@ LogicalResult TensorFlowSavedModelDialect::verifyRegionArgAttribute(
<< symbol_name << "'";
}
auto arg_type = cast<FuncOp>(op).getArgument(arg_index).getType();
if (global_tensor.is_mutable()) {
if (!IsResourceVarType(arg_type)) {
return op->emitError()
<< "bound inputs for mutable 'tf_saved_model.global_tensor's "
"must be tensors of '!tf.resource'";
}
} else {
if (arg_type != global_tensor.type()) {
return op->emitError() << "bound input for immutable "
"'tf_saved_model.global_tensor' must "
"match the global tensor's type";
}
}
return success();
return VerifyBoundInputArgType(op, arg_type, global_tensor);
}
if (named_attr.first == "tf_saved_model.index_path") {
return VerifyIndexPath(op, named_attr);

View File

@ -68,6 +68,11 @@ class OperandsSameAsResultsTypeOrRef
}
};
// Layout agnostic operations do not depend on the operands data layout (data
// format), as and example all element wise operations are layout agnostic.
template <typename ConcreteType>
class LayoutAgnostic : public TraitBase<ConcreteType, LayoutAgnostic> {};
} // namespace TF
} // namespace OpTrait
} // namespace mlir

View File

@ -21,23 +21,35 @@ limitations under the License.
namespace mlir {
namespace TF {
LogicalResult VerifyLayoutSensitiveInterface(Operation* op) {
auto layout_sensitive_interface = cast<LayoutSensitiveInterface>(op);
namespace {
if (!llvm::all_of(
layout_sensitive_interface.GetLayoutDependentArgs(),
[&](int64_t index) { return index < op->getNumOperands(); })) {
template <typename Interface>
LogicalResult VerifyLayoutDependentArgsAndResults(Operation* op,
Interface interface) {
auto valid_operand = [&](int64_t idx) { return idx < op->getNumOperands(); };
if (!llvm::all_of(interface.GetLayoutDependentArgs(), valid_operand)) {
return op->emitOpError("layout dependent argument index is out of bound");
}
if (!llvm::all_of(
layout_sensitive_interface.GetLayoutDependentResults(),
[&](int64_t index) { return index < op->getNumResults(); })) {
auto valid_result = [&](int64_t idx) { return idx < op->getNumResults(); };
if (!llvm::all_of(interface.GetLayoutDependentResults(), valid_result)) {
return op->emitOpError("layout dependent result index is out of bound");
}
return success();
}
} // namespace
LogicalResult VerifyLayoutSensitiveInterface(Operation* op) {
auto layout_sensitive_interface = cast<LayoutSensitiveInterface>(op);
return VerifyLayoutDependentArgsAndResults(op, layout_sensitive_interface);
}
LogicalResult VerifyFoldOperandsTransposeInterface(Operation* op) {
auto fold_operands_transpose = cast<FoldOperandsTransposeInterface>(op);
return VerifyLayoutDependentArgsAndResults(op, fold_operands_transpose);
}
} // namespace TF
} // namespace mlir

View File

@ -29,6 +29,12 @@ namespace TF {
// [0, getNumOperands/getNumResults) range.
LogicalResult VerifyLayoutSensitiveInterface(Operation* op);
// Verifies correctness of ops implementing FoldOperandsTransposeInterface (see
// definition in tf_op_base.td):
// (1) Layout dependent arguments and results indices must be in
// [0, getNumOperands/getNumResults) range.
LogicalResult VerifyFoldOperandsTransposeInterface(Operation* op);
} // namespace TF
} // namespace mlir

View File

@ -3,6 +3,10 @@
// All tests also test for idempotence.
// Test that external functions aren't processed (used to crash).
// CHECK-LABEL: func @unused_external_func
func @unused_external_func()
func @multiple_return(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> (tensor<*xi32>, tensor<*xi32>) {
%graph:2 = tf_executor.graph {
%island:3 = tf_executor.island {
@ -276,3 +280,67 @@ func @empty_island_multiple_data_results(%arg0: tensor<*xf32>, %arg1: tensor<*xi
}
return
}
// The following tests check that certain control dependencies between islands
// and certain tf_executor ops are added correctly.
// CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print"
// CHECK: tf_executor.NextIteration.Sink [{{.*}}] {{.*}}, %[[CONTROL]]
func @next_iteration_sink_control_input() {
tf_executor.graph {
%source:3 = tf_executor.NextIteration.Source : tensor<*xi32>
%island:2 = tf_executor.island {
%const = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32>
%print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>)
tf_executor.yield %const : tensor<*xi32>
}
tf_executor.NextIteration.Sink[%source#1] %island#0 : tensor<*xi32>
tf_executor.fetch %island#0 : tensor<*xi32>
}
return
}
// CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print"
// CHECK: tf_executor.LoopCond {{.*}}, %[[CONTROL]]
func @loop_cond_control_input() {
tf_executor.graph {
%island:2 = tf_executor.island {
%const = "tf.Const"() {value = dense<1> : tensor<i1>} : () -> tensor<*xi1>
%print = "tf.Print"(%const) : (tensor<*xi1>) -> (tensor<*xi1>)
tf_executor.yield %const : tensor<*xi1>
}
%loop_cond:2 = tf_executor.LoopCond %island#0 : tensor<*xi1>
tf_executor.fetch %loop_cond#0 : tensor<*xi1>
}
return
}
// CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print"
// CHECK: tf_executor.Enter {{.*}}, %[[CONTROL]]
func @enter_control_input() {
tf_executor.graph {
%island:2 = tf_executor.island {
%const = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32>
%print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>)
tf_executor.yield %const : tensor<*xi32>
}
%enter:2 = tf_executor.Enter %island#0 frame "some/frame" : tensor<*xi32>
tf_executor.fetch %enter#0 : tensor<*xi32>
}
return
}
// CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print"
// CHECK: tf_executor.SwitchN {{.*}}, {{.*}} of {{[0-9]*}} (%[[CONTROL]])
func @switchn_control_input(%arg1: tensor<i32>) {
tf_executor.graph {
%island:2 = tf_executor.island {
%const = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32>
%print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>)
tf_executor.yield %const : tensor<*xi32>
}
%switchn:4 = tf_executor.SwitchN %island#0, %arg1 of 3: tensor<*xi32>
tf_executor.fetch %switchn#0 : tensor<*xi32>
}
return
}

View File

@ -383,6 +383,28 @@ func @nonIdentityTranspose(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x6x5xf32
// CHECK: return %1
}
// CHECK-LABEL: @cancellableTranspose
func @cancellableTranspose(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> {
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
%1 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
%2 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
%3 = "tf.Transpose"(%2, %1) : (tensor<1x8x4x4xf32>, tensor<4xi32>) -> tensor<1x4x4x8xf32>
return %3 : tensor<1x4x4x8xf32>
// CHECK: return %arg0
}
// CHECK-LABEL: @nonCancellableTranspose
func @nonCancellableTranspose(%arg0: tensor<1x4x4x8xf32>) -> tensor<4x1x4x8xf32> {
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
%1 = "tf.Const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
%2 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
%3 = "tf.Transpose"(%2, %1) : (tensor<1x8x4x4xf32>, tensor<4xi32>) -> tensor<4x1x4x8xf32>
return %3 : tensor<4x1x4x8xf32>
// CHECK: return %3
}
// CHECK-LABEL: func @addN
func @addN(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: return %arg0

View File

@ -9,5 +9,7 @@ func @device_test(%arg0: tensor<3x1xf32>) -> (tensor<3x3xf32>) {
%1 = "tf.MatMul"(%arg0, %0) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "", transpose_a = false, transpose_b = false} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
// CHECK: device = "cpu"
%2 = "tf.Relu"(%1) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "cpu"} : (tensor<3x3xf32>) -> tensor<3x3xf32>
return %2 : tensor<3x3xf32>
// CHECK: device = "gpu"
%3 = "tf.Relu"(%2) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"]} : (tensor<3x3xf32>) -> tensor<3x3xf32>
return %3 : tensor<3x3xf32>
}

View File

@ -0,0 +1,57 @@
// RUN: tf-opt %s -tf-executor-tpu-v1-island-coarsening | FileCheck %s --dump-input=fail
// Test that islands with a function call are merged if the call is to a function
// that contains ops with the same attribute.
// CHECK-LABEL: func @control_input
func @control_input(%arg0 : tensor<i1>) -> tensor<i32> {
%0:6 = tf_executor.graph {
%1:2 = tf_executor.island wraps "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i1>) -> tensor<i32>
%2:2 = tf_executor.island wraps "tf.While"(%1#0) {name = "A", body = @while_body_with_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
%3:2 = tf_executor.island wraps "tf.While"(%1#0) {name = "B", body = @while_body_with_wrong_cluster_attr, cond = @while_cond_with_wrong_cluster_attr, is_stateless = false, parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
%4:2 = tf_executor.island wraps "tf.While"(%1#0) {name = "C", body = @while_body_without_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
%6:2 = tf_executor.island wraps "tf.While"(%1#0) {name = "D", body = @while_body_without_cluster_attr, cond = @while_cond_without_cluster_attr, is_stateless = false, parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
%5:2 = tf_executor.island wraps "tf.While"(%1#0) {name = "E", body = @while_body_with_cluster_attr, cond = @while_cond_without_cluster_attr, is_stateless = false, parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
// CHECK: "tf.opA"
// CHECK-NOT: island
// CHECK: name = "A"
// CHECK-NOT: island
// CHECK: name = "C"
// CHECK-NOT: island
// CHECK: name = "E"
// CHECK: island {{.*}}name = "B"
// CHECK: island {{.*}}name = "D"
tf_executor.fetch %1#0, %2#0, %3#0, %4#0, %5#0, %6#0 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}
return %0#0 : tensor<i32>
}
func @while_body_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
func @while_cond_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
func @while_body_with_wrong_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
%0 = "some.op"(%arg0) {_tpu_replicate = "wrong_cluster"} : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
func @while_cond_with_wrong_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
%0 = "some.op"(%arg0) {_tpu_replicate = "wrong_cluster"} : (tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
func @while_body_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
func @while_cond_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}

View File

@ -0,0 +1,44 @@
// RUN: tf-opt %s -tf-executor-tpu-v1-island-inlining | FileCheck %s --dump-input=fail
// CHECK-NOT: tf.PartitionedCall
// CHECK-NOT: module @_tpu_v1_compat_outlined
module {
func @control_input(%arg0: tensor<i1>) -> tensor<i32> {
%0:4 = tf_executor.graph {
%outputs:4, %control = tf_executor.island wraps "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @_tpu_v1_compat_outlined::@_tpu_v1_compat_outlined_func0} : (tensor<i1>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>)
tf_executor.fetch %outputs#0, %outputs#1, %outputs#2, %outputs#3 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}
return %0#0 : tensor<i32>
}
module @_tpu_v1_compat_outlined {
func @_tpu_v1_compat_outlined_func0(%arg0: tensor<i1>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
"tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "device", num_replicas = 1 : i64, topology = "topology"} : () -> ()
%0 = "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i1>) -> tensor<i32>
%1 = "tf.While"(%0) {body = @while_body_with_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "A", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
%2 = "tf.While"(%0) {body = @while_body_without_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "C", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
%3 = "tf.While"(%0) {body = @while_body_with_cluster_attr, cond = @while_cond_without_cluster_attr, is_stateless = false, name = "E", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
return %0, %1, %2, %3 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}
func @while_body_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
func @while_cond_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
func @while_body_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
func @while_cond_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
%0 = "tf.PartionedCalledOp"(%arg0) {f = @callee_func} : (tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
func @callee_func(%arg0: tensor<i32>) -> tensor<i1> {
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
}
}

View File

@ -0,0 +1,48 @@
// RUN: tf-opt %s -tf-executor-tpu-v1-island-outlining | FileCheck %s --dump-input=fail
// CHECK: func @control_input
// CHECK-NOT: func @
// CHECK-LABEL: module @_tpu_v1_compat_outlined
// CHECK: @_tpu_v1_compat_outlined_func0
// CHECK: func @while_body_with_cluster_attr
// CHECK: func @while_cond_with_cluster_attr
// CHECK: func @while_body_without_cluster_attr
// CHECK: func @while_cond_without_cluster_attr
// CHECK: func @callee_func
module {
func @control_input(%arg0: tensor<i1>) -> tensor<i32> {
%0:4 = tf_executor.graph {
%outputs:4, %control = tf_executor.island {
"tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
%1 = "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i1>) -> tensor<i32>
%2 = "tf.While"(%1) {body = @while_body_with_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "A", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
%3 = "tf.While"(%1) {body = @while_body_without_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "C", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
%4 = "tf.While"(%1) {body = @while_body_with_cluster_attr, cond = @while_cond_without_cluster_attr, is_stateless = false, name = "E", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
tf_executor.yield %1, %2, %3, %4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}
tf_executor.fetch %outputs#0, %outputs#1, %outputs#2, %outputs#3 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}
return %0#0 : tensor<i32>
}
func @while_body_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
func @while_cond_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
func @while_body_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
func @while_cond_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
%0 = "tf.PartionedCalledOp"(%arg0) { f = @callee_func} : (tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
func @callee_func(%arg0: tensor<i32>) -> tensor<i1> {
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
}

View File

@ -1,41 +1,24 @@
// RUN: tf-opt %s -tf-layout-assignment=force-data-format=NCHW -verify-diagnostics | FileCheck %s
// RUN: tf-opt %s -tf-layout-optimization=force-data-format=NCHW -verify-diagnostics | FileCheck %s --dump-input=always
// CHECK-LABEL: func @transposeBiasAdd
func @transposeBiasAdd(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
func @transposeBiasAdd(%arg0: tensor<1x8x4x4xf32>, %arg1: tensor<8xf32>) -> tensor<1x8x4x4xf32> {
// Check that BiasAdd was converted to forced data format, and layout
// dependent arguments and results passed through transpose nodes.
// Convert input: NCHW -> NHWC
%0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x8x4x4xf32>, tensor<4xi64>) -> tensor<1x4x4x8xf32>
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
// CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%[[ARG_TRANSPOSE]], %arg1) {data_format = "NCHW"} {{.*}} tensor<1x8x4x4xf32>
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[BIAS_ADD]], %[[RES_PERM]])
// CHECK: return %[[RES_TRANSPOSE]]
%0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NHWC"} : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
// Compute in NHWC
%2 = "tf.BiasAdd"(%1, %arg1) {data_format = "NHWC"} : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
return %0 : tensor<1x4x4x8xf32>
}
// Convert result back: NHWC -> NCHW
%3 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
%4 = "tf.Transpose"(%2, %3) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
// CHECK-LABEL: func @transposeBiasAddWithDefaultAttr
func @transposeBiasAddWithDefaultAttr(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
// Check that BiasAdd computed in NCHW format, and all redundant transpose
// operations removed from the function.
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
// CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%[[ARG_TRANSPOSE]], %arg1) {data_format = "NCHW"} {{.*}} tensor<1x8x4x4xf32>
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[BIAS_ADD]], %[[RES_PERM]])
// CHECK: return %[[RES_TRANSPOSE]]
%0 = "tf.BiasAdd"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
// CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW"} {{.*}} tensor<1x8x4x4xf32>
// CHECK: return %[[BIAS_ADD]]
return %0 : tensor<1x4x4x8xf32>
}
// CHECK-LABEL: func @transposeBiasWithUnknownShape
func @transposeBiasWithUnknownShape(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tensor<*xf32> {
// CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%[[ARG_TRANSPOSE]], %arg1) {data_format = "NCHW"} {{.*}} tensor<*xf32>
%0 = "tf.BiasAdd"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
return %4 : tensor<1x8x4x4xf32>
}

View File

@ -0,0 +1,75 @@
// RUN: tf-opt %s -tf-layout-assignment=force-data-format=NCHW -verify-diagnostics | FileCheck %s --dump-input=always
// CHECK-LABEL: func @transposeBiasAdd
func @transposeBiasAdd(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
// Check that BiasAdd was converted to forced data format, and layout
// dependent arguments and results passed through transpose nodes.
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
// CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%[[ARG_TRANSPOSE]], %arg1) {data_format = "NCHW"} {{.*}} tensor<1x8x4x4xf32>
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[BIAS_ADD]], %[[RES_PERM]])
// CHECK: return %[[RES_TRANSPOSE]]
%0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NHWC"} : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
return %0 : tensor<1x4x4x8xf32>
}
// CHECK-LABEL: func @transposeBiasAddWithDefaultAttr
func @transposeBiasAddWithDefaultAttr(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
// CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%[[ARG_TRANSPOSE]], %arg1) {data_format = "NCHW"} {{.*}} tensor<1x8x4x4xf32>
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[BIAS_ADD]], %[[RES_PERM]])
// CHECK: return %[[RES_TRANSPOSE]]
%0 = "tf.BiasAdd"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
return %0 : tensor<1x4x4x8xf32>
}
// CHECK-LABEL: func @transposeBiasWithUnknownShape
func @transposeBiasWithUnknownShape(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tensor<*xf32> {
// CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%[[ARG_TRANSPOSE]], %arg1) {data_format = "NCHW"} {{.*}} tensor<*xf32>
%0 = "tf.BiasAdd"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @transposeConv2D
func @transposeConv2D(%input: tensor<1x32x32x3xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<1x32x32x8xf32> {
// IMPORTANT: Tensor shapes do not match convolution parameters (stride,
// dilations, etc...). This test only verifies that changing convolution data
// layout will update all the attributes.
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
// CHECK: %[[CONV2D:[0-9]*]] = "tf.Conv2D"(%[[ARG_TRANSPOSE]], %arg1)
// CHECK-SAME: data_format = "NCHW"
// CHECK-SAME: dilations = [1, 4, 2, 3]
// CHECK-SAME: explicit_paddings = [1, 2, 7, 8, 3, 4, 5, 6]
// CHECK-SAME: padding = "EXPLICIT"
// CHECK-SAME: strides = [5, 8, 6, 7]
// CHECK-SAME: (tensor<1x3x32x32xf32>, tensor<1x1x3x8xf32>) -> tensor<1x8x32x32xf32>
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[CONV2D]], %[[RES_PERM]])
// CHECK: return %[[RES_TRANSPOSE]]
%0 = "tf.Conv2D"(%input, %filter)
{
data_format = "NHWC",
dilations = [1, 2, 3, 4],
explicit_paddings = [1, 2, 3, 4, 5, 6, 7, 8],
padding = "EXPLICIT",
strides = [5, 6, 7, 8]
} : (tensor<1x32x32x3xf32>, tensor<1x1x3x8xf32>) -> tensor<1x32x32x8xf32>
return %0 : tensor<1x32x32x8xf32>
}

View File

@ -0,0 +1,35 @@
// RUN: tf-opt %s -tf-layout-assignment=force-data-format=NHWC -verify-diagnostics | FileCheck %s --dump-input=always
// CHECK-LABEL: func @transposeConv2D
func @transposeConv2D(%input: tensor<1x3x32x32xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<1x8x32x32xf32> {
// IMPORTANT: Tensor shapes do not match convolution parameters (stride,
// dilations, etc...). This test only verifies that changing convolution data
// layout will update all the attributes.
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
// CHECK: %[[CONV2D:[0-9]*]] = "tf.Conv2D"(%[[ARG_TRANSPOSE]], %arg1)
// CHECK-SAME: data_format = "NHWC"
// CHECK-SAME: dilations = [1, 3, 4, 2]
// CHECK-SAME: explicit_paddings = [1, 2, 5, 6, 7, 8, 3, 4]
// CHECK-SAME: padding = "EXPLICIT"
// CHECK-SAME: strides = [5, 7, 8, 6]
// CHECK-SAME: (tensor<1x32x32x3xf32>, tensor<1x1x3x8xf32>) -> tensor<1x32x32x8xf32>
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[CONV2D]], %[[RES_PERM]])
// CHECK: return %[[RES_TRANSPOSE]]
%0 = "tf.Conv2D"(%input, %filter)
{
data_format = "NCHW",
dilations = [1, 2, 3, 4],
explicit_paddings = [1, 2, 3, 4, 5, 6, 7, 8],
padding = "EXPLICIT",
strides = [5, 6, 7, 8]
} : (tensor<1x3x32x32xf32>, tensor<1x1x3x8xf32>) -> tensor<1x8x32x32xf32>
return %0 : tensor<1x8x32x32xf32>
}

View File

@ -0,0 +1,67 @@
// RUN: tf-opt %s -tf-move-transposes=direction=begin -verify-diagnostics | FileCheck %s --dump-input=always
// CHECK-LABEL: func @move_across_single_op
func @move_across_single_op(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
// CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
// CHECK: return %[[TANH]]
%0 = "tf.Tanh"(%arg0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
%1 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
%2 = "tf.Transpose"(%0, %1) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
return %2 : tensor<1x8x4x4xf32>
}
// CHECK-LABEL: func @move_across_multiple_ops
func @move_across_multiple_ops(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
// CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
// CHECK: %[[RELU:[0-9]*]] = "tf.Relu"(%[[TANH]]) {{.*}} tensor<1x8x4x4xf32>
// CHECK: return %[[RELU]]
%0 = "tf.Tanh"(%arg0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
%1 = "tf.Relu"(%0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
%2 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
%3 = "tf.Transpose"(%1, %2) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
return %3 : tensor<1x8x4x4xf32>
}
// CHECK-LABEL: func @move_across_multi_operand_op
func @move_across_multi_operand_op(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
// CHECK: %[[ARG0_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
// CHECK: %[[ARG1_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg1, %[[ARG_PERM]])
// CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%[[ARG0_TRANSPOSE]], %[[ARG1_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
// CHECK: return %[[ADD]]
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
%1 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
%2 = "tf.Transpose"(%0, %1) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
return %2 : tensor<1x8x4x4xf32>
}
// CHECK-LABEL: func @move_with_multiple_uses
func @move_with_multiple_uses(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
// CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
// CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%[[TANH]], %[[TANH]]) {{.*}} tensor<1x8x4x4xf32>
// CHECK: return %[[ADD]]
%0 = "tf.Tanh"(%arg0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
%1 = "tf.AddV2"(%0, %0) : (tensor<1x4x4x8xf32>, tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
%2 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
%3 = "tf.Transpose"(%1, %2) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
return %3 : tensor<1x8x4x4xf32>
}

View File

@ -0,0 +1,120 @@
// RUN: tf-opt %s -tf-move-transposes=direction=end -verify-diagnostics | FileCheck %s --dump-input=always
// CHECK-LABEL: func @move_across_single_op
func @move_across_single_op(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
// CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%arg0) {{.*}} tensor<1x4x4x8xf32>
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[TANH]], %[[RES_PERM]]) {{.*}} tensor<1x8x4x4xf32>
// CHECK: return %[[RES_TRANSPOSE]]
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
%2 = "tf.Tanh"(%1) : (tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32>
return %2 : tensor<1x8x4x4xf32>
}
// CHECK-LABEL: func @move_across_multiple_ops
func @move_across_multiple_ops(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
// CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%arg0) {{.*}} tensor<1x4x4x8xf32>
// CHECK: %[[RELU:[0-9]*]] = "tf.Relu"(%[[TANH]]) {{.*}} tensor<1x4x4x8xf32>
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[RELU]], %[[RES_PERM]])
// CHECK: return %[[RES_TRANSPOSE]]
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
%2 = "tf.Tanh"(%1) : (tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32>
%3 = "tf.Relu"(%2) : (tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32>
return %3 : tensor<1x8x4x4xf32>
}
// CHECK-LABEL: func @move_across_multi_operand_op
func @move_across_multi_operand_op(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
// CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%arg0, %arg1) {{.*}} tensor<1x4x4x8xf32>
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[ADD]], %[[RES_PERM]])
// CHECK: return %[[RES_TRANSPOSE]]
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
%2 = "tf.Transpose"(%arg1, %0) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
%3 = "tf.AddV2"(%1, %2) : (tensor<1x8x4x4xf32>, tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32>
return %3 : tensor<1x8x4x4xf32>
}
// CHECK-LABEL: func @fold_into_max_pool
func @fold_into_max_pool(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x56x56x64xf32> {
// MaxPool operand transpose must be folded into the op and MaxPool
// must use NCHW data format with updated kernel size and strides.
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
// CHECK: %[[MAX_POOL:[0-9]*]] = "tf.MaxPool"(%arg0) {data_format = "NCHW", ksize = [1, 1, 3, 3], padding = "SAME", strides = [1, 1, 2, 2]} : (tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xf32>
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[MAX_POOL]], %[[RES_PERM]])
// CHECK: return %[[RES_TRANSPOSE]]
// Transpose NCHW -> NHWC
%0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x64x112x112xf32>, tensor<4xi64>) -> tensor<1x112x112x64xf32>
// Compute MaxPool in NHWC format
%2 = "tf.MaxPool"(%1)
{
data_format = "NHWC", ksize = [1, 3, 3, 1],
padding = "SAME", strides = [1, 2, 2, 1]
} : (tensor<1x112x112x64xf32>) -> tensor<1x56x56x64xf32>
return %2 : tensor<1x56x56x64xf32>
}
// CHECK-LABEL: func @fold_into_mean
func @fold_into_mean(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x64xf32> {
// CHECK: %[[RED_IDX:[0-9]*]] = "tf.Const"() {value = dense<[2, 3]> : tensor<2xi64>}
// CHECK: %[[MEAN:[0-9]*]] = "tf.Mean"(%arg0, %[[RED_IDX]])
// CHECK-SAME: (tensor<1x64x112x112xf32>, tensor<2xi64>) -> tensor<1x64xf32>
// CHECK: return %[[MEAN]]
// Transpose NCHW -> NHWC
%0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x64x112x112xf32>, tensor<4xi64>) -> tensor<1x112x112x64xf32>
// Compute Mean over spatial dimensions in NHWC format.
%2 = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
%3 = "tf.Mean"(%1, %2) : (tensor<1x112x112x64xf32>, tensor<2xi64>) -> tensor<1x64xf32>
return %3 : tensor<1x64xf32>
}
// CHECK-LABEL: func @fold_into_fused_batch_norm
func @fold_into_fused_batch_norm(%arg0: tensor<1x64x112x112xf32>, %arg1: tensor<64xf32>) -> tensor<1x112x112x64xf32> {
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
// CHECK: "tf.FusedBatchNormV3"(%arg0, {{.*}} {data_format = "NCHW"
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%y, %[[RES_PERM]])
// CHECK: return %[[RES_TRANSPOSE]]
// Transpose NCHW -> NHWC
%0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x64x112x112xf32>, tensor<4xi64>) -> tensor<1x112x112x64xf32>
// Compute FusedBatchNormV3 in NHWC format
%2, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3
= "tf.FusedBatchNormV3"(%1, %arg1, %arg1, %arg1, %arg1)
{
data_format = "NHWC",
epsilon = 1.001 : f32,
exponential_avg_factor = 1.0 : f32,
is_training = false
}
: (tensor<1x112x112x64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
-> (tensor<1x112x112x64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
return %2#0 : tensor<1x112x112x64xf32>
}

View File

@ -0,0 +1,194 @@
// RUN: tf-opt %s -tf-parallel-execute-to-islands | FileCheck %s --dump-input=fail
// CHECK-LABEL: func @check_regions_to_islands
func @check_regions_to_islands() {
tf_executor.graph {
tf_executor.island() {
"tf_device.parallel_execute"() ({
tf_device.return
},
{
tf_device.return
}) {} : () -> ()
tf_executor.yield
}
tf_executor.fetch
}
return
}
// CHECK: %[[ISLAND_INPUT_CTL:[a-z_0-9]*]] = tf_executor.island {
// CHECK-NEXT: tf_executor.yield
// CHECK: %[[ISLAND_1_CTL:[a-z_0-9]*]] = tf_executor.island(%[[ISLAND_INPUT_CTL]]) {
// CHECK: tf_executor.yield
// CHECK: %[[ISLAND_2_CTL:[a-z_0-9]*]] = tf_executor.island(%[[ISLAND_INPUT_CTL]]) {
// CHECK: tf_executor.yield
// CHECK: %{{.*}} = tf_executor.island(%[[ISLAND_1_CTL]], %[[ISLAND_2_CTL]]) {
// CHECK-NEXT: tf_executor.yield
// CHECK-LABEL: func @check_regions_to_islands_with_inputs
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
func @check_regions_to_islands_with_inputs(%arg0 : tensor<i1>) {
tf_executor.graph {
%1:2 = tf_executor.island {
%2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
tf_executor.yield %2 : tensor<i1>
}
tf_executor.island() {
"tf_device.parallel_execute"() ({
%3 = "tf.opB"(%1#0) : (tensor<i1>) -> tensor<i1>
tf_device.return %3 : tensor<i1>
},
{
%5 = "tf.opC"(%1#0) : (tensor<i1>) -> tensor<i32>
tf_device.return %5 : tensor<i32>
}) {} : () -> (tensor<i1>, tensor<i32>)
tf_executor.yield
}
tf_executor.fetch
}
return
}
// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island {
// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor<i1>) -> tensor<i1>
// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor<i1>
// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CONTROL:[a-z_0-9]*]] = tf_executor.island {
// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor<i1>
// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %[[ISLAND_1_CTL:[a-z_0-9]*]] = tf_executor.island {
// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"(%[[INPUT_0]]) : (tensor<i1>) -> tensor<i1>
// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor<i1>
// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %[[ISLAND_2_CTL:[a-z_0-9]*]] = tf_executor.island {
// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"(%outputs_0) : (tensor<i1>) -> tensor<i32>
// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor<i32>
// CHECK: %{{.*}} = tf_executor.island(%[[ISLAND_1_CTL]], %[[ISLAND_2_CTL]]) {
// CHECK-NEXT: tf_executor.yield
// CHECK-LABEL: func @check_input_sink_island_forwards_control_inputs
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
func @check_input_sink_island_forwards_control_inputs(%arg0 : tensor<i1>) {
tf_executor.graph {
%1:2 = tf_executor.island {
%2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
tf_executor.yield %2 : tensor<i1>
}
%7 = tf_executor.ControlTrigger {}
%8 = tf_executor.ControlTrigger {}
tf_executor.island(%7, %8) {
"tf_device.parallel_execute"() ({
%3 = "tf.opB"(%1#0) : (tensor<i1>) -> tensor<i1>
tf_device.return %3 : tensor<i1>
},
{
%5 = "tf.opC"() : () -> tensor<i32>
tf_device.return %5 : tensor<i32>
}) {} : () -> (tensor<i1>, tensor<i32>)
tf_executor.yield
}
tf_executor.fetch
}
return
}
// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island {
// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor<i1>) -> tensor<i1>
// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor<i1>
// CHECK: %[[CT_0:[0-9]*]] = tf_executor.ControlTrigger
// CHECK: %[[CT_1:[0-9]*]] = tf_executor.ControlTrigger
// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CONTROL:[a-z_0-9]*]] = tf_executor.island(%[[CT_0]], %[[CT_1]]) {
// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor<i1>
// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %[[ISLAND_1_CTL:[a-z_0-9]*]] = tf_executor.island {
// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"(%[[INPUT_0]]) : (tensor<i1>) -> tensor<i1>
// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor<i1>
// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %[[ISLAND_2_CTL:[a-z_0-9]*]] = tf_executor.island(%[[INPUT_CONTROL]]) {
// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"() : () -> tensor<i32>
// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor<i32>
// CHECK: %{{.*}} = tf_executor.island(%[[ISLAND_1_CTL]], %[[ISLAND_2_CTL]]) {
// CHECK-NEXT: tf_executor.yield
// CHECK-LABEL: func @check_control_dep_added_when_region_does_not_have_inputs
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
func @check_control_dep_added_when_region_does_not_have_inputs(%arg0 : tensor<i1>) {
tf_executor.graph {
%1:2 = tf_executor.island {
%2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
tf_executor.yield %2 : tensor<i1>
}
%7:3 = tf_executor.island() {
%8:2 = "tf_device.parallel_execute"() (
{
%3 = "tf.opB"() : () -> tensor<i1>
tf_device.return %3 : tensor<i1>
},
{
%5 = "tf.opC"(%1#0) : (tensor<i1>) -> tensor<i32>
tf_device.return %5 : tensor<i32>
}
) {} : () -> (tensor<i1>, tensor<i32>)
tf_executor.yield %8#0, %8#1 : tensor<i1>, tensor<i32>
}
tf_executor.island {
"tf.opD"(%7#0, %7#1) : (tensor<i1>, tensor<i32>) -> ()
tf_executor.yield
}
tf_executor.fetch
}
return
}
// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island {
// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor<i1>) -> tensor<i1>
// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor<i1>
// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CTL:[a-z_0-9]*]] = tf_executor.island {
// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor<i1>
// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island(%[[INPUT_CTL]]) {
// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"() : () -> tensor<i1>
// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor<i1>
// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island {
// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"(%outputs_0) : (tensor<i1>) -> tensor<i32>
// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor<i32>
// CHECK: %{{.*}} = tf_executor.island {
// CHECK-NEXT: tf_executor.yield %[[ISLAND_1_OUTPUT]], %[[ISLAND_2_OUTPUT]]
// CHECK-LABEL: func @check_output_barrier_correctly_forwards_outputs
func @check_output_barrier_correctly_forwards_outputs(%arg0 : tensor<i1>) -> tensor<i1> {
%0 = tf_executor.graph {
%1:2 = tf_executor.island {
%2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
tf_executor.yield %2 : tensor<i1>
}
%8:3 = tf_executor.island() {
%7:2 = "tf_device.parallel_execute"() ({
%3 = "tf.opB"() : () -> tensor<i1>
tf_device.return %3 : tensor<i1>
},
{
%5 = "tf.opC"(%1#0) : (tensor<i1>) -> tensor<i32>
tf_device.return %5 : tensor<i32>
}) {} : () -> (tensor<i1>, tensor<i32>)
tf_executor.yield %7#0, %7#1 : tensor<i1>, tensor<i32>
}
tf_executor.fetch %8#0 : tensor<i1>
}
return %0 : tensor<i1>
}
// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island {
// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor<i1>) -> tensor<i1>
// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor<i1>
// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CTL:[a-z_0-9]*]] = tf_executor.island {
// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor<i1>
// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island(%[[INPUT_CTL]]) {
// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"() : () -> tensor<i1>
// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor<i1>
// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island {
// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"(%[[INPUT_0]]) : (tensor<i1>) -> tensor<i32>
// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor<i32>
// CHECK: %[[OUTPUT_SINK_OUTPUT:[a-z_0-9]*]]:2, %[[OUTPUT_SINK_CTL:[a-z_0-9]*]] = tf_executor.island {
// CHECK-NEXT: tf_executor.yield %[[ISLAND_1_OUTPUT]], %[[ISLAND_2_OUTPUT]] : tensor<i1>, tensor<i32>

View File

@ -542,3 +542,116 @@ func @if_else(%arg0: tensor<*x!tf.resource<tensor<4xf32>>>, %arg1: tensor<*x!tf.
-> (tensor<*x!tf.resource<tensor<4xf32>>>) {
return %arg1 : tensor<*x!tf.resource<tensor<4xf32>>>
}
// -----
// Tests that the pass lifts resources on two partitioned call ops sharing the
// same callee. The lifting should clone the callee then modify the clone.
// CHECK-LABEL: @launch_with_partitioned_call
func @launch_with_partitioned_call() -> tensor<f32> {
// CHECK: %[[VH:.*]] = "tf.VarHandleOp"()
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
// CHECK: %[[CONST:.*]] = "tf.Const"()
%1 = "tf.Const"() {value = dense<10.0> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VH]])
// CHECK: %[[LAUNCH:.*]] = "tf_device.launch"()
%2 = "tf_device.launch"() ( {
// CHECK: %[[PC0:.*]] = "tf.PartitionedCall"(%[[CONST]], %[[READ]], %[[CONST]])
// CHECK-SAME: f = @callee_resource_lifted
%3 = "tf.PartitionedCall"(%1, %0, %1) {f = @callee, config = "", config_proto = "", executor_type = ""}
: (tensor<f32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<f32>
// CHECK: %[[PC1:.*]] = "tf.PartitionedCall"(%[[CONST]], %[[READ]], %[[CONST]])
// CHECK-SAME: f = @callee_resource_lifted
%4 = "tf.PartitionedCall"(%1, %0, %1) {f = @callee, config = "", config_proto = "", executor_type = ""}
: (tensor<f32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<f32>
// CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[PC0]], %[[PC1]])
%5 = "tf.AddV2"(%3, %4) : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: tf_device.return %[[ADD]] : tensor<f32>
tf_device.return %5 : tensor<f32>
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<f32>
return %2 : tensor<f32>
}
// CHECK: @callee(%[[OA0:.*]]: tensor<f32>, %[[OA1:.*]]: tensor<*x!tf.resource<tensor<f32>>>, %[[OA2:.*]]: tensor<f32>) -> tensor<f32>
func @callee(%arg0: tensor<f32>, %arg1: tensor<*x!tf.resource<tensor<f32>>>, %arg2: tensor<f32>) -> tensor<f32> {
// CHECK: "tf.ReadVariableOp"(%[[OA1]])
%0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
%1 = "tf.AddV2"(%0, %arg0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%2 = "tf.AddV2"(%1, %arg2) : (tensor<f32>, tensor<f32>) -> tensor<f32>
return %2 : tensor<f32>
}
// CHECK: func @callee_resource_lifted(%[[A0:.*]]: tensor<f32>, %[[A1:.*]]: tensor<f32>, %[[A2:.*]]: tensor<f32>) -> tensor<f32>
// CHECK-NEXT: %[[ADD0:.*]] = "tf.AddV2"(%[[A1]], %[[A0]])
// CHECK-NEXT: %[[ADD1:.*]] = "tf.AddV2"(%[[ADD0]], %[[A2]])
// CHECK-NEXT: return %[[ADD1]]
// -----
// Tests that the pass lifts resources on two stateful partitioned call ops
// sharing the same callee. The lifting should clone the callee then modify the
// clone.
// CHECK-LABEL: @launch_with_stateful_partitioned_call
func @launch_with_stateful_partitioned_call() -> () {
// CHECK: %[[VH0:.*]] = "tf.VarHandleOp"()
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
// CHECK: %[[VH1:.*]] = "tf.VarHandleOp"()
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource<tensor<f32>>>
// CHECK: %[[CONST:.*]] = "tf.Const"()
%2 = "tf.Const"() {value = dense<10.0> : tensor<f32>} : () -> tensor<f32>
// CHECK-DAG: %[[READ0:.*]] = "tf.ReadVariableOp"(%[[VH0]])
// CHECK-DAG: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[VH1]])
// CHECK: %[[LAUNCH:.*]] = "tf_device.launch"()
"tf_device.launch"() ( {
// CHECK: %[[PC0:.*]] = "tf.StatefulPartitionedCall"(%[[READ0]], %[[READ1]], %[[CONST]])
// CHECK-SAME: f = @callee_resource_lifted
%3 = "tf.StatefulPartitionedCall"(%0, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""}
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
// CHECK: %[[PC1:.*]] = "tf.StatefulPartitionedCall"(%[[PC0]], %[[READ1]], %[[CONST]])
// CHECK-SAME: f = @callee_resource_lifted
%4 = "tf.StatefulPartitionedCall"(%3, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""}
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
// CHECK: tf_device.return %[[PC1]] : tensor<f32>
tf_device.return
// CHECK: {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<f32>
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
// CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[LAUNCH]])
return
}
// CHECK: @callee(%[[OA0:.*]]: tensor<*x!tf.resource<tensor<f32>>>, %[[OA1:.*]]: tensor<*x!tf.resource<tensor<f32>>>, %[[OA2:.*]]: tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
func @callee(%arg0: tensor<*x!tf.resource<tensor<f32>>>, %arg1: tensor<*x!tf.resource<tensor<f32>>>, %arg2: tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>> {
// CHECK: "tf.ReadVariableOp"(%[[OA1]])
%0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
%1 = "tf.AddV2"(%0, %arg2) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"tf.AssignVariableOp"(%arg0, %1) {dtype = i32} : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
return %arg0 : tensor<*x!tf.resource<tensor<f32>>>
}
// CHECK: func @callee_resource_lifted(%[[A0:.*]]: tensor<f32>, %[[A1:.*]]: tensor<f32>, %[[A2:.*]]: tensor<f32>) -> tensor<f32>
// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[A1]], %[[A2]])
// CHECK-NEXT: return %[[ADD]]
// -----
// Tests that the pass reports error on called function that has resource output
// which doesn't alias an input.
func @launch_with_stateful_partitioned_call() -> () {
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource<tensor<f32>>>
%2 = "tf.Const"() {value = dense<10.0> : tensor<f32>} : () -> tensor<f32>
"tf_device.launch"() ( {
%3 = "tf.StatefulPartitionedCall"(%0, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""}
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
%4 = "tf.StatefulPartitionedCall"(%3, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""}
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
tf_device.return
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
return
}
// expected-error @+1 {{Unsupported function call: resource return value does not alias an input.}}
func @callee(%arg0: tensor<*x!tf.resource<tensor<f32>>>, %arg1: tensor<*x!tf.resource<tensor<f32>>>, %arg2: tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>> {
%0 = "tf._Unknown_"() : () -> tensor<*x!tf.resource<tensor<f32>>>
return %0 : tensor<*x!tf.resource<tensor<f32>>>
}

View File

@ -45,6 +45,17 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
return %1 : tensor<*xf32>
}
// CHECK-LABEL: func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<?xf32>
func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
br ^bb1
^bb1:
// CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK: return %[[IDENTITY]] : tensor<?xf32>
%ret = "tf.Identity"(%arg0) : (tensor<?xf32>) -> tensor<*xf32>
return %ret : tensor<*xf32>
}
// Tests the case where an inference opportunity relies on folding.
// CHECK-LABEL: func @simple_folding

View File

@ -46,7 +46,7 @@ class TestModule(tf.Module):
# CHECK: "tf_saved_model.global_tensor"() {sym_name = "[[CONST:[a-zA-Z_0-9]+]]", tf_saved_model.exported_names = [], type = tensor<f32>, value = dense<4.300000e+01> : tensor<f32>} : () -> ()
# CHECK: func {{@[a-zA-Z_0-9]+}}(
# CHECK-SAME: %arg0: tensor<f32> {tf_saved_model.index_path = [0]},
# CHECK-SAME: %arg1: tensor<*x!tf.resource> {tf_saved_model.bound_input = @[[VAR]]},
# CHECK-SAME: %arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @[[VAR]]},
# CHECK-SAME: %arg2: tensor<f32> {tf_saved_model.bound_input = @[[CONST]]}) -> (
# CHECK-SAME: tensor<f32> {tf_saved_model.index_path = []})
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["some_function"]

View File

@ -46,7 +46,7 @@ class TestModule(tf.Module):
#
# CHECK: func {{@[a-zA-Z_0-9]+}}(
# CHECK-SAME: %arg0: tensor<f32> {tf_saved_model.index_path = [0]},
# CHECK-SAME: %arg1: tensor<*x!tf.resource> {tf_saved_model.bound_input = {{@[a-zA-Z_0-9]+}}}
# CHECK-SAME: %arg1: tensor<!tf.resource<{{.*}}>> {tf_saved_model.bound_input = {{@[a-zA-Z_0-9]+}}}
# CHECK-SAME: ) -> (
# CHECK-SAME: tensor<f32> {tf_saved_model.index_path = [0]},
# CHECK-SAME: tensor<f32> {tf_saved_model.index_path = [1]})
@ -55,7 +55,7 @@ class TestModule(tf.Module):
#
# CHECK: func {{@[a-zA-Z_0-9]+}}(
# CHECK-SAME: %arg0: tensor<f32> {tf_saved_model.index_path = [0]},
# CHECK-SAME: %arg1: tensor<*x!tf.resource> {tf_saved_model.bound_input = {{@[a-zA-Z_0-9]+}}}
# CHECK-SAME: %arg1: tensor<!tf.resource<{{.*}}>> {tf_saved_model.bound_input = {{@[a-zA-Z_0-9]+}}}
# CHECK-SAME: ) -> (
# CHECK-SAME: tensor<f32> {tf_saved_model.index_path = [0]},
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [1]})

View File

@ -25,8 +25,8 @@ module attributes {tf_saved_model.semantics} {
// CHECK: tf_saved_model.global_tensor
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<1.0> : tensor<f32> } : () -> ()
// CHECK: func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v})
func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v})
// CHECK: func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
attributes {tf_saved_model.exported_names = ["f"]} {
// CHECK-NOT: tf.Const
return

View File

@ -26,7 +26,7 @@ module attributes {tf_saved_model.semantics} {
func @__concrete_function_run_computation(
%arg0: tensor<f32> {tf_saved_model.index_path = [0, "foo"]},
%arg1: tensor<1x64xf32> {tf_saved_model.bound_input = @some_constant},
%arg2: tensor<*x!tf.resource> {tf_saved_model.bound_input = @some_variable}
%arg2: tensor<!tf.resource<tensor<?x64xf32>>> {tf_saved_model.bound_input = @some_variable}
) -> (
tensor<f32> {tf_saved_model.index_path = [0, "bar"]}
) attributes { tf_saved_model.exported_names = ["some_func"] }

View File

@ -219,8 +219,8 @@ module attributes {tf_saved_model.semantics} {
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.0> : tensor<f32> } : () -> ()
// expected-error@+1 {{duplicate 'tf_saved_model.bound_input' binding}}
func @f(
%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v},
%arg1: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}
%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v},
%arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}
) attributes {tf_saved_model.exported_names = ["f"]} {
return
}
@ -232,9 +232,9 @@ module attributes {tf_saved_model.semantics} {
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<?xf32>, value = dense<1.> : tensor<1xf32> } : () -> ()
// expected-error@+1 {{can only apply 'tf_saved_model' argument attributes to exported functions}}
func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v})
func @f(%arg0: tensor<!tf.resource<tensor<?xf32>>> {tf_saved_model.bound_input = @v})
-> (tensor<?xf32> {tf_saved_model.index_path = []}) {
%0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor<?xf32>
%0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<?xf32>>>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
}
@ -244,7 +244,7 @@ module attributes {tf_saved_model.semantics} {
module attributes {tf_saved_model.semantics} {
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<?xf32>, value = dense<1.> : tensor<1xf32> } : () -> ()
// expected-error@+1 {{bound inputs for mutable 'tf_saved_model.global_tensor's must be tensors of '!tf.resource'}}
// expected-error@+1 {{mutable bound input with type 'tensor<f32>' expected to have type 'tensor<!tf.resource<tensor<?xf32>>>'}}
func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @v})
attributes {tf_saved_model.exported_names = ["f"]} {
return
@ -257,7 +257,7 @@ module attributes {tf_saved_model.semantics} {
"tf_saved_model.global_tensor"() { sym_name = "v", type = tensor<1xf32>, value = dense<1.> : tensor<1xf32> } : () -> ()
// expected-error@+1 {{bound input for immutable 'tf_saved_model.global_tensor' must match the global tensor's type}}
func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v})
func @f(%arg0: tensor<!tf.resource<tensor<1xf32>>> {tf_saved_model.bound_input = @v})
attributes {tf_saved_model.exported_names = ["f"]} {
return
}

View File

@ -14,10 +14,10 @@ module attributes {tf_saved_model.semantics} {
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
// CHECK: func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @v})
func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
attributes {tf_saved_model.exported_names = ["f"]} {
// CHECK-NOT: tf.ReadVariableOp
%val = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor<f32>
%val = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
// CHECK: return %arg0
return %val : tensor<f32>
}
@ -35,12 +35,12 @@ module attributes {tf_saved_model.semantics} {
// CHECK-SAME: } : () -> ()
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
// CHECK: func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v})
func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v})
// CHECK: func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
attributes {tf_saved_model.exported_names = ["f"]} {
%c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
// CHECK: tf.AssignVariableOp
"tf.AssignVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> ()
"tf.AssignVariableOp"(%arg0, %c0) : (tensor<!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
return
}
@ -57,10 +57,10 @@ module attributes {tf_saved_model.semantics} {
// CHECK-SAME: } : () -> ()
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", tf_saved_model.exported_names = ["v"], type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
// CHECK: func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v})
func @f(%arg0: tensor<*x!tf.resource> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
// CHECK: func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
attributes {tf_saved_model.exported_names = ["f"]} {
%val = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor<f32>
%val = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
return %val : tensor<f32>
}

Some files were not shown because too many files have changed in this diff Show More