Merge branch 'master' into add_xnnpack_to_label_image_again

This commit is contained in:
Koan-Sin Tan 2020-04-01 09:16:25 +08:00
commit 9d022cb885
973 changed files with 26545 additions and 11286 deletions

View File

@ -19,10 +19,10 @@
# Compiler options:
# cuda_clang: Use clang when building CUDA code.
# c++17: Build with C++17 options
# C++1z: Build with C++17 options
# c++1z: Build with C++17 options
# avx_linux: Build with avx instruction set on linux.
# avx2_linux: Build with avx2 instruction set on linux.
# arch_native_linux: Build with instruction sets available to the host machine on linux
# native_arch_linux: Build with instruction sets available to the host machine on linux
# avx_win: Build with avx instruction set on windows
# avx2_win: Build with avx2 instruction set on windows
#

View File

@ -88,6 +88,9 @@ TensorFlow coding style.
submitting PRs to fix one typo, one warning,etc. We recommend fixing the
same issue at the file level at least (e.g.: fix all typos in a file, fix
all compiler warning in a file, etc.)
* Tests should follow the
[testing best practices](https://www.tensorflow.org/community/contribute/tests)
guide.
#### License

View File

@ -135,6 +135,7 @@ Build Type | Status
* [TensorFlow Examples](https://github.com/tensorflow/examples)
* [TensorFlow in Practice from Coursera](https://www.coursera.org/specializations/tensorflow-in-practice)
* [TensorFlow: Data and Deployment from Coursera](https://www.coursera.org/specializations/tensorflow-data-and-deployment)
* [Getting Started with TensorFlow 2 from Coursera](https://www.coursera.org/learn/getting-started-with-tensor-flow2)
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
* [TensorFlow Blog](https://blog.tensorflow.org)

View File

@ -41,6 +41,11 @@ import sys as _sys
from tensorflow.python.tools import module_util as _module_util
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
# Make sure code inside the TensorFlow codebase can use tf2.enabled() at import.
_os.environ['TF2_BEHAVIOR'] = '1'
from tensorflow.python import tf2 as _tf2
_tf2.enable()
# API IMPORTS PLACEHOLDER
# WRAPPER_PLACEHOLDER

View File

@ -683,7 +683,11 @@ TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
tensorflow::Tensor tensor(dtype, tensorflow::TensorShape({}));
std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
return TFE_TensorHandle::CreateLocalHandle(tensor, status);
status->status = tensorflow::Status::OK();
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(
tensorflow::TensorHandle::CreateLocalHandle(tensor))};
}
namespace {
@ -703,7 +707,8 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
} while (0);
// New server created for new server_def. Unused if updating server_def.
tensorflow::EagerContext* context = ctx->context;
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::GrpcServer* grpc_server =
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
if (grpc_server == nullptr) {
@ -822,8 +827,7 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
for (int i = 0; i < num_inputs; ++i) {
node_def.add_input("dummy_input");
}
tensorflow::down_cast<tensorflow::OperationInterface*>(
tfe_op->operation.get())
OperationFromInterface(tfe_op->operation)
->Attrs()
.FillAttrValueMap(node_def.mutable_attr());

View File

@ -28,6 +28,9 @@ tf_cuda_library(
"c_api_debug.cc",
"c_api_experimental.h",
"c_api_internal.h",
"c_api_unified_experimental.h",
"context_interface.cc",
"context_interface.h",
"operation_interface.cc",
"operation_interface.h",
"tensor_handle_interface.h",
@ -62,6 +65,7 @@ tf_cuda_library(
"//tensorflow/core/platform:errors",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/types:variant",
],
}) + select({
"//tensorflow:with_xla_support": [
@ -95,6 +99,8 @@ filegroup(
srcs = [
"c_api_experimental.h",
"c_api_internal.h",
"c_api_unified_experimental.h",
"context_interface.h",
"dlpack.h",
"operation_interface.h",
"tensor_handle_interface.h",
@ -109,6 +115,8 @@ tf_cuda_library(
name = "c_api_internal",
srcs = [
"c_api_experimental.h",
"c_api_unified_experimental.h",
"context_interface.h",
"operation_interface.h",
"tensor_handle_interface.h",
],
@ -206,7 +214,6 @@ tf_cuda_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/platform:casts",
"@com_google_absl//absl/strings",
],
)
@ -215,8 +222,12 @@ tf_cuda_library(
name = "c_api_experimental",
srcs = [
"c_api_experimental.cc",
"c_api_unified_experimental.cc",
],
hdrs = [
"c_api_experimental.h",
"c_api_unified_experimental.h",
],
hdrs = ["c_api_experimental.h"],
copts = tf_copts() + tfe_xla_copts(),
visibility = ["//visibility:public"],
deps = select({
@ -242,6 +253,7 @@ tf_cuda_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:variant",
],
}) + select({
"//tensorflow:with_xla_support": [
@ -293,6 +305,30 @@ tf_cuda_cc_test(
],
)
tf_cuda_cc_test(
name = "c_api_unified_experimental_test",
size = "small",
srcs = [
"c_api_unified_experimental_test.cc",
],
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["nomac"],
deps = [
":c_api",
":c_api_experimental",
":c_api_test_util",
"//tensorflow/c:c_test_util",
"//tensorflow/cc/profiler",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_google_absl//absl/strings",
],
)
tf_cc_test(
name = "custom_device_test",
size = "small",

View File

@ -305,7 +305,9 @@ tensorflow::Status CreateRemoteContexts(
server_def.default_session_config());
std::vector<bool> filtered_device_mask;
ctx->context->FilterDevicesForRemoteWorkers(
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->FilterDevicesForRemoteWorkers(
remote_worker, base_request.cluster_device_attributes(),
&filtered_device_mask);
DCHECK_EQ(filtered_device_mask.size(),
@ -388,7 +390,9 @@ tensorflow::Status UpdateRemoteContexts(
}
std::vector<bool> filtered_device_mask;
ctx->context->FilterDevicesForRemoteWorkers(
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->FilterDevicesForRemoteWorkers(
remote_worker, base_request.cluster_device_attributes(),
&filtered_device_mask);
DCHECK_EQ(filtered_device_mask.size(), cluster_device_count);
@ -467,7 +471,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
// New server created for new server_def. Unused if updating server_def.
std::unique_ptr<tensorflow::ServerInterface> new_server;
tensorflow::EagerContext* context = ctx->context;
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::GrpcServer* grpc_server;
if (reset_context) {
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
@ -696,14 +701,16 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get());
return new TFE_Context{new tensorflow::EagerContext(
opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
/*device_mgr_owned*/ true, r,
tensorflow::GetDefaultCustomKernelCreator())};
return new TFE_Context{std::make_unique<tensorflow::ContextInterface>(
new tensorflow::EagerContext(
opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(
opts->mirroring_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
/*device_mgr_owned*/ true, r,
tensorflow::GetDefaultCustomKernelCreator()))};
}
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
@ -714,20 +721,24 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr);
return new TFE_Context{new tensorflow::EagerContext(
opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
/*device_mgr_owned*/ false, r,
tensorflow::GetDefaultCustomKernelCreator())};
return new TFE_Context{std::make_unique<tensorflow::ContextInterface>(
new tensorflow::EagerContext(
opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(
opts->mirroring_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
/*device_mgr_owned*/ false, r,
tensorflow::GetDefaultCustomKernelCreator()))};
}
void TFE_DeleteContext(TFE_Context* ctx) {
// context->RefCountIsOne() should be true here.
// TODO(iga): Remove EagerContext refcounting.
ctx->context->Unref();
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->Unref();
delete ctx;
}
@ -739,7 +750,9 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
}
void TFE_ContextClearCaches(TFE_Context* ctx) {
ctx->context->ClearCachesAndThreadExecutors();
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->ClearCachesAndThreadExecutors();
}
// Set server_def on the context, possibly updating it.
@ -769,8 +782,10 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
device_filters[i] = tdf.second.device_filters(i);
}
const string remote_worker = remote_prefix + std::to_string(task_index);
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
status->status =
ctx->context->SetRemoteDeviceFilters(remote_worker, device_filters);
context->SetRemoteDeviceFilters(remote_worker, device_filters);
}
}
}
@ -789,11 +804,13 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
"TFE_ContextSetServerDef not supported on mobile");
#else // !defined(IS_MOBILE_PLATFORM)
tensorflow::ServerDef server_def;
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
if (!server_def.ParseFromArray(proto, proto_len)) {
status->status = tensorflow::errors::InvalidArgument(
"Invalid tensorflow.ServerDef protocol buffer");
return;
} else if (ctx->context->GetContextId() ==
} else if (context->GetContextId() ==
tensorflow::EagerContext::kInvalidContextId) {
status->status = tensorflow::errors::InvalidArgument(
"Trying to update a context with invalid context id.");
@ -817,7 +834,8 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
"TFE_ContextSetServerDef not supported on mobile");
return false;
#else // !defined(IS_MOBILE_PLATFORM)
tensorflow::EagerContext* context = ctx->context;
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::GrpcServer* grpc_server =
static_cast<tensorflow::GrpcServer*>(context->GetServer());
@ -872,13 +890,17 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
#if defined(IS_MOBILE_PLATFORM)
status->status = tensorflow::Status::OK();
#else // !defined(IS_MOBILE_PLATFORM)
status->status = ctx->context->SyncExecutors();
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
status->status = context->SyncExecutors();
#endif // !IS_MOBILE_PLATFORM
}
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
ctx->context->SetThreadLocalDevicePlacementPolicy(
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->SetThreadLocalDevicePlacementPolicy(
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
}
@ -887,15 +909,20 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
// safe to call this function from the async EagerExecutor threads.
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
return static_cast<TFE_ContextDevicePlacementPolicy>(
ctx->context->GetDevicePlacementPolicy());
context->GetDevicePlacementPolicy());
}
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
tensorflow::Tensor tensor;
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
if (!status->status.ok()) return nullptr;
return TFE_TensorHandle::CreateLocalHandle(tensor, status);
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(
tensorflow::TensorHandle::CreateLocalHandle(tensor))};
}
void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
@ -1050,10 +1077,12 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
}
return new TFE_TensorHandle{
std::unique_ptr<AbstractTensorHandleInterface>(h->handle->Copy())};
std::unique_ptr<tensorflow::AbstractTensorHandleInterface>(
h->handle->Copy())};
}
AbstractTensorHandleInterface* tensorflow::TensorHandleInterface::Copy() {
tensorflow::AbstractTensorHandleInterface*
tensorflow::TensorHandleInterface::Copy() {
handle_->Ref();
return new TensorHandleInterface(handle_);
}
@ -1069,13 +1098,22 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
return nullptr;
}
return h->handle->Resolve(&status->status);
std::unique_ptr<tensorflow::AbstractTensorInterface> t =
h->handle->Resolve(&status->status);
if (t == nullptr) {
return nullptr;
}
tensorflow::Tensor tensor = tensorflow::TensorFromInterface(t);
return tensorflow::TF_TensorFromTensor(tensor, &status->status);
}
TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
std::unique_ptr<tensorflow::AbstractTensorInterface>
tensorflow::TensorHandleInterface::Resolve(Status* status) {
if (!IsValid(status)) {
return nullptr;
}
if (VariantDeviceIsCustom(handle_->device())) {
tensorflow::CustomDevice* custom_device =
absl::get<tensorflow::CustomDevice*>(handle_->device());
@ -1104,7 +1142,7 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
h_cpu->Unref();
return nullptr;
}
TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status);
auto retval = std::make_unique<tensorflow::TensorInterface>(*t);
h_cpu->Unref();
return retval;
} else {
@ -1131,7 +1169,7 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
if (!status->ok()) return nullptr;
}
}
return tensorflow::TF_TensorFromTensor(tensor, status);
return std::make_unique<tensorflow::TensorInterface>(std::move(tensor));
}
}
@ -1142,8 +1180,7 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
return nullptr;
}
tensorflow::TensorHandle* handle =
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle();
tensorflow::TensorHandleFromInterface(h->handle);
if (VariantDeviceIsCustom(handle->device())) {
const tensorflow::Tensor* t;
status->status = handle->Tensor(&t);
@ -1178,7 +1215,8 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg, TF_Status* status) {
tensorflow::Device* device = nullptr;
tensorflow::EagerContext* context = ctx->context;
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
status->status = context->FindDeviceFromName(device_name, &device);
tensorflow::CustomDevice* custom_device = nullptr;
if (!status->status.ok()) {
@ -1203,19 +1241,17 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
tensorflow::Tensor t(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimvec), buf);
buf->Unref();
tensorflow::TensorHandle* ret_handle;
if (custom_device == nullptr) {
status->status = tensorflow::TensorHandle::CreateLocalHandle(
std::move(t), device, device, context, &ret_handle);
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(
tensorflow::TensorHandle::CreateLocalHandle(std::move(t), device,
device, context))};
} else {
status->status = tensorflow::TensorHandle::CreateLocalHandle(
std::move(t), custom_device, context, &ret_handle);
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(
tensorflow::TensorHandle::CreateLocalHandle(
std::move(t), custom_device, context))};
}
if (!status->status.ok()) {
return nullptr;
}
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(ret_handle)};
}
// This function will block till the operation that produces `h` has
@ -1229,9 +1265,7 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
return 0;
}
tensorflow::TensorHandle* handle =
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle();
tensorflow::TensorHandleFromInterface(h->handle);
if (handle->IsRemote()) {
status->status = tensorflow::errors::InvalidArgument(
"TFE_TensorHandleDeviceMemorySize may not be called on a remote tensor "
@ -1248,8 +1282,7 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) {
std::unique_ptr<TFE_Op> new_op(
new TFE_Op{std::make_unique<tensorflow::OperationInterface>(ctx)});
std::unique_ptr<TFE_Op> new_op(new TFE_Op{ctx->context->CreateOperation()});
status->status = new_op->operation->Reset(op_or_function_name, nullptr);
if (!status->status.ok()) {
new_op.reset();
@ -1285,8 +1318,8 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
TF_Status* status) {
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>> handles(
num_inputs);
absl::FixedArray<std::unique_ptr<tensorflow::AbstractTensorHandleInterface>>
handles(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
handles[i].reset(inputs[i]->handle->Copy());
}
@ -1383,7 +1416,10 @@ void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
TF_Status* status) {
status->status = op->operation->SetAttrTensor(attr_name, tensor);
tensorflow::Tensor t;
status->status = TF_TensorToTensor(tensor, &t);
status->status = op->operation->SetAttrTensor(
attr_name, std::make_unique<tensorflow::TensorInterface>(t));
}
void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
@ -1480,8 +1516,8 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) {
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>> handles(
*num_retvals);
absl::FixedArray<std::unique_ptr<tensorflow::AbstractTensorHandleInterface>>
handles(*num_retvals);
status->status = op->operation->Execute(&handles, num_retvals);
if (!status->status.ok()) {
return;
@ -1497,17 +1533,15 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
TF_Status* status) {
tensorflow::TensorHandle* handle = nullptr;
tensorflow::Device* device;
tensorflow::EagerContext* context = ctx->context;
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
status->status = context->FindDeviceFromName(device_name, &device);
if (!status->status.ok()) {
tensorflow::CustomDevice* dev;
status->status = context->FindCustomDeviceFromName(device_name, &dev);
if (status->status.ok()) {
status->status = dev->CopyTensorToDevice(
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
h->handle.get())
->Handle(),
&handle);
tensorflow::TensorHandleFromInterface(h->handle), &handle);
if (status->status.ok()) {
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
@ -1524,10 +1558,7 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
status->status = context->FindCustomDeviceFromName(handle_device_name, &dev);
if (status->status.ok()) {
status->status = dev->CopyTensorFromDevice(
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
h->handle.get())
->Handle(),
device_name, &handle);
tensorflow::TensorHandleFromInterface(h->handle), device_name, &handle);
if (status->status.ok()) {
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
@ -1537,9 +1568,8 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
// Handle regular case.
status->status = tensorflow::EagerCopyToDevice(
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle(),
context, &context->Executor(), device, false, &handle);
tensorflow::TensorHandleFromInterface(h->handle), context,
&context->Executor(), device, false, &handle);
if (status->status.ok()) {
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
@ -1556,41 +1586,56 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
return;
}
status->status = ctx->context->AddFunctionDef(function_def);
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
status->status = context->AddFunctionDef(function_def);
}
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
TF_Status* status) {
status->status = ctx->context->AddFunctionDef(function->fdef);
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
status->status = context->AddFunctionDef(function->fdef);
}
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
TF_Status* status) {
status->status = ctx->context->RemoveFunction(name);
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
status->status = context->RemoveFunction(name);
}
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
return ctx->context->FindFunctionDef(name) != nullptr;
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
return context->FindFunctionDef(name) != nullptr;
}
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
ctx->context->SetShouldStoreGraphs(true);
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->SetShouldStoreGraphs(true);
}
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
ctx->context->SetShouldStoreGraphs(false);
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->SetShouldStoreGraphs(false);
}
} // extern "C"
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
TF_Status* status) {
return TFE_TensorHandle::CreateLocalHandle(t, status);
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(
tensorflow::TensorHandle::CreateLocalHandle(t))};
}
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
TF_Status* status) {
tensorflow::EagerContext* context = ctx->context;
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
status->status = context->Executor().WaitForAllPendingNodes();
if (!status->status.ok()) return;
tensorflow::mutex_lock ml(*context->MetadataMu());
@ -1611,21 +1656,27 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
}
} // namespace
void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); }
void TFE_ContextStartStep(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->StartStep();
}
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); }
void TFE_ContextEndStep(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->EndStep();
}
void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) {
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
op->operation.get());
*attrs = TFE_OpAttrs(&operation->Attrs(), op->operation->Name().c_str());
tensorflow::EagerOperation* operation = OperationFromInterface(op->operation);
*attrs = TFE_OpAttrs(&operation->Attrs(), operation->Name().c_str());
}
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
tensorflow::AttrValueMap m;
attrs->attributes->FillAttrValueMap(&m);
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
op->operation.get());
tensorflow::EagerOperation* operation = OperationFromInterface(op->operation);
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
for (auto attribute : m) {
destination->Set(attribute.first, attribute.second);
@ -1721,9 +1772,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
TFE_TensorHandle* result_handle =
device_.copy_tensor_to_device(context_, &tensor_handle, &status, info_);
if (!status.status.ok()) return status.status;
*result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
result_handle->handle.get())
->Handle();
*result = tensorflow::TensorHandleFromInterface(result_handle->handle);
(*result)->Ref();
delete result_handle;
return status.status;
@ -1740,9 +1789,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
context_, &tensor_handle, target_device_name.c_str(), &status, info_);
if (!status.status.ok()) return status.status;
*result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
result_handle->handle.get())
->Handle();
*result = tensorflow::TensorHandleFromInterface(result_handle->handle);
(*result)->Ref();
delete result_handle;
return status.status;
@ -1766,9 +1813,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
&attributes, num_retvals, outputs.data(), &status, info_);
if (status.status.ok()) {
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
outputs[i]->handle.get())
->Handle();
retvals[i] = tensorflow::TensorHandleFromInterface(outputs[i]->handle);
retvals[i]->Ref();
delete outputs[i];
}
@ -1793,6 +1838,8 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
TF_Status* status) {
auto custom_device =
std::make_unique<CustomDeviceAPI>(ctx, device, device_info, device_name);
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
status->status =
ctx->context->RegisterCustomDevice(device_name, std::move(custom_device));
context->RegisterCustomDevice(device_name, std::move(custom_device));
}

View File

@ -54,36 +54,32 @@ extern "C" {
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
TFE_TensorHandle* h, TF_Status* status) {
return h->handle->TensorDebugInfo(&status->status);
}
TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
Status* status) {
tensorflow::TensorHandle* handle = TensorHandleFromInterface(h->handle);
const tensorflow::Tensor* tensor;
*status = handle_->Tensor(&tensor);
if (!status->ok()) {
status->status = handle->Tensor(&tensor);
if (!status->status.ok()) {
return nullptr;
}
#ifdef TENSORFLOW_EAGER_USE_XLA
tensorflow::Device* device = absl::get<Device*>(handle_->device());
auto* device = absl::get<tensorflow::Device*>(handle->device());
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
tensorflow::XlaDevice* xla_device =
dynamic_cast<tensorflow::XlaDevice*>(device);
auto* xla_device = dynamic_cast<tensorflow::XlaDevice*>(device);
if (xla_device != nullptr) {
tensorflow::XlaDevice::PaddedShapeFn shape_fn =
xla_device->metadata().padded_shape_fn();
xla::Shape padded_shape;
*status = shape_fn(*tensor, &padded_shape);
if (!status->ok()) {
status->status = shape_fn(*tensor, &padded_shape);
if (!status->status.ok()) {
return nullptr;
}
if (VLOG_IS_ON(3)) {
std::vector<int64> shape_to_log = TensorShapeAsVector(*handle_, status);
if (!status->ok()) {
std::vector<int64> shape_to_log =
TensorShapeAsVector(*handle, &status->status);
if (!status->status.ok()) {
// Ignore the status here as we are simply logging.
*status = tensorflow::Status::OK();
status->status = tensorflow::Status::OK();
} else {
VLOG(3) << "Fully padded shape of ["
<< absl::StrJoin(shape_to_log, ", ") << "] is "
@ -96,7 +92,7 @@ TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
// Currently, the only case of XlaTensor containing a tuple shape is to
// represent 64 bit ints, doubles, and complex numbers (we don't support
// 64bit complex numbers).
*status = tensorflow::errors::InvalidArgument(
status->status = tensorflow::errors::InvalidArgument(
"XlaTensors should only contain tuples of size 2. Shape: ",
padded_shape.DebugString());
return nullptr;
@ -108,13 +104,13 @@ TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
const xla::Shape& shape1 =
xla::ShapeUtil::GetTupleElementShape(padded_shape, 1);
if (shape0.IsTuple() || shape1.IsTuple()) {
*status = tensorflow::errors::InvalidArgument(
status->status = tensorflow::errors::InvalidArgument(
"XlaTensors should not contain nested tuples. Shape: ",
padded_shape.DebugString());
return nullptr;
}
if (!xla::ShapeUtil::Equal(shape0, shape1)) {
*status = tensorflow::errors::InvalidArgument(
status->status = tensorflow::errors::InvalidArgument(
"Subshapes of XlaTensors should be the same. Shape: ",
padded_shape.DebugString());
return nullptr;
@ -139,15 +135,15 @@ TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
dev_dims.push_back(padded_shape.dimensions(dim_index));
}
}
*status = tensorflow::Status::OK();
status->status = tensorflow::Status::OK();
return new TFE_TensorDebugInfo(dev_dims);
}
#endif // TENSORFLOW_EAGER_USE_XLA
// If the tensor is not an XLA tensor, the device shape is
// the same as regular tensor shape.
std::vector<int64> dev_dims = TensorShapeAsVector(*handle_, status);
if (!status->ok()) {
std::vector<int64> dev_dims = TensorShapeAsVector(*handle, &status->status);
if (!status->status.ok()) {
return nullptr;
}
return new TFE_TensorDebugInfo(dev_dims);

View File

@ -31,6 +31,7 @@ using tensorflow::string;
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status) {
if (op_to_reset) {
op_to_reset->operation->Clear();
status->status =
op_to_reset->operation->Reset(op_or_function_name, raw_device_name);
} else {
@ -40,11 +41,15 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
}
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
ctx->context->SetShouldStoreGraphs(true);
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->SetShouldStoreGraphs(true);
}
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
ctx->context->SetShouldStoreGraphs(false);
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->SetShouldStoreGraphs(false);
}
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
@ -474,7 +479,9 @@ void TFE_ContextOptionsSetMirroringPolicy(TFE_ContextOptions* options,
void TFE_ContextSetThreadLocalMirroringPolicy(
TFE_Context* ctx, TFE_ContextMirroringPolicy policy) {
ctx->context->SetThreadLocalMirroringPolicy(
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->SetThreadLocalMirroringPolicy(
static_cast<tensorflow::ContextMirroringPolicy>(policy));
}
@ -483,8 +490,9 @@ void TFE_ContextSetThreadLocalMirroringPolicy(
// safe to call this function from the async EagerExecutor threads.
extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
TFE_Context* ctx) {
return static_cast<TFE_ContextMirroringPolicy>(
ctx->context->GetMirroringPolicy());
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
return static_cast<TFE_ContextMirroringPolicy>(context->GetMirroringPolicy());
}
void TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions* options,
@ -492,6 +500,10 @@ void TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions* options,
options->lazy_remote_inputs_copy = lazy_copy;
}
void TFE_ContextOptionsSetTfrt(TFE_ContextOptions* options, bool use_tfrt) {
options->use_tfrt = use_tfrt;
}
TFE_CancellationManager* TFE_NewCancellationManager() {
return new TFE_CancellationManager;
}
@ -514,7 +526,11 @@ void TFE_DeleteCancellationManager(
void TFE_OpSetCancellationManager(TFE_Op* op,
TFE_CancellationManager* cancellation_manager,
TF_Status* status) {
status->status = op->operation->SetCancellationManager(cancellation_manager);
tensorflow::EagerOperation* operation =
tensorflow::OperationFromInterface(op->operation);
operation->SetCancellationManager(
&cancellation_manager->cancellation_manager);
status->status = tensorflow::Status::OK();
}
TFE_Executor* TFE_NewExecutor(bool is_async) {
@ -537,16 +553,22 @@ void TFE_ExecutorClearError(TFE_Executor* executor) {
}
void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
ctx->context->SetExecutorForThread(executor->executor());
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->SetExecutorForThread(executor->executor());
}
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
return new TFE_Executor(&ctx->context->Executor());
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
return new TFE_Executor(&context->Executor());
}
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
ctx->context->HostCPU()->parsed_name());
context->HostCPU()->parsed_name());
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
void* data = tensorflow::port::Malloc(str.length());
str.copy(static_cast<char*>(data), str.length(), 0);
@ -565,7 +587,9 @@ void TFE_TensorHandleEnableImplicitMirroring(TFE_TensorHandle* h,
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
TF_Buffer* buf, TF_Status* status) {
auto* function_def = ctx->context->FindFunctionDef(function_name);
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
auto* function_def = context->FindFunctionDef(function_name);
if (function_def == nullptr) {
status->status = tensorflow::errors::NotFound(
"Unable to find FunctionDef with name: ", function_name);

View File

@ -296,6 +296,10 @@ TF_CAPI_EXPORT extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetLazyRemoteInputsCopy(
TFE_ContextOptions*, bool lazy_copy);
// Sets whether to use TFRT
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetTfrt(TFE_ContextOptions*,
bool use_tfrt);
// -----------------------------------------------------------------------------
// Cancellation APIs.

View File

@ -455,6 +455,7 @@ TEST(CAPI, TensorHandleOnDeviceMemory) {
TFE_DeleteTensorHandle(copy_aliased); // Note that this will delete copy.
TFE_DeleteTensorHandle(on_host);
}
TF_DeleteDeviceList(devices);
TF_DeleteTensor(m_data);
TFE_DeleteTensorHandle(m);
TFE_DeleteContext(ctx);

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/context_interface.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/device_factory.h"
@ -59,25 +60,16 @@ struct TFE_ContextOptions {
TFE_ContextMirroringPolicy mirroring_policy{TFE_MIRRORING_NONE};
// If true, lazily copy the remote inputs of a function to the target devices.
bool lazy_remote_inputs_copy = true;
// If true, use TFRT backend
bool use_tfrt = false;
};
struct TFE_Context {
tensorflow::EagerContext* context;
std::unique_ptr<tensorflow::AbstractContextInterface> context;
};
struct TFE_TensorHandle {
static TFE_TensorHandle* CreateLocalHandle(const class tensorflow::Tensor& t,
TF_Status* s) {
tensorflow::TensorHandle* handle;
s->status = tensorflow::TensorHandle::CreateLocalHandle(t, &handle);
if (!s->status.ok()) {
return nullptr;
}
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
}
std::unique_ptr<AbstractTensorHandleInterface> handle;
std::unique_ptr<tensorflow::AbstractTensorHandleInterface> handle;
};
struct TFE_TensorDebugInfo {
@ -89,7 +81,7 @@ struct TFE_TensorDebugInfo {
};
struct TFE_Op {
std::unique_ptr<AbstractOperationInterface> operation;
std::unique_ptr<tensorflow::AbstractOperationInterface> operation;
};
struct TFE_MonitoringCounterCell {

View File

@ -184,9 +184,7 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
// TODO(gjn): Add support for waiting on async local mirrors
if (!async) {
auto remote_arg = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
h1_task2->handle.get())
->Handle();
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
matmul->operation.get());
// The input handles should never change since they have been mirrored.

View File

@ -409,13 +409,8 @@ void TensorHandleSilentCopy(bool async,
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
// Validate if the input was replaced with a different TensorHandle
auto arg0 = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
hcpu->handle.get())
->Handle();
auto arg1 = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
hgpu->handle.get())
->Handle();
auto arg0 = tensorflow::TensorHandleFromInterface(hcpu->handle);
auto arg1 = tensorflow::TensorHandleFromInterface(hgpu->handle);
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
matmul->operation.get());

View File

@ -0,0 +1,261 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/mutex.h"
using tensorflow::string;
// =============================================================================
// Unified Execution APIs for Eager and tracing backends.
// =============================================================================
typedef void (*ExecuteOperation)(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs,
TF_OutputList* o, TF_ExecutionContext* ctx,
TF_Status* s);
struct TF_ExecutionContext {
explicit TF_ExecutionContext() {}
absl::variant<TFE_Context*, TF_GraphContext*> ctx;
ExecuteOperation execution_callback;
};
struct TF_AbstractTensor {
absl::variant<TFE_TensorHandle*, TF_GraphTensor*> t;
};
struct TF_AbstractOp {
string op_type;
string op_name;
};
TF_ExecutionContext* TF_NewExecutionContext() {
return new TF_ExecutionContext();
}
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete c; }
TF_AbstractOp* TF_NewAbstractOp() {
TF_AbstractOp* op = new TF_AbstractOp;
return op;
}
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete op; }
TF_AbstractTensor* TF_NewAbstractTensor() {
TF_AbstractTensor* t = new TF_AbstractTensor;
return t;
}
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete t; }
struct TF_GraphContext {
TF_Graph* graph;
// TODO(srbs): Handle captures.
};
TF_GraphContext* TF_NewGraphContext(TF_Graph* g) {
auto ctx = new TF_GraphContext;
ctx->graph = g;
return ctx;
}
void TF_DeleteGraphContext(TF_GraphContext* ctx) { delete ctx; }
struct TF_GraphTensor {
TF_Output output;
TF_GraphContext* ctx;
};
TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* ctx, TF_Output output,
TF_Status* s) {
TF_GraphTensor* t = new TF_GraphTensor;
t->output = output;
t->ctx = ctx;
return t;
}
TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s) {
return t->output;
}
void TF_DeleteGraphTensor(TF_GraphTensor* t) { delete t; }
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
TF_Status* s) {
at->t = t;
}
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s) {
if (!absl::holds_alternative<TFE_TensorHandle*>(at->t)) {
string msg = absl::StrCat("Not an eager tensor handle.",
reinterpret_cast<uintptr_t>(at));
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
}
return absl::get<TFE_TensorHandle*>(at->t);
}
void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t,
TF_Status* s) {
at->t = t;
}
TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at,
TF_Status* s) {
if (!absl::holds_alternative<TF_GraphTensor*>(at->t)) {
string msg = absl::StrCat("Not an graph tensor handle.");
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
}
return absl::get<TF_GraphTensor*>(at->t);
}
bool IsEagerTensor(const TF_AbstractTensor* const t) {
return absl::holds_alternative<TFE_TensorHandle*>(t->t);
}
struct TF_OutputList {
std::vector<TF_AbstractTensor*> outputs;
int expected_num_outputs = -1;
};
TF_OutputList* TF_NewOutputList() { return new TF_OutputList; }
void TF_DeleteOutputList(TF_OutputList* o) { delete o; }
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
TF_Status* s) {
o->expected_num_outputs = num_outputs;
}
int TF_OutputListNumOutputs(TF_OutputList* o) { return o->outputs.size(); }
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
return o->outputs[i];
}
void ExecuteOperationEager(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s) {
auto* tfe_op =
TFE_NewOp(absl::get<TFE_Context*>(ctx->ctx), op->op_type.c_str(), s);
if (TF_GetCode(s) != TF_OK) return;
for (int i = 0; i < num_inputs; ++i) {
if (!IsEagerTensor(inputs[i])) {
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
return;
}
TFE_OpAddInput(tfe_op, absl::get<TFE_TensorHandle*>(inputs[i]->t), s);
if (TF_GetCode(s) != TF_OK) return;
}
if (o->expected_num_outputs == -1) {
string msg =
"The number of outputs must be provided in eager mode. Use "
"TF_OutputListSetNumOutputs.";
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return;
}
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals;
int num_retvals = o->expected_num_outputs;
retvals.resize(num_retvals);
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
TFE_DeleteOp(tfe_op);
if (TF_GetCode(s) != TF_OK) {
return;
}
o->outputs.clear();
o->outputs.reserve(num_retvals);
for (int i = 0; i < num_retvals; ++i) {
auto* t = TF_NewAbstractTensor();
t->t = retvals[i];
o->outputs.push_back(t);
}
}
TF_GraphContext* GetGraphContext(TF_AbstractTensor const* t) {
return absl::get<TF_GraphTensor*>(t->t)->ctx;
}
void ExecuteOperationGraph(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s) {
TF_GraphContext* graph_ctx = absl::get<TF_GraphContext*>(ctx->ctx);
TF_Graph* g = graph_ctx->graph;
auto* tf_opdesc =
TF_NewOperation(g, op->op_type.c_str(), op->op_name.c_str());
for (int i = 0; i < num_inputs; ++i) {
auto* input = inputs[i];
if (IsEagerTensor(input)) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Capturing eager tensors is not supported yet.");
return;
} else {
if (GetGraphContext(input) != graph_ctx) {
TF_SetStatus(
s, TF_INVALID_ARGUMENT,
"Capturing tensors from other graphs is not supported yet.");
return;
}
TF_AddInput(tf_opdesc, absl::get<TF_GraphTensor*>(input->t)->output);
}
}
auto* operation = TF_FinishOperation(tf_opdesc, s);
if (TF_GetCode(s) != TF_OK) return;
int num_outputs = TF_OperationNumOutputs(operation);
o->outputs.clear();
o->outputs.reserve(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
auto* t = TF_NewAbstractTensor();
TF_GraphTensor* output_t = TF_NewGraphTensor(graph_ctx, {operation, i}, s);
if (TF_GetCode(s) != TF_OK) {
return;
}
t->t = output_t;
o->outputs.push_back(t);
}
}
void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context,
TFE_Context* eager_context,
TF_Status* s) {
context->ctx = eager_context;
context->execution_callback = &ExecuteOperationEager;
}
void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context,
TF_GraphContext* graph_context,
TF_Status* s) {
context->ctx = graph_context;
context->execution_callback = &ExecuteOperationGraph;
}
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
TF_Status* s) {
op->op_type = op_type;
}
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
TF_Status* s) {
op->op_name = op_name;
}
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s) {
ctx->execution_callback(op, num_inputs, inputs, o, ctx, s);
}

View File

@ -0,0 +1,119 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#ifdef __cplusplus
extern "C" {
#endif
// =============================================================================
// Unified Execution APIs for Eager and tracing backends.
// =============================================================================
// -----------------------------------------------------------------------------
// Core APIs
// -----------------------------------------------------------------------------
// A TF_ExecutionContext stores knowledge about how to execute an operation.
// E.g. it could know whether we're in eager mode or in graph mode, keeps track
// of gradient tapes, etc.
typedef struct TF_ExecutionContext TF_ExecutionContext;
// A TF_AbstractTensor is an input to an operation. E.g. it could be a union
// type of eager and graph tensors.
typedef struct TF_AbstractTensor TF_AbstractTensor;
// A TF_AbstractOp is the metadata we need to execute an operation. E.g. this
// could contain the op type and other attributes.
typedef struct TF_AbstractOp TF_AbstractOp;
TF_ExecutionContext* TF_NewExecutionContext();
void TF_DeleteExecutionContext(TF_ExecutionContext*);
TF_AbstractOp* TF_NewAbstractOp();
void TF_DeleteAbstractOp(TF_AbstractOp*);
TF_AbstractTensor* TF_NewAbstractTensor();
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
// -----------------------------------------------------------------------------
// APIs for Eager and graph modes
// -----------------------------------------------------------------------------
// Keeps track of the current graph and other state e.g. captures etc.
typedef struct TF_GraphContext TF_GraphContext;
TF_GraphContext* TF_NewGraphContext(TF_Graph*);
void TF_DeleteGraphContext(TF_GraphContext*);
// `eager_context` must outlive `context`.
void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context,
TFE_Context* eager_context, TF_Status*);
// `graph_context` must outlive `context`.
void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context,
TF_GraphContext* graph_context,
TF_Status*);
// TODO(srbs): Add APIs for specifying attrs etc.
// `op_type` must outlive `op`.
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
TF_Status* s);
// `op_name` must outlive `op`.
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
TF_Status* s);
// Wrapper for TF_Output but contains a pointer to TF_GraphContext as well.
typedef struct TF_GraphTensor TF_GraphTensor;
TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* c, TF_Output t,
TF_Status* s);
TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s);
void TF_DeleteGraphTensor(TF_GraphTensor* t);
// `t` must outlive `at`.
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
TF_Status* s);
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s);
// `t` must outlive `at`.
void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t,
TF_Status* s);
TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at,
TF_Status* s);
// TF_OutputList just lets us not specify the number of outputs of an operation
// beforehand. This forces a memory allocation in the runtime, which is bad, but
// it allows for generic code.
typedef struct TF_OutputList TF_OutputList;
TF_OutputList* TF_NewOutputList();
void TF_DeleteOutputList(TF_OutputList* o);
void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*);
int TF_OutputListNumOutputs(TF_OutputList* o);
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
// capture some inputs and then add a node in the graph, and after
// execution/node creation it'll go and record things that happened in any tape
// which happens to be active.
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s);
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_

View File

@ -0,0 +1,204 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include <string.h>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/cc/profiler/profiler.h"
#include "tensorflow/core/lib/monitoring/collection_registry.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
using tensorflow::string;
namespace tensorflow {
namespace {
TEST(UnifedCAPI, TestBasicEager) {
TF_ExecutionContext* ctx = TF_NewExecutionContext();
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* eager_ctx = TFE_NewContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
// Enter the eager context.
TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract input tensor.
TFE_TensorHandle* t = TestScalarTensorHandle(2.0f);
TF_AbstractTensor* at = TF_NewAbstractTensor();
TF_AbstractTensorSetEagerTensor(at, t, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract operation.
auto* op = TF_NewAbstractOp();
TF_AbstractOpSetOpType(op, "Add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {at, at};
TF_OutputList* o = TF_NewOutputList();
TF_OutputListSetNumOutputs(o, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Execute.
TF_ExecuteOperation(op, 2, inputs, o, ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Clean up operation and inputs.
TF_DeleteAbstractOp(op);
TF_DeleteAbstractTensor(at);
TFE_DeleteTensorHandle(t);
// Verify the results.
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
TF_AbstractTensor* result = TF_OutputListGet(o, 0);
TFE_TensorHandle* result_t =
TF_AbstractTensorGetEagerTensor(result, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_Tensor* result_tensor = TFE_TensorHandleResolve(result_t, status.get());
float* result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 4.0);
TF_DeleteTensor(result_tensor);
TF_DeleteAbstractTensor(result);
TFE_DeleteTensorHandle(result_t);
TF_DeleteOutputList(o);
TFE_DeleteContext(eager_ctx);
TF_DeleteExecutionContext(ctx);
}
TEST(UnifedCAPI, TestBasicGraph) {
TF_ExecutionContext* ctx = TF_NewExecutionContext();
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
// Enter a graph context.
TF_Graph* g = TF_NewGraph();
TF_GraphContext* graph_context = TF_NewGraphContext(g);
TF_ExecutionContextSetGraphContext(ctx, graph_context, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
auto* placeholder_op = TF_NewOperation(g, "Placeholder", "Placeholder");
TF_SetAttrType(placeholder_op, "dtype", TF_FLOAT);
auto* operation = TF_FinishOperation(placeholder_op, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_Output placeholder_t = {operation, 0};
TF_GraphTensor* graph_t =
TF_NewGraphTensor(graph_context, placeholder_t, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractTensor* t = TF_NewAbstractTensor();
TF_AbstractTensorSetGraphTensor(t, graph_t, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract operation.
auto* op = TF_NewAbstractOp();
TF_AbstractOpSetOpType(op, "Add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(op, "my_add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {t, t};
TF_OutputList* o = TF_NewOutputList();
// Execute.
TF_ExecuteOperation(op, 2, inputs, o, ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Clean up operation and inputs.
TF_DeleteAbstractOp(op);
TF_DeleteAbstractTensor(t);
TF_DeleteGraphTensor(graph_t);
TF_AbstractTensor* result = TF_OutputListGet(o, 0);
TF_GraphTensor* result_graph_tensor =
TF_AbstractTensorGetGraphTensor(result, status.get());
TF_DeleteAbstractTensor(result);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_Output result_output =
TF_GraphTensorToOutput(result_graph_tensor, status.get());
TF_DeleteGraphTensor(result_graph_tensor);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
string fn_name = "double";
TF_Function* f = TF_GraphToFunction(
g, fn_name.c_str(), 0, -1, nullptr, 1, &placeholder_t, 1, &result_output,
nullptr, nullptr, fn_name.c_str(), status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an eager context to run the function.
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* eager_ctx = TFE_NewContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
// Build the abstract op to run the function.
TFE_ContextAddFunction(eager_ctx, f, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOp* fn_op = TF_NewAbstractOp();
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract input tensor.
TFE_TensorHandle* input_eager = TestScalarTensorHandle(2.0f);
TF_AbstractTensor* input_t = TF_NewAbstractTensor();
TF_AbstractTensorSetEagerTensor(input_t, input_eager, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Enter the eager context.
TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_OutputListSetNumOutputs(o, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_ExecuteOperation(fn_op, 1, &input_t, o, ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
TF_AbstractTensor* final_result = TF_OutputListGet(o, 0);
TFE_TensorHandle* final =
TF_AbstractTensorGetEagerTensor(final_result, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_Tensor* f_t = TFE_TensorHandleResolve(final, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
float* f_value = static_cast<float*>(TF_TensorData(f_t));
ASSERT_EQ(*f_value, 4.0);
TF_DeleteOutputList(o);
TF_DeleteAbstractOp(fn_op);
TF_DeleteAbstractTensor(input_t);
TFE_DeleteTensorHandle(input_eager);
TF_DeleteAbstractTensor(final_result);
TFE_DeleteTensorHandle(final);
TF_DeleteTensor(f_t);
TF_DeleteFunction(f);
TF_DeleteGraphContext(graph_context);
TF_DeleteGraph(g);
TFE_DeleteContext(eager_ctx);
TF_DeleteExecutionContext(ctx);
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,143 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/context_interface.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/framework/tensor_interface.h"
#include "tensorflow/core/platform/casts.h"
namespace tensorflow {
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateInt64Scalar(
int64 value) {
return std::make_unique<TensorInterface>(Tensor(value));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateUint64Scalar(
uint64 value) {
return std::make_unique<TensorInterface>(Tensor(value));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateInt32Scalar(
int32 value) {
return std::make_unique<TensorInterface>(Tensor(value));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateFloatScalar(
float value) {
return std::make_unique<TensorInterface>(Tensor(value));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateDoubleScalar(
double value) {
return std::make_unique<TensorInterface>(Tensor(value));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateHalfScalar(
Eigen::half value) {
return std::make_unique<TensorInterface>(Tensor(value));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateStringScalar(
tstring value) {
return std::make_unique<TensorInterface>(Tensor(value));
}
std::unique_ptr<AbstractTensorInterface>
ContextInterface::CreateComplex128Scalar(complex128 value) {
return std::make_unique<TensorInterface>(Tensor(value));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateBoolScalar(
bool value) {
return std::make_unique<TensorInterface>(Tensor(value));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateInt64Tensor(
absl::Span<const int64> dim_sizes) {
return std::make_unique<TensorInterface>(
Tensor(DT_INT64, TensorShape(dim_sizes)));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateUint64Tensor(
absl::Span<const int64> dim_sizes) {
return std::make_unique<TensorInterface>(
Tensor(DT_UINT64, TensorShape(dim_sizes)));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateInt32Tensor(
absl::Span<const int64> dim_sizes) {
return std::make_unique<TensorInterface>(
Tensor(DT_INT32, TensorShape(dim_sizes)));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateFloatTensor(
absl::Span<const int64> dim_sizes) {
return std::make_unique<TensorInterface>(
Tensor(DT_FLOAT, TensorShape(dim_sizes)));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateDoubleTensor(
absl::Span<const int64> dim_sizes) {
return std::make_unique<TensorInterface>(
Tensor(DT_DOUBLE, TensorShape(dim_sizes)));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateHalfTensor(
absl::Span<const int64> dim_sizes) {
return std::make_unique<TensorInterface>(
Tensor(DT_HALF, TensorShape(dim_sizes)));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateStringTensor(
absl::Span<const int64> dim_sizes) {
return std::make_unique<TensorInterface>(
Tensor(DT_STRING, TensorShape(dim_sizes)));
}
std::unique_ptr<AbstractTensorInterface>
ContextInterface::CreateComplex128Tensor(absl::Span<const int64> dim_sizes) {
return std::make_unique<TensorInterface>(
Tensor(DT_COMPLEX128, TensorShape(dim_sizes)));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateBoolTensor(
absl::Span<const int64> dim_sizes) {
return std::make_unique<TensorInterface>(
Tensor(DT_BOOL, TensorShape(dim_sizes)));
}
std::unique_ptr<AbstractTensorHandleInterface>
ContextInterface::CreateLocalHandle(
const std::unique_ptr<AbstractTensorInterface> t) {
Tensor tensor = tensorflow::down_cast<TensorInterface*>(t.get())->Tensor();
return std::make_unique<TensorHandleInterface>(
TensorHandle::CreateLocalHandle(std::move(tensor), /*d=*/ctx_->HostCPU(),
/*op_device=*/nullptr, ctx_));
}
std::unique_ptr<AbstractOperationInterface>
ContextInterface::CreateOperation() {
return std::make_unique<tensorflow::OperationInterface>(ctx_);
}
void ContextInterface::ListDevices(
std::vector<tensorflow::DeviceAttributes>* devices) {
ctx_->ListDevices(devices);
}
} // namespace tensorflow

View File

@ -0,0 +1,155 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
#define TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
#include <memory>
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/tensor_interface.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/tstring.h"
namespace tensorflow {
// Abstract interface to a context.
//
// A context is responsible for creating key objects such as Tensors,
// TensorHandles & Operations.
class AbstractContextInterface {
public:
virtual ~AbstractContextInterface() {}
// Scalar creation functions
virtual std::unique_ptr<AbstractTensorInterface> CreateInt64Scalar(
int64 value) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateUint64Scalar(
uint64 value) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateInt32Scalar(
int32 value) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateFloatScalar(
float value) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateDoubleScalar(
double value) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateHalfScalar(
Eigen::half value) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateStringScalar(
tstring value) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateComplex128Scalar(
complex128 value) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateBoolScalar(
bool value) = 0;
// Tensor creation functions
virtual std::unique_ptr<AbstractTensorInterface> CreateInt64Tensor(
absl::Span<const int64> dim_sizes) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateUint64Tensor(
absl::Span<const int64> dim_sizes) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateInt32Tensor(
absl::Span<const int64> dim_sizes) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateFloatTensor(
absl::Span<const int64> dim_sizes) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateDoubleTensor(
absl::Span<const int64> dim_sizes) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateHalfTensor(
absl::Span<const int64> dim_sizes) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateStringTensor(
absl::Span<const int64> dim_sizes) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateComplex128Tensor(
absl::Span<const int64> dim_sizes) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateBoolTensor(
absl::Span<const int64> dim_sizes) = 0;
// Create a handle to wrap and manage a Tensor
virtual std::unique_ptr<AbstractTensorHandleInterface> CreateLocalHandle(
const std::unique_ptr<AbstractTensorInterface> t) = 0;
// Create an operation to perform op execution
virtual std::unique_ptr<AbstractOperationInterface> CreateOperation() = 0;
// List attributes of available devices
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
};
// TODO(gjn): Try to move these all to EagerContext and make it implement
// AbstractContextInterface. Currently, this is not so straightforward because
// of various BUILD file dependencies.
class ContextInterface : public AbstractContextInterface {
public:
explicit ContextInterface(EagerContext* ctx) : ctx_(ctx) {}
~ContextInterface() override {}
std::unique_ptr<AbstractTensorInterface> CreateInt64Scalar(
int64 value) override;
std::unique_ptr<AbstractTensorInterface> CreateUint64Scalar(
uint64 value) override;
std::unique_ptr<AbstractTensorInterface> CreateInt32Scalar(
int32 value) override;
std::unique_ptr<AbstractTensorInterface> CreateFloatScalar(
float value) override;
std::unique_ptr<AbstractTensorInterface> CreateDoubleScalar(
double value) override;
std::unique_ptr<AbstractTensorInterface> CreateHalfScalar(
Eigen::half value) override;
std::unique_ptr<AbstractTensorInterface> CreateStringScalar(
tensorflow::tstring value) override;
std::unique_ptr<AbstractTensorInterface> CreateComplex128Scalar(
tensorflow::complex128 value) override;
std::unique_ptr<AbstractTensorInterface> CreateBoolScalar(
bool value) override;
std::unique_ptr<AbstractTensorInterface> CreateInt64Tensor(
absl::Span<const int64> dim_sizes) override;
std::unique_ptr<AbstractTensorInterface> CreateUint64Tensor(
absl::Span<const int64> dim_sizes) override;
std::unique_ptr<AbstractTensorInterface> CreateInt32Tensor(
absl::Span<const int64> dim_sizes) override;
std::unique_ptr<AbstractTensorInterface> CreateFloatTensor(
absl::Span<const int64> dim_sizes) override;
std::unique_ptr<AbstractTensorInterface> CreateDoubleTensor(
absl::Span<const int64> dim_sizes) override;
std::unique_ptr<AbstractTensorInterface> CreateHalfTensor(
absl::Span<const int64> dim_sizes) override;
std::unique_ptr<AbstractTensorInterface> CreateStringTensor(
absl::Span<const int64> dim_sizes) override;
std::unique_ptr<AbstractTensorInterface> CreateComplex128Tensor(
absl::Span<const int64> dim_sizes) override;
std::unique_ptr<AbstractTensorInterface> CreateBoolTensor(
absl::Span<const int64> dim_sizes) override;
std::unique_ptr<AbstractTensorHandleInterface> CreateLocalHandle(
const std::unique_ptr<AbstractTensorInterface> t) override;
std::unique_ptr<AbstractOperationInterface> CreateOperation() override;
void ListDevices(std::vector<DeviceAttributes>* devices) override;
// For runtime specific APIs, provide ability to get the underlying context.
EagerContext* Context() const { return ctx_; }
private:
EagerContext* ctx_;
};
inline EagerContext* ContextFromInterface(
const std::unique_ptr<AbstractContextInterface>& context) {
return down_cast<ContextInterface*>(context.get())->Context();
}
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_

View File

@ -20,7 +20,6 @@ limitations under the License.
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_reference.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
@ -47,9 +46,7 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
return nullptr;
}
tensorflow::TensorHandle* handle =
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle();
tensorflow::TensorHandleFromInterface(h->handle);
if (handle->IsRemote()) {
status->status = tensorflow::errors::InvalidArgument(
"DLPack doesn't support remote tensor");
@ -289,9 +286,8 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
return static_cast<void*>(dlm_tensor);
}
TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status,
TFE_Context* ctx) {
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
DLTensor* dl_tensor = &dlmt->dl_tensor;
absl::optional<std::string> device_name =

View File

@ -30,7 +30,8 @@ TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h,
// Converts DLPack (DLManagedTensor*) to eager tensor handle.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm,
TF_Status* status);
TF_Status* status,
TFE_Context* ctx);
// Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule.
TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr);

View File

@ -26,8 +26,7 @@ limitations under the License.
namespace tensorflow {
OperationInterface::OperationInterface(TFE_Context* ctx)
: operation_(ctx->context) {}
OperationInterface::OperationInterface(EagerContext* ctx) : operation_(ctx) {}
const string& OperationInterface::DeviceName() const {
absl::variant<Device*, CustomDevice*> variant_device =
@ -99,9 +98,8 @@ Status OperationInterface::SetAttrFunction(
AttrValue attr_value;
NameAttrList* func = attr_value.mutable_func();
func->set_name(value->Name());
OperationInterface* value_operation =
tensorflow::down_cast<OperationInterface*>(value.get());
value_operation->operation_.Attrs().FillAttrValueMap(func->mutable_attr());
EagerOperation* value_operation = OperationFromInterface(value);
value_operation->Attrs().FillAttrValueMap(func->mutable_attr());
operation_.MutableAttrs()->Set(attr_name, attr_value);
return Status::OK();
}
@ -116,10 +114,9 @@ Status OperationInterface::SetAttrFunctionName(const char* attr_name,
return Status::OK();
}
Status OperationInterface::SetAttrTensor(const char* attr_name,
TF_Tensor* tensor) {
Tensor t;
TF_RETURN_IF_ERROR(TF_TensorToTensor(tensor, &t));
Status OperationInterface::SetAttrTensor(
const char* attr_name, std::unique_ptr<AbstractTensorInterface> tensor) {
Tensor t = TensorFromInterface(tensor);
operation_.MutableAttrs()->Set(attr_name, t);
return Status::OK();
}
@ -209,11 +206,10 @@ Status OperationInterface::SetAttrFunctionList(const char* attr_name,
int num_values) {
std::unique_ptr<NameAttrList[]> funcs(new NameAttrList[num_values]);
for (int i = 0; i < num_values; i++) {
auto value_operation =
tensorflow::down_cast<OperationInterface*>(value[i]->operation.get());
funcs[i].set_name(value_operation->operation_.Name());
value_operation->operation_.Attrs().FillAttrValueMap(
funcs[i].mutable_attr());
EagerOperation* value_operation =
OperationFromInterface(value[i]->operation);
funcs[i].set_name(value_operation->Name());
value_operation->Attrs().FillAttrValueMap(funcs[i].mutable_attr());
}
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<const NameAttrList>(funcs.get(), num_values));
@ -267,8 +263,7 @@ Status OperationInterface::OutputLength(const char* output_name, int* length) {
Status OperationInterface::AddInput(
const std::unique_ptr<AbstractTensorHandleInterface>& input) {
TensorHandle* h =
tensorflow::down_cast<TensorHandleInterface*>(input.get())->Handle();
TensorHandle* h = TensorHandleFromInterface(input);
operation_.AddInput(h);
return operation_.MaybeInferSingleInputAttrs(h);
}
@ -277,8 +272,7 @@ Status OperationInterface::AddInputList(
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
inputs) {
for (auto& input : inputs) {
TensorHandle* h =
tensorflow::down_cast<TensorHandleInterface*>(input.get())->Handle();
TensorHandle* h = TensorHandleFromInterface(input);
operation_.AddInput(h);
}
return operation_.InferInputListAttrs(inputs.size());
@ -297,13 +291,6 @@ Status OperationInterface::Execute(
return Status::OK();
}
Status OperationInterface::SetCancellationManager(
TFE_CancellationManager* cancellation_manager) {
operation_.SetCancellationManager(
&cancellation_manager->cancellation_manager);
return Status::OK();
}
Status OperationInterface::SetUseXla(bool enable) {
operation_.SetUseXla(enable);
return Status::OK();

View File

@ -19,9 +19,14 @@ limitations under the License.
#include "absl/container/fixed_array.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/framework/tensor_interface.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
// Abstract interface to an operation.
class AbstractOperationInterface {
@ -29,90 +34,67 @@ class AbstractOperationInterface {
virtual ~AbstractOperationInterface() {}
virtual void Clear() = 0;
virtual tensorflow::Status Reset(const char* op,
const char* raw_device_name) = 0;
virtual Status Reset(const char* op, const char* raw_device_name) = 0;
virtual const tensorflow::string& Name() const = 0;
virtual const tensorflow::string& DeviceName() const = 0;
virtual tensorflow::Status SetDeviceName(const char* name) = 0;
virtual const string& Name() const = 0;
virtual const string& DeviceName() const = 0;
virtual Status SetDeviceName(const char* name) = 0;
virtual tensorflow::Status AddInput(
virtual Status AddInput(
const std::unique_ptr<AbstractTensorHandleInterface>& input) = 0;
virtual tensorflow::Status AddInputList(
virtual Status AddInputList(
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
inputs) = 0;
virtual tensorflow::Status Execute(
virtual Status Execute(
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
int* num_retvals) = 0;
virtual const tensorflow::OpDef* OpDef() const = 0;
virtual tensorflow::Status SetAttrString(const char* attr_name,
const char* data, size_t length) = 0;
virtual tensorflow::Status SetAttrInt(const char* attr_name,
int64_t value) = 0;
virtual tensorflow::Status SetAttrFloat(const char* attr_name,
float value) = 0;
virtual tensorflow::Status SetAttrBool(const char* attr_name, bool value) = 0;
virtual tensorflow::Status SetAttrType(const char* attr_name,
TF_DataType value) = 0;
virtual tensorflow::Status SetAttrShape(const char* attr_name,
const int64_t* dims,
const int num_dims) = 0;
virtual tensorflow::Status SetAttrFunction(
virtual Status SetAttrString(const char* attr_name, const char* data,
size_t length) = 0;
virtual Status SetAttrInt(const char* attr_name, int64_t value) = 0;
virtual Status SetAttrFloat(const char* attr_name, float value) = 0;
virtual Status SetAttrBool(const char* attr_name, bool value) = 0;
virtual Status SetAttrType(const char* attr_name, TF_DataType value) = 0;
virtual Status SetAttrShape(const char* attr_name, const int64_t* dims,
const int num_dims) = 0;
virtual Status SetAttrFunction(
const char* attr_name,
const std::unique_ptr<AbstractOperationInterface>& value) = 0;
virtual tensorflow::Status SetAttrFunctionName(const char* attr_name,
const char* value,
size_t length) = 0;
virtual tensorflow::Status SetAttrTensor(const char* attr_name,
TF_Tensor* tensor) = 0;
virtual tensorflow::Status SetAttrStringList(const char* attr_name,
const void* const* values,
const size_t* lengths,
int num_values) = 0;
virtual tensorflow::Status SetAttrFloatList(const char* attr_name,
const float* values,
int num_values) = 0;
virtual tensorflow::Status SetAttrIntList(const char* attr_name,
const int64_t* values,
int num_values) = 0;
virtual tensorflow::Status SetAttrTypeList(const char* attr_name,
const TF_DataType* values,
int num_values) = 0;
virtual tensorflow::Status SetAttrBoolList(const char* attr_name,
const unsigned char* values,
int num_values) = 0;
virtual tensorflow::Status SetAttrShapeList(const char* attr_name,
const int64_t** dims,
const int* num_dims,
int num_values) = 0;
virtual tensorflow::Status SetAttrFunctionList(const char* attr_name,
const TFE_Op** value,
int num_values) = 0;
virtual Status SetAttrFunctionName(const char* attr_name, const char* value,
size_t length) = 0;
virtual Status SetAttrTensor(
const char* attr_name,
std::unique_ptr<AbstractTensorInterface> tensor) = 0;
virtual Status SetAttrStringList(const char* attr_name,
const void* const* values,
const size_t* lengths, int num_values) = 0;
virtual Status SetAttrFloatList(const char* attr_name, const float* values,
int num_values) = 0;
virtual Status SetAttrIntList(const char* attr_name, const int64_t* values,
int num_values) = 0;
virtual Status SetAttrTypeList(const char* attr_name,
const TF_DataType* values, int num_values) = 0;
virtual Status SetAttrBoolList(const char* attr_name,
const unsigned char* values,
int num_values) = 0;
virtual Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
const int* num_dims, int num_values) = 0;
virtual Status SetAttrFunctionList(const char* attr_name,
const TFE_Op** value, int num_values) = 0;
virtual tensorflow::Status InputLength(const char* input_name,
int* length) = 0;
virtual tensorflow::Status OutputLength(const char* output_name,
int* length) = 0;
virtual Status InputLength(const char* input_name, int* length) = 0;
virtual Status OutputLength(const char* output_name, int* length) = 0;
// Experimental
virtual tensorflow::Status SetUseXla(bool enable) {
return tensorflow::errors::Unimplemented("SetUseXla not implemented");
}
virtual tensorflow::Status SetCancellationManager(
TFE_CancellationManager* cancellation_manager) {
return tensorflow::errors::Unimplemented(
"SetCancellationManager not implemented");
}
virtual Status SetUseXla(bool enable) = 0;
};
namespace tensorflow {
class OpDef;
class OperationInterface : public AbstractOperationInterface {
public:
explicit OperationInterface(TFE_Context* ctx);
explicit OperationInterface(EagerContext* ctx);
~OperationInterface() override{};
void Clear() override { operation_.Clear(); }
@ -149,7 +131,9 @@ class OperationInterface : public AbstractOperationInterface {
const std::unique_ptr<AbstractOperationInterface>& value) override;
Status SetAttrFunctionName(const char* attr_name, const char* data,
size_t length) override;
Status SetAttrTensor(const char* attr_name, TF_Tensor* tensor) override;
Status SetAttrTensor(
const char* attr_name,
std::unique_ptr<AbstractTensorInterface> tensor) override;
Status SetAttrStringList(const char* attr_name, const void* const* values,
const size_t* lengths, int num_values) override;
Status SetAttrFloatList(const char* attr_name, const float* values,
@ -169,20 +153,25 @@ class OperationInterface : public AbstractOperationInterface {
Status OutputLength(const char* output_name, int* length) override;
Status SetUseXla(bool enable) override;
Status SetCancellationManager(
TFE_CancellationManager* cancellation_manager) override;
// TODO(gjn): Remove once TFE_InferShapes is removed
const tensorflow::AttrBuilder& Attrs() const { return operation_.Attrs(); }
tensorflow::AttrBuilder* MutableAttrs() { return operation_.MutableAttrs(); }
const AttrBuilder& Attrs() const { return operation_.Attrs(); }
AttrBuilder* MutableAttrs() { return operation_.MutableAttrs(); }
const TensorHandle* GetInput(int i) const { return operation_.Inputs()[i]; }
EagerOperation* Operation() { return &operation_; }
private:
const tensorflow::OpDef* GetOpDef(Status* status);
EagerOperation operation_;
};
inline EagerOperation* OperationFromInterface(
const std::unique_ptr<AbstractOperationInterface>& operation) {
return down_cast<OperationInterface*>(operation.get())->Operation();
}
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_

View File

@ -15,10 +15,13 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
#define TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/tensor_interface.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
// Abstract interface to a TensorHandle.
//
@ -34,24 +37,22 @@ class AbstractTensorHandleInterface {
virtual ~AbstractTensorHandleInterface() {}
// Check if the handle is in a valid initialized state.
virtual bool IsValid(tensorflow::Status* status) const = 0;
virtual bool IsValid(Status* status) const = 0;
// Returns tensor dtype.
virtual TF_DataType DataType() const = 0;
// Returns number of dimensions.
virtual int NumDims(tensorflow::Status* status) const = 0;
virtual int NumDims(Status* status) const = 0;
// Returns number of elements across all dimensions.
virtual int64_t NumElements(tensorflow::Status* status) const = 0;
virtual int64_t NumElements(Status* status) const = 0;
// Returns size of specified dimension
virtual int64_t Dim(int dim_index, tensorflow::Status* status) const = 0;
virtual int64_t Dim(int dim_index, Status* status) const = 0;
// Returns the device which created the handle.
virtual const char* DeviceName(tensorflow::Status* status) const = 0;
virtual const char* DeviceName(Status* status) const = 0;
// Returns the device where the tensor was placed.
virtual const char* BackingDeviceName(tensorflow::Status* status) const = 0;
virtual const char* BackingDeviceName(Status* status) const = 0;
// Returns a tensor for the handle. If tensor is remote, it will be copied.
virtual TF_Tensor* Resolve(tensorflow::Status* status) = 0;
// Returns debug information about the tensor.
virtual TFE_TensorDebugInfo* TensorDebugInfo(tensorflow::Status* status) = 0;
virtual std::unique_ptr<AbstractTensorInterface> Resolve(Status* status) = 0;
// Return a copy of the handle.
virtual AbstractTensorHandleInterface* Copy() = 0;
@ -65,8 +66,9 @@ class AbstractTensorHandleInterface {
virtual void EnableImplicitMirroring() = 0;
};
namespace tensorflow {
// TODO(gjn): Try to move these all to TensorHandle and make it implement
// AbstractTensorHandleInterface. Currently, this is not so straightforward
// because of various BUILD file dependencies.
class TensorHandleInterface : public AbstractTensorHandleInterface {
public:
explicit TensorHandleInterface(TensorHandle* h) : handle_(h) {}
@ -80,21 +82,24 @@ class TensorHandleInterface : public AbstractTensorHandleInterface {
const char* DeviceName(Status* status) const override;
const char* BackingDeviceName(Status* status) const override;
TF_Tensor* Resolve(Status* status) override;
TFE_TensorDebugInfo* TensorDebugInfo(Status* status) override;
std::unique_ptr<AbstractTensorInterface> Resolve(Status* status) override;
AbstractTensorHandleInterface* Copy() override;
void EnableImplicitMirroring() override;
// TODO(gjn): This is not a very generic interface, but is needed for specific
// use cases.
// For runtime specific APIs, provide ability to get the underlying handle.
TensorHandle* Handle() { return handle_; }
private:
TensorHandle* handle_;
};
inline TensorHandle* TensorHandleFromInterface(
const std::unique_ptr<AbstractTensorHandleInterface>& handle) {
return down_cast<TensorHandleInterface*>(handle.get())->Handle();
}
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_

View File

@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/error.h"
#include "tensorflow/core/platform/status.h"
using ::tensorflow::IOError;
using ::tensorflow::Status;

View File

@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_C_TF_STATUS_HELPER_H_
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {

View File

@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_C_TF_STATUS_INTERNAL_H_
#define TENSORFLOW_C_TF_STATUS_INTERNAL_H_
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/status.h"
// Internal structures used by the status C API. These are likely to change
// and should not be depended on.

View File

@ -381,7 +381,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
->ToTensor(dst);
}
Status TensorInterface::ToTensor(Tensor* dst) const {
Status TensorInterface::ToTensor(tensorflow::Tensor* dst) const {
if (tensor_.dtype() == DT_RESOURCE) {
if (tensor_.dims() != 0) {
return InvalidArgument(
@ -389,7 +389,7 @@ Status TensorInterface::ToTensor(Tensor* dst) const {
"shape ",
tensor_.shape().DebugString());
}
*dst = Tensor(tensorflow::DT_RESOURCE, tensor_.shape());
*dst = tensorflow::Tensor(tensorflow::DT_RESOURCE, tensor_.shape());
if (!dst->scalar<tensorflow::ResourceHandle>()().ParseFromString(
string(static_cast<const char*>(Data()), ByteSize()))) {
return InvalidArgument(
@ -414,7 +414,7 @@ Status TensorInterface::ToTensor(Tensor* dst) const {
const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
const char* limit = input + src_size;
*dst = Tensor(tensor_.dtype(), tensor_.shape());
*dst = tensorflow::Tensor(tensor_.dtype(), tensor_.shape());
auto dstarray = dst->flat<tstring>();
for (tensorflow::int64 i = 0; i < num_elements; ++i) {
tensorflow::uint64 offset =

View File

@ -31,7 +31,7 @@ limitations under the License.
// passed to or returned from C functions *by pointer*. Otherwise, changes to
// its internal structure will break the C API's binary interface.
typedef struct TF_Tensor {
std::unique_ptr<AbstractTensorInterface> tensor;
std::unique_ptr<tensorflow::AbstractTensorInterface> tensor;
} TF_Tensor;
class TF_ManagedBuffer : public tensorflow::TensorBuffer {

View File

@ -124,7 +124,7 @@ Status LoadSavedModel(const SessionOptions& session_options,
/// the export directory definitely does not contain a SavedModel. If the method
/// returns `true`, the export directory may contain a SavedModel but provides
/// no guarantee that it can be loaded.
bool MaybeSavedModelDirectory(const string& export_dir);
bool MaybeSavedModelDirectory(const std::string& export_dir);
} // namespace tensorflow

View File

@ -115,11 +115,10 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/core:core_cpu",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
],
alwayslink = 1,
)

View File

@ -12,7 +12,7 @@ load("@bazel_skylib//lib:paths.bzl", "paths")
_default_test_file_exts = ["mlir", ".pbtxt", ".td"]
_default_driver = "@llvm-project//mlir:run_lit.sh"
_default_size = "small"
_default_tags = ["no_rocm"]
_default_tags = []
# These are patterns which we should never match, for tests, subdirectories, or
# test input data files.

View File

@ -17,8 +17,7 @@ package_group(
name = "friends",
includes = ["//third_party/mlir:subpackages"],
packages = [
"//learning/brain/experimental/mlir/...",
"//learning/brain/google/xla/...",
"//learning/brain/mlir/...",
"//tensorflow/compiler/mlir/...",
],
)
@ -205,8 +204,6 @@ cc_library(
cc_library(
name = "tensorflow_lite",
srcs = [
"experimental/estimators/estimator.h",
"experimental/estimators/gpu_estimator.h.inc",
"ir/tfl_ops.cc",
"ir/tfl_ops.cc.inc",
"ir/tfl_ops.h.inc",
@ -216,7 +213,6 @@ cc_library(
"utils/attribute_utils.cc",
],
hdrs = [
"experimental/estimators/hardware.h",
"ir/tfl_ops.h",
"transforms/passes.h",
"utils/attribute_utils.h",
@ -226,6 +222,7 @@ cc_library(
deps = [
":tensorflow_lite_ops_inc_gen",
":validators",
"//tensorflow/compiler/mlir/lite/experimental/estimators:cost_estimators",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/lite/schema:schema_fbs",
"@llvm-project//llvm:support",

View File

@ -0,0 +1,15 @@
package(
default_visibility = [
"//visibility:public",
],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "cost_estimators",
textual_hdrs = [
"estimator.h",
"gpu_estimators.h",
"hardware.h",
],
)

View File

@ -1,72 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATOR_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATOR_H_
// tfl.average_pool_2d
template <>
class TFLiteCostEstimator<AveragePool2DOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.conv_2d
template <>
class TFLiteCostEstimator<Conv2DOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
// TODO(renjieliu): We probably need to check for dynamic weights.
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.depthwise_conv_2d
template <>
class TFLiteCostEstimator<DepthwiseConv2DOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.max_pool_2d
template <>
class TFLiteCostEstimator<MaxPool2DOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATOR_H_

View File

@ -0,0 +1,231 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
// tfl.add
template <>
class TFLiteCostEstimator<AddOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.average_pool_2d
template <>
class TFLiteCostEstimator<AveragePool2DOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.concatenation
template <>
class TFLiteCostEstimator<ConcatenationOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
// TODO(renjieliu): We probably need to check for dynamic weights.
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.conv_2d
template <>
class TFLiteCostEstimator<Conv2DOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
// TODO(renjieliu): We probably need to check for dynamic weights.
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.depthwise_conv_2d
template <>
class TFLiteCostEstimator<DepthwiseConv2DOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.fully_connected
template <>
class TFLiteCostEstimator<FullyConnectedOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
// TODO(renjieliu): we need to check for dynamic weights.
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.logistic
template <>
class TFLiteCostEstimator<LogisticOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.max_pool_2d
template <>
class TFLiteCostEstimator<MaxPool2DOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.mirror_pad
template <>
class TFLiteCostEstimator<MirrorPadOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.maximum
template <>
class TFLiteCostEstimator<MaximumOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.minimum
template <>
class TFLiteCostEstimator<MinimumOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.mul
template <>
class TFLiteCostEstimator<MulOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.relu
template <>
class TFLiteCostEstimator<ReluOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.relu6
template <>
class TFLiteCostEstimator<Relu6Op, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.reshape
template <>
class TFLiteCostEstimator<ReshapeOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.softmax
template <>
class TFLiteCostEstimator<SoftmaxOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_

View File

@ -54,7 +54,7 @@ class TensorFlowLiteDialect : public Dialect {
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc"
// Include all specializes estimators below this line
#include "tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimator.h.inc"
#include "tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h"
} // end namespace TFL
} // end namespace mlir

View File

@ -139,8 +139,7 @@ def TFL_Uint8 : UI<8>;
def TFL_Int32Or64 : SignlessIntOfWidths<[32, 64]>;
def TFL_BoolTensor : TFL_TensorOf<[I1]>;
def TFL_FpOrI32OrI64Tensor : TFL_TensorOf<[AnyFloat, TFL_Int32Or64]>;
def TFL_FpTensor : TFL_TensorOf<[AnyFloat]>;
def TFL_FpTensor : TFL_TensorOf<[F32]>;
def TFL_I32OrI64Tensor : TFL_TensorOf<[TFL_Int32Or64]>;
def TFL_I32Tensor : TFL_TensorOf<[I32]>;
def TFL_I64Tensor : TFL_TensorOf<[I64]>;
@ -324,9 +323,9 @@ class TFL_ConvOp<string mnemonic, string opSummary, int index> :
}];
let arguments = (
ins AnyTensor:$input,
AnyTensor:$filter,
TFL_TensorOfOrNone<[AnyType]>:$bias,
ins TFL_TensorOf<[F32, QI8, QUI8]>:$input,
TFL_TensorOf<[F32, QI8, QUI8]>:$filter,
TFL_TensorOfOrNone<[F32, I32]>:$bias,
I32Attr:$dilation_h_factor,
I32Attr:$dilation_w_factor,
TFL_AFAttr:$fused_activation_function,
@ -335,7 +334,7 @@ class TFL_ConvOp<string mnemonic, string opSummary, int index> :
I32Attr:$stride_w
);
let results = (outs AnyTensor:$output);
let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output);
let hasOptions = 0b1;
}
@ -361,7 +360,10 @@ an output element, this operation computes \\(y = |x|\\).
let hasFolder = 1;
}
def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape, NoSideEffect, Commutative]> {
def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape,
NoSideEffect,
Commutative,
TFL_GpuTargetOp]> {
let summary = "Addition operator";
let description = [{
@ -394,11 +396,11 @@ def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect, SameOperandsAndResu
}];
let arguments = (ins
TFL_VariadicTensorOf<[F32, I32, QI16, QUI16]>:$inputs
TFL_VariadicTensorOf<[F32, I32]>:$inputs
);
let results = (outs
TFL_TensorOf<[F32, I32, QI16, QUI16]>:$sum
TFL_TensorOf<[F32, I32]>:$sum
);
}
@ -492,7 +494,7 @@ def TFL_AveragePool2DOp:
}];
let arguments = (
ins AnyTensor:$input,
ins TFL_TensorOf<[F32, QI8, QUI8]>:$input,
I32Attr:$filter_height,
I32Attr:$filter_width,
TFL_PaddingAttr:$padding,
@ -501,7 +503,7 @@ def TFL_AveragePool2DOp:
TFL_AFAttr:$fused_activation_function
);
let results = (outs AnyTensor:$output);
let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output);
let hasOptions = 1;
let customOption = "Pool2DOptions";
@ -577,7 +579,8 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation",
NoSideEffect,
PredOpTrait<"values and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
SameOperandsAndResultsScale
SameOperandsAndResultsScale,
TFL_GpuTargetOp
]> {
let summary = "Concatenation operator";
@ -719,7 +722,18 @@ def TFL_CosOp: TFL_Op<"cos", [
def TFL_DepthwiseConv2DOp :
TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> {
let arguments = !con(TFL_Conv2DOp.arguments, (ins I32Attr:$depth_multiplier));
let arguments = (
ins TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input,
TFL_TensorOf<[F32, QI8, QUI8]>:$filter,
TFL_TensorOfOrNone<[F32, I32, I64]>:$bias,
I32Attr:$dilation_h_factor,
I32Attr:$dilation_w_factor,
TFL_AFAttr:$fused_activation_function,
TFL_PaddingAttr:$padding,
I32Attr:$stride_h,
I32Attr:$stride_w,
I32Attr:$depth_multiplier
);
let extraClassDeclaration = [{
// ChannelDimIndexInterface:
@ -741,7 +755,8 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
TFL_ChannelDimIndexInterface,
AffineOpCoefficient<-1, 1>,
TFL_SparseOp]> {
TFL_SparseOp,
TFL_GpuTargetOp]> {
let summary = "Fully connected op";
let arguments = (ins
@ -1091,6 +1106,8 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [
}
def TFL_DivOp : TFL_Op<"div", [
// TODO(fengliuai): NoQuantizableResult is only correct for int8
// quantization. update to handle Uint8 quantization.
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
let summary = "Division operator";
@ -1099,11 +1116,11 @@ def TFL_DivOp : TFL_Op<"div", [
}];
let arguments = (
ins AnyTensor:$lhs,
AnyTensor:$rhs,
ins TFL_TensorOf<[F32, I32, QUI8]>:$lhs,
TFL_TensorOf<[F32, I32, TFL_Uint8]>:$rhs,
TFL_AFAttr:$fused_activation_function);
let results = (outs AnyTensor:$output);
let results = (outs TFL_TensorOf<[F32, I32, TFL_Uint8]>:$output);
let builders = [TFL_FusedBroadcastableBinaryBuilder];
@ -1126,7 +1143,7 @@ def TFL_EluOp: TFL_Op<"elu", [NoSideEffect, SameOperandsAndResultType]> {
let arguments = (ins TFL_FpTensor:$x);
let results = (outs AnyTensor:$y);
let results = (outs TFL_FpTensor:$y);
let hasOptions = 0;
}
@ -1487,16 +1504,17 @@ def TFL_LogisticOp: TFL_Op<"logistic", [
// zero_point = 0
// scale = 1. / (max_value + 1)
FixedResultScale<Int8UniformQuantizedType<-128, 390625, -8>>,
FixedResultScale<UInt8UniformQuantizedType<0, 390625, -8>>]> {
FixedResultScale<UInt8UniformQuantizedType<0, 390625, -8>>,
TFL_GpuTargetOp]> {
let summary = "Logistic operator";
let description = [{
Computes element-wise Sigmoid of input
}];
let arguments = (ins TFL_TensorOf<[AnyFloat, QI8, QUI8, QI16, QUI16]>:$x);
let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$x);
let results = (outs TFL_TensorOf<[AnyFloat, QI8, QUI8, QI16, QUI16]>:$y);
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$y);
}
def TFL_LogOp: TFL_Op<"log", [
@ -1639,19 +1657,23 @@ def TFL_MaxUnpooling2DOp :
}
def TFL_MaximumOp : TFL_Op<"maximum", [
ResultsBroadcastableShape, NoSideEffect, Commutative, SameOperandsAndResultsScale]> {
ResultsBroadcastableShape,
NoSideEffect,
Commutative,
SameOperandsAndResultsScale,
TFL_GpuTargetOp]> {
let summary = "Max operator";
let description = [{
Element-wise max operation.
}];
let arguments = (
ins TFL_TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$lhs,
TFL_TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$rhs
ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$lhs,
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$rhs
);
let results = (outs
TFL_TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$max
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$max
);
let builders = [TFL_BroadcastableBinaryBuilder];
@ -1837,19 +1859,23 @@ def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [NoSideEffect]> {
}
def TFL_MinimumOp : TFL_Op<"minimum", [
ResultsBroadcastableShape, NoSideEffect, Commutative, SameOperandsAndResultsScale]> {
ResultsBroadcastableShape,
NoSideEffect,
Commutative,
SameOperandsAndResultsScale,
TFL_GpuTargetOp]> {
let summary = "Min operator";
let description = [{
Element-wise min operation.
}];
let arguments = (
ins TFL_TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$lhs,
TFL_TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$rhs
ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$lhs,
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$rhs
);
let results = (outs
TFL_TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$min
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$min
);
let builders = [TFL_BroadcastableBinaryBuilder];
@ -1857,7 +1883,10 @@ def TFL_MinimumOp : TFL_Op<"minimum", [
let hasOptions = 0;
}
def TFL_MulOp : TFL_Op<"mul", [ResultsBroadcastableShape, NoSideEffect, Commutative]> {
def TFL_MulOp : TFL_Op<"mul", [ResultsBroadcastableShape,
NoSideEffect,
Commutative,
TFL_GpuTargetOp]> {
let summary = "Multiplication operator";
let description = [{
@ -2090,7 +2119,8 @@ def TFL_RankOp: TFL_Op<"rank", [NoSideEffect]> {
def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultsScale]> {
SameOperandsAndResultsScale,
TFL_GpuTargetOp]> {
let summary = "Relu operator";
let description = [{
@ -2105,7 +2135,8 @@ def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect,
def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultsScale]> {
SameOperandsAndResultsScale,
TFL_GpuTargetOp]> {
let summary = "Relu6 operator";
let description = [{
@ -2134,7 +2165,7 @@ def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect,
}
def TFL_ReshapeOp: TFL_Op<"reshape", [
NoSideEffect, SameOperandsAndResultsScale]> {
NoSideEffect, SameOperandsAndResultsScale, TFL_GpuTargetOp]> {
let summary = "Reshape operator";
let description = [{
@ -2351,7 +2382,8 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [
// zero_point = 0
// scale = 1. / (max_value + 1)
FixedResultScale<Int8UniformQuantizedType<-128, 390625, -8>>,
FixedResultScale<UInt8UniformQuantizedType<0, 390625, -8>>]> {
FixedResultScale<UInt8UniformQuantizedType<0, 390625, -8>>,
TFL_GpuTargetOp]> {
let summary = "Softmax operator";
let description = [{
@ -2882,7 +2914,7 @@ def TFL_CastOp : TFL_Op<"cast", [
TFL_TensorOf<[F32, I1, I32, I64, TFL_Quint8, TFL_Uint8, Complex<F<32>>]>:$input
);
let results = (outs TFL_TensorOf<[F32, I1, I32, I64, Complex<F<32>>]>:$output);
let results = (outs TFL_TensorOf<[F32, I1, I32, I64, TFL_Quint8, TFL_Uint8, Complex<F<32>>]>:$output);
// TFLite's cast op does not utilize CastOptions, instead derives types
// from the TfLiteTensors.
@ -2891,7 +2923,7 @@ def TFL_CastOp : TFL_Op<"cast", [
def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [
NoSideEffect, TFL_OperandHasRank<1, 2>]> {
NoSideEffect, TFL_OperandHasRank<1, 2>, TFL_GpuTargetOp]> {
let summary = "MirrorPad Operator. Pads a tensor with mirrored values.";
let description = [{
@ -3400,7 +3432,7 @@ def TFL_BidirectionalSequenceLSTMOp :
let summary = "Bidirectional sequence lstm operator";
let description = [{
Bidirectional lstm is essentiallay two lstms, one running forward & the
Bidirectional lstm is essentially two lstms, one running forward & the
other running backward. And the output is the concatenation of the two
lstms.
}];

View File

@ -111,3 +111,31 @@ tf_native_cc_binary(
"@llvm-project//mlir:TableGen",
],
)
cc_library(
name = "device_target",
srcs = ["device_target.cc"],
hdrs = ["device_target.h"],
deps = [
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:Support",
],
)
cc_library(
name = "quantization_context",
srcs = ["quantization_context.cc"],
hdrs = ["quantization_context.h"],
deps = [
":device_target",
":quantization_lib",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
)

View File

@ -0,0 +1,82 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
namespace mlir {
namespace quant {
constexpr int k8Bits = 8;
constexpr unsigned kSigned = quant::QuantizationFlags::Signed;
DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) {
f32_ = FloatType::getF32(ctx_);
i8_ = IntegerType::get(k8Bits, ctx_);
i8_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k8Bits);
i8_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k8Bits);
any_ = AnyQuantizedType();
qi8_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_, i8_max_);
qi8n_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_ + 1, i8_max_);
assert(qi8n_ == qi8n_);
}
Optional<KernelSpec> DeviceTarget::Get(QuantizeRegionOp op) const {
auto kernel_specs_it = specs_.find(op.logical_kernel());
if (kernel_specs_it == specs_.end()) return llvm::None;
KernelSpecs::Signature signature;
signature.reserve(op.input_specs().size() + op.output_specs().size());
AppendToSignature(op.input_specs(), &signature);
AppendToSignature(op.output_specs(), &signature);
return kernel_specs_it->getValue().Find(signature);
}
LogicalResult DeviceTarget::RegisterKernel(
llvm::StringRef kernel, const KernelSpecs::Signature& signature,
const ScaleFn& fn) {
return specs_[kernel].Add(signature, {ScaleConstraintType::CustomScale, fn});
}
LogicalResult DeviceTarget::RegisterKernel(
llvm::StringRef kernel, const KernelSpecs::Signature& signature,
const ScaleConstraintType constraint) {
return specs_[kernel].Add(signature, {constraint, {}});
}
void DeviceTarget::AppendToSignature(ArrayAttr specs_attr,
KernelSpecs::Signature* signature) const {
for (auto attr : specs_attr) {
Type spec = attr.cast<TypeAttr>().getValue();
if (auto quant = spec.dyn_cast<UniformQuantizedType>()) {
signature->push_back(AnyQuantizedType::get(
quant.getFlags(), quant.getStorageType(), quant.getExpressedType(),
quant.getStorageTypeMin(), quant.getStorageTypeMax()));
} else if (auto any = spec.dyn_cast<AnyQuantizedType>()) {
signature->push_back(any);
} else { // float
signature->push_back({});
}
}
}
} // namespace quant
} // namespace mlir

View File

@ -0,0 +1,147 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_DEVICE_TARGET_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_DEVICE_TARGET_H_
#include <functional>
#include <ostream>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
namespace mlir {
namespace quant {
class QuantizeContext;
using AdjacentOperations = llvm::SmallVectorImpl<Operation*>;
using ScaleFn = std::function<LogicalResult(QuantizeContext*, Operation*,
AdjacentOperations*, bool*)>;
enum class ScaleConstraintType {
OutputInputSameScale,
OutputInputFreeScale,
CustomScale,
};
// Each kernel signature has its own specification for scales.
struct KernelSpec {
// Scale constraint
ScaleConstraintType type;
// Custom function to derive the scales. Only available when the scale
// constraint is `CustomScale`.
ScaleFn scale_fn;
};
class KernelSpecs {
public:
using Signature = llvm::SmallVector<quant::AnyQuantizedType, 4>;
// Returns the kernel specification for the kernel signature.
Optional<KernelSpec> Find(const Signature& signature) const {
auto spec_it = all_signatures_.find(signature);
if (spec_it != all_signatures_.end()) {
return spec_it->second;
} else {
return llvm::None;
}
}
// Adds the kernel signature with the kernel specification.
LogicalResult Add(const Signature& signature, const KernelSpec& spec) {
if (all_signatures_.insert({signature, spec}).second) return success();
return failure();
}
private:
// The signature is pattern match based.
struct SignatureInfo : public llvm::DenseMapInfo<Signature> {
static inline Signature getEmptyKey() { return {}; }
static inline Signature getTombstoneKey() { return {nullptr}; }
static unsigned getHashValue(Signature val) {
return llvm::hash_combine_range(val.begin(), val.end());
}
static bool isEqual(Signature LHS, Signature RHS) {
if (RHS == getEmptyKey()) return LHS == getEmptyKey();
if (RHS == getTombstoneKey()) return LHS == getTombstoneKey();
if (LHS.size() != RHS.size()) return false;
for (auto arg : llvm::zip(LHS, RHS)) {
if (std::get<0>(arg) != std::get<1>(arg)) return false;
}
return true;
}
};
// Maps the signature to the kernel spec. Note that the matching is
// pattern match based.
llvm::DenseMap<Signature, KernelSpec, SignatureInfo> all_signatures_;
};
class DeviceTarget {
public:
explicit DeviceTarget(MLIRContext* ctx);
// Retrieves the kernel spec for the quant region op.
Optional<KernelSpec> Get(quant::QuantizeRegionOp op) const;
protected:
// Adds the kernel spec with the custom scale function for the kernel.
LogicalResult RegisterKernel(llvm::StringRef kernel,
const KernelSpecs::Signature& signature,
const ScaleFn& fn);
// Adds the kernel spec with the scale constraint type for the kernel.
LogicalResult RegisterKernel(llvm::StringRef kernel,
const KernelSpecs::Signature& signature,
const ScaleConstraintType constraint);
// converts specification to signature:
// - UniformedQuantizedType -> AnyQuantizedType
// - AnyQuantizedType (int) -> AnyQuantizedType
// - Float -> {}
void AppendToSignature(ArrayAttr specs_attr,
KernelSpecs::Signature* signature) const;
// A set of parameters are required to build the signatures.
FloatType f32_;
IntegerType i8_;
int64_t i8_min_, i8_max_;
AnyQuantizedType any_, qi8_, qi8n_;
private:
// Maps the kernel names to all the available kernels.
llvm::StringMap<KernelSpecs> specs_;
// Points to the global MLIRContext.
MLIRContext* ctx_;
};
} // namespace quant
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_DEVICE_TARGET_H_

View File

@ -34,11 +34,11 @@ limitations under the License.
namespace mlir {
namespace lite {
// TODO(fengliuai): check the result for `allow_float` flag.
// TODO(fengliuai): check the result for `fully_quantize` flag.
TfLiteStatus QuantizeModel(
const tflite::ModelT& input_model, const tflite::TensorType& input_type,
const tflite::TensorType& output_type,
const std::unordered_set<std::string>& operator_names, bool allow_float,
const std::unordered_set<std::string>& operator_names, bool fully_quantize,
flatbuffers::FlatBufferBuilder* builder,
tflite::ErrorReporter* error_reporter) {
// TODO(b/142502494): remove this restriction by improving the `emit_adaptor`

View File

@ -27,11 +27,11 @@ namespace lite {
// Quantize the `input_model` and write the result to a flatbuffer `builder`.
// The `input_type` and `output_type` can be float32/qint8/int8.
// Return partially quantized model if `allow_float` is true.
// Return partially quantized model if `fully_quantize` is false.
TfLiteStatus QuantizeModel(
const tflite::ModelT& input_model, const tflite::TensorType& input_type,
const tflite::TensorType& output_type,
const std::unordered_set<std::string>& operator_names, bool allow_float,
const std::unordered_set<std::string>& operator_names, bool fully_quantize,
flatbuffers::FlatBufferBuilder* builder,
tflite::ErrorReporter* error_reporter);
} // namespace lite

View File

@ -47,7 +47,7 @@ TfLiteStatus QuantizeAnnotatedModel(llvm::StringRef buffer,
tflite::StderrReporter error_reporter;
return mlir::lite::QuantizeModel(
*model, tflite::TensorType_INT8, tflite::TensorType_INT8, {},
/*allow_float=*/false, builder, &error_reporter);
/*fully_quantize=*/true, builder, &error_reporter);
}
} // namespace

View File

@ -0,0 +1,239 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#define DEBUG_TYPE "quantization-context"
namespace mlir {
namespace quant {
QuantizeContext::QuantizeContext(FuncOp func, const DeviceTarget &spec)
: func_(func), target_spec_(spec) {
llvm::DenseMap<Value, int> value_to_state;
func.walk([&](quant::QuantizeRegionOp op) {
for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
states_manager_.InitializeOperandState(op, i, &value_to_state);
}
for (int res = 0, e = op.getNumResults(); res != e; ++res) {
states_manager_.InitializeResultState(op, res, &value_to_state);
}
});
}
llvm::ArrayRef<quant::QuantizeRegionOp> QuantizeContext::GetAllOps() {
llvm::SmallVector<quant::QuantizeRegionOp, 64> all_ops;
func_.walk([&](quant::QuantizeRegionOp op) { all_ops.push_back(op); });
return all_ops;
}
LogicalResult QuantizeContext::Handle(
quant::QuantizeRegionOp op, llvm::SmallVectorImpl<Operation *> *new_items,
bool *changed) {
auto spec = target_spec_.Get(op);
if (!spec.hasValue()) {
op.emitWarning(
"Couldn't find kernel from the registeration for quantization.");
return success();
}
switch (spec->type) {
case ScaleConstraintType::OutputInputFreeScale: {
// no propagation.
*changed = false;
break;
}
case ScaleConstraintType::CustomScale: {
if (failed(spec->scale_fn(this, op, new_items, changed))) {
return failure();
}
break;
}
default: {
llvm_unreachable("no implementation.");
return failure();
}
}
return success();
}
LogicalResult QuantizeContext::Finalize() {
MLIRContext *context = func_.getContext();
func_.walk([&](quant::QuantizeRegionOp op) {
llvm::SmallVector<Attribute, 4> input_specs;
auto original_input_specs = op.input_specs().getValue();
for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
auto &state = states_manager_.GetOperandQuantState(op, i);
auto &requantize = states_manager_.GetOperandRequantizeState(op, i);
if (state.IsEmpty() && requantize.pos == RequantizeState::NO_REQUANTIZE) {
input_specs.push_back(original_input_specs[i]);
} else if (requantize.pos == RequantizeState::ON_OUTPUT) {
input_specs.push_back(TypeAttr::get(requantize.params));
} else {
input_specs.push_back(TypeAttr::get(state.params));
}
}
op.setAttr("input_specs", ArrayAttr::get(input_specs, context));
llvm::SmallVector<Attribute, 4> output_specs;
auto original_output_specs = op.output_specs().getValue();
for (int res = 0, e = op.getNumResults(); res != e; ++res) {
auto &state = states_manager_.GetResultQuantState(op, res);
auto &requantize = states_manager_.GetResultRequantizeState(op, res);
if (state.IsEmpty() && requantize.pos == RequantizeState::NO_REQUANTIZE) {
output_specs.push_back(original_output_specs[res]);
} else if (requantize.pos == RequantizeState::ON_INPUT) {
output_specs.push_back(TypeAttr::get(requantize.params));
} else {
output_specs.push_back(TypeAttr::get(state.params));
}
}
op.setAttr("output_specs", ArrayAttr::get(output_specs, context));
});
return success();
}
void QuantizeContext::DumpStates(QuantizeRegionOp current_op) {
if (current_op) {
llvm::errs() << "\n\n\n" << current_op.logical_kernel() << "\n";
}
func_.walk([&](QuantizeRegionOp op) {
if (current_op == op) llvm::errs() << "===>>>";
llvm::errs() << op.logical_kernel() << " : (";
for (auto i = 0; i < op.getNumOperands(); ++i) {
if (auto params = GetOperandParams(op, i))
params.print(llvm::errs());
else
llvm::errs() << "_";
llvm::errs() << ",";
}
llvm::errs() << ") -> (";
for (auto i = 0; i < op.getNumResults(); ++i) {
if (auto params = GetResultParams(op, i))
params.print(llvm::errs());
else
llvm::errs() << "_";
llvm::errs() << ",";
}
llvm::errs() << ")\n";
});
}
int QuantizeContext::StatesManager::InitializeState(quant::QuantizeRegionOp op,
int index, bool as_result) {
Attribute params_attr;
if (as_result) {
params_attr = op.output_specs()[index];
} else {
params_attr = op.input_specs()[index];
}
QuantParams params =
params_attr.cast<TypeAttr>().getValue().dyn_cast<QuantParams>();
bool immutable = !EmptyParams(params);
int next_state_index = states_.size();
states_.push_back({params, immutable});
if (as_result) {
result_states_.insert({{op, index}, next_state_index});
} else {
operand_states_.insert({{op, index}, next_state_index});
}
return next_state_index;
}
void QuantizeContext::StatesManager::InitializeOperandState(
quant::QuantizeRegionOp op, int index, llvm::DenseMap<Value, int> *cache) {
Value in = op.getOperand(index);
auto cached = cache->insert({in, 0});
if (!cached.second) {
operand_states_.insert({{op, index}, cached.first->second});
return;
}
cached.first->second = InitializeState(op, index, /*as_result=*/false);
}
void QuantizeContext::StatesManager::InitializeResultState(
quant::QuantizeRegionOp op, int index, llvm::DenseMap<Value, int> *cache) {
auto res = op.getResult(index);
auto cached = cache->insert({res, 0});
if (!cached.second) {
result_states_.insert({{op, index}, cached.first->second});
return;
}
cached.first->second = InitializeState(op, index, /*as_result=*/true);
}
bool QuantizeContext::StatesManager::SetConstantResultParams(Operation *op) {
llvm_unreachable("no implementation.");
return false;
}
bool QuantizeContext::StatesManager::SetResultParams(Operation *op,
int res_index,
QuantParams params) {
auto &state = GetResultQuantState(op, res_index);
if (state.params == params) {
return false;
}
if (!state.IsEmpty()) {
auto &rescale = GetResultRequantizeState(op, res_index);
rescale.params = params;
rescale.pos = RequantizeState::ON_INPUT;
return false;
}
state.params = params;
return true;
}
bool QuantizeContext::StatesManager::SetOperandParams(Operation *op, int index,
QuantParams params) {
auto &state = GetOperandQuantState(op, index);
if (state.params == params) {
return false;
}
if (!state.IsEmpty()) {
auto &rescale = GetOperandRequantizeState(op, index);
rescale.params = params;
rescale.pos = RequantizeState::ON_OUTPUT;
return false;
}
state.params = params;
return true;
}
} // namespace quant
} // namespace mlir

View File

@ -0,0 +1,217 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_
#include "llvm/ADT/DenseMap.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
namespace mlir {
namespace quant {
static bool EmptyParams(QuantParams p) { return p == quant::QuantizedType(); }
// The state for each op result during the quantization parameters propagation.
struct QuantState {
// Quantization parameters propagated to an op result.
QuantParams params;
// A flag indicates this state (the params) shouldn't be changed after it is
// initialized. This flag will be set to true if the quantization parameters
// are from the quantization-aware training.
const bool immutable;
bool IsEmpty() { return EmptyParams(params); }
};
// The state for rescaling the propagated quantization parameters. This can be
// on the input side to satisfy the constraint of previous operation, or on the
// output side to satisfy the constraint of the next operation.
struct RequantizeState {
// Sometimes, we have to "requantize" the quantization result to satisfy all
// the constraints. The "requantize" can happen either on the input or output
// of the quantization result.
enum RequantizePosition {
NO_REQUANTIZE,
ON_INPUT,
ON_OUTPUT
} pos = NO_REQUANTIZE;
// Quantization parameters will be used to add the requantize ops.
QuantParams params;
};
// This class manages all the intermedaite quantization states.
class QuantizeContext {
public:
QuantizeContext(FuncOp func, const DeviceTarget &spec);
// Returns all the quant region ops.
ArrayRef<quant::QuantizeRegionOp> GetAllOps();
// For each quant region op, propagates its quantization parameters according
// to the kernel specification and also returns the adjcent quant region ops
// which get the new quantization parameters propagated.
LogicalResult Handle(quant::QuantizeRegionOp op,
llvm::SmallVectorImpl<Operation *> *new_items,
bool *changed);
// Updates the port quantization specifications of all the quant region ops
// with the propagation results.
LogicalResult Finalize();
// Dumps the states stores in the state manager.
void DumpStates(QuantizeRegionOp current_op = {});
// Update the quantization parameter for certain result of the op. By this
// method, the quantization parameter is propagated to all the users of the
// result as well.
bool SetResultParams(Operation *op, int index, QuantParams params) {
return states_manager_.SetResultParams(op, index, params);
}
// Update the quantization parameter for certain operand of the op. By this
// method, the quantization parameter is propagated to the defining op of
// operand as well.
bool SetOperandParams(Operation *op, int index, QuantParams params) {
return states_manager_.SetOperandParams(op, index, params);
}
// Return the quantization parameter of certain result of the op.
QuantParams GetResultParams(Operation *op, int index) {
return states_manager_.GetResultParams(op, index);
}
// Return the quantization parameter of certain operand of the op.
QuantParams GetOperandParams(Operation *op, int index) {
return states_manager_.GetOperandParams(op, index);
}
private:
class StatesManager {
public:
// Sets the quantization parameters of the constant result according to its
// content.
//
// Always returns true.
bool SetConstantResultParams(Operation *op);
// Sets the quantization parameters of the result to a fixed value. If any
// quantization parameters have been propagated, a `requantize` will happen
// on the input of propagated quantization.
//
// Returns true, if the users of the result needs to be added to the
// worklist.
bool SetResultParams(Operation *op, int index, QuantParams params);
// Sets the quantization parameters of the operand to a fixed value. If any
// quantization parameters have been propagated, a `requantize` will happen
// on the output of propagated quantization.
//
// Returns true, if the defining op of the operand needs to be added to the
// worklist.
bool SetOperandParams(Operation *op, int index, QuantParams params);
// Returns the quantization parameters of the index-th result of the op.
QuantParams GetResultParams(Operation *op, int index) {
return states_[result_states_[{op, index}]].params;
}
// Returns the quantization parameters of the index-th operand of the op.
QuantParams GetOperandParams(Operation *op, int index) {
return states_[operand_states_[{op, index}]].params;
}
private:
friend class QuantizeContext;
// Uses the type of `val` to set the initial state of the index-th result if
// `as_result` is true or index-th operand if `as_result` is false. The
// state is immutable if the type is a quantized type. Returns the index of
// this new state in the state vector.
int InitializeState(quant::QuantizeRegionOp op, int index, bool as_result);
// Sets the state of the index-th operand of the op. If this operand is
// cached, uses the cached result without creating new entry in the state
// vector. Otherwise, allocate a new entry in the state vector.
void InitializeOperandState(quant::QuantizeRegionOp op, int index,
llvm::DenseMap<Value, int> *cache);
// Sets the state of the index-th result of the op. If this result is
// cached, uses the cached result without creating new entry in the state
// vector. Otherwise, allocate a new entry in the state vector.
void InitializeResultState(quant::QuantizeRegionOp op, int index,
llvm::DenseMap<Value, int> *cache);
// Returns the state of the index-th operand of the op.
QuantState &GetOperandQuantState(Operation *op, int index) {
return states_[operand_states_[{op, index}]];
}
// Returns the state of the index-th result of the op.
QuantState &GetResultQuantState(Operation *op, int index) {
return states_[result_states_[{op, index}]];
}
// Returns the state of the index-th operand of the op.
RequantizeState &GetOperandRequantizeState(Operation *op, int index) {
return rescale_states_[operand_states_[{op, index}]];
}
// Returns the state of the index-th result of the op.
RequantizeState &GetResultRequantizeState(Operation *op, int index) {
return rescale_states_[result_states_[{op, index}]];
}
private:
// This is used to identify an operand or result of an op. The second
// element of this pair is the index of the operand or result.
using OpValue = std::pair<mlir::Operation *, int>;
// The vector contains all the quantization parameters propagated from the
// defining operations of the value, or from the quantization aware
// training.
std::vector<QuantState> states_;
// The map contains all the quantization parameters which are required to
// satisfy the same operands and results constraint. The keys of this map
// are the values from `operand_states_` and `result_state_`.
std::unordered_map<int, RequantizeState> rescale_states_;
// Maps of indexes to the propagation state vector from the ops operands,
// results and arguments.
llvm::DenseMap<OpValue, int> operand_states_;
llvm::DenseMap<OpValue, int> result_states_;
};
FuncOp func_;
DeviceTarget target_spec_;
StatesManager states_manager_;
};
} // namespace quant
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_

View File

@ -33,7 +33,9 @@ cc_library(
"passes.h",
],
deps = [
":cpu_device_target",
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"//tensorflow/compiler/mlir/lite/quantization:quantization_context",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/xla/client/lib:quantize",
@ -49,6 +51,23 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "cpu_device_target",
srcs = [
"cpu_device_target.cc",
],
hdrs = [
"cpu_device_target.h",
],
deps = [
"//tensorflow/compiler/mlir/lite/quantization:device_target",
"//tensorflow/compiler/mlir/lite/quantization:quantization_context",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:Support",
],
)
cc_library(
name = "quantize",
srcs = [

View File

@ -0,0 +1,59 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h"
namespace mlir {
namespace xla_hlo {
namespace ph = std::placeholders;
CpuDeviceTarget::CpuDeviceTarget(MLIRContext* ctx) : DeviceTarget(ctx) {
RegisterKernel("generic.concat", {qi8_, qi8_, qi8_},
quant::ScaleConstraintType::OutputInputSameScale);
RegisterKernel("generic.mul", {qi8_, qi8_, qi8_},
quant::ScaleConstraintType::OutputInputFreeScale);
RegisterKernel("generic.mul_add", {qi8_, qi8n_, any_, qi8_},
std::bind(&CpuDeviceTarget::HandleMultiplyAccumulateScale,
this, ph::_1, ph::_2, ph::_3, ph::_4));
RegisterKernel("generic.matmul_add", {qi8_, qi8n_, any_, qi8_},
std::bind(&CpuDeviceTarget::HandleMultiplyAccumulateScale,
this, ph::_1, ph::_2, ph::_3, ph::_4));
}
LogicalResult CpuDeviceTarget::HandleMultiplyAccumulateScale(
quant::QuantizeContext* ctx, Operation* op,
quant::AdjacentOperations* new_items, bool* changed) {
auto bias_params = ctx->GetOperandParams(op, 2);
if (!EmptyParams(bias_params)) {
return success();
}
std::vector<quant::QuantParams> op_types{ctx->GetOperandParams(op, 0),
ctx->GetOperandParams(op, 1)};
auto bias_scale = GetUniformQuantizedTypeForBias(op_types);
if (bias_scale && ctx->SetOperandParams(op, 2, bias_scale)) {
*changed = true;
new_items->push_back(op->getOperand(2).getDefiningOp());
}
return success();
}
} // namespace xla_hlo
} // namespace mlir

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_CPU_DEVICE_TARGET_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_CPU_DEVICE_TARGET_H_
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
namespace mlir {
namespace xla_hlo {
// Target specs for cpu kernels
class CpuDeviceTarget : public quant::DeviceTarget {
public:
explicit CpuDeviceTarget(MLIRContext* ctx);
private:
LogicalResult HandleMultiplyAccumulateScale(
quant::QuantizeContext* ctx, Operation* op,
quant::AdjacentOperations* new_items, bool* changed);
};
} // namespace xla_hlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_CPU_DEVICE_TARGET_H_

View File

@ -26,7 +26,9 @@ limitations under the License.
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.h"
// NOLINTNEXTLINE
static llvm::cl::opt<bool> disable_per_channel(
@ -59,9 +61,36 @@ struct PropagateQuantPass : public FunctionPass<PropagateQuantPass> {
void PropagateQuantPass::runOnFunction() {
FuncOp func = getFunction();
// TODO(fengliuai): deprecate this old code generation path.
// XLA only support uint8/uint16 quantization for now.
ApplyQuantizationParamsPropagation(func, /*is_signed*/ false,
disable_per_channel, GetOpQuantSpec);
CpuDeviceTarget spec(&getContext());
quant::QuantizeContext ctx(func, spec);
std::vector<quant::QuantizeRegionOp> work_list(ctx.GetAllOps());
bool changed = false;
while (!work_list.empty()) {
quant::QuantizeRegionOp op = work_list.back();
work_list.pop_back();
llvm::SmallVector<Operation *, 4> new_items;
if (failed(ctx.Handle(op, &new_items, &changed))) {
// The IR is still valid, thus we shouldn't fail.
signalPassFailure();
}
for (auto item : new_items) {
if (auto reg = llvm::dyn_cast_or_null<quant::QuantizeRegionOp>(item))
work_list.push_back(reg);
}
}
if (!changed) return;
if (failed(ctx.Finalize())) {
signalPassFailure();
}
}
} // namespace

View File

@ -0,0 +1,54 @@
// RUN: tf-opt -xla-hlo-propagate-quant %s | FileCheck %s --dump-input-on-failure
// -----
// CHECK-LABEL: @mul_add_source_no_params
func @mul_add_source_no_params(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
%region = "quant.region"(%arg0, %arg1, %arg2) ( {
^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
%mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
%add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
"quant.return"(%add) : (tensor<4xf32>) -> ()
}) {input_specs = [f32, f32, f32], logical_kernel = "generic.mul_add", output_specs = [f32]} :
(tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %region : tensor<4xf32>
// CHECK: input_specs = [f32, f32, f32]
// CHECK-SAME: output_specs = [f32]
}
// -----
// CHECK-LABEL: @mul_add_annotated_no_narrow_range
func @mul_add_annotated_no_narrow_range(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
%region = "quant.region"(%arg0, %arg1, %arg2) ( {
^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
%mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
%add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
"quant.return"(%add) : (tensor<4xf32>) -> ()
}) {input_specs = [!quant.uniform<i8:f32, 1.0:-128>, !quant.uniform<i8:f32, 1.0:-128>, f32],
logical_kernel = "generic.mul_add", output_specs = [!quant.uniform<i8:f32, 1.0:-128>]} :
(tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %region : tensor<4xf32>
// CHECK: input_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>, !quant.uniform<i8:f32, 1.000000e+00:-128>, f32]
// CHECK-SAME: output_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>]
}
// -----
// CHECK-LABEL: @mul_add_annotated
func @mul_add_annotated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
%region = "quant.region"(%arg0, %arg1, %arg2) ( {
^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
%mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
%add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
"quant.return"(%add) : (tensor<4xf32>) -> ()
}) {input_specs = [!quant.uniform<i8:f32, 1.0:-128>, !quant.uniform<i8<-127:127>:f32, 1.0:-128>, f32],
logical_kernel = "generic.mul_add", output_specs = [!quant.uniform<i8:f32, 1.0:-128>]} :
(tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %region : tensor<4xf32>
// CHECK: input_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>, !quant.uniform<i8<-127:127>:f32, 1.000000e+00:-128>, !quant.uniform<i32:f32, 1.000000e+00>]
// CHECK-SAME: output_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>]
}

View File

@ -5,6 +5,11 @@ package(licenses = ["notice"])
glob_lit_tests(
data = [":test_utilities"],
driver = "@llvm-project//mlir:run_lit.sh",
tags_override = {
"legalize-tf.mlir": ["no_rocm"],
"optimize.mlir": ["no_rocm"],
"prepare-tf.mlir": ["no_rocm"],
},
test_file_exts = ["mlir"],
)

View File

@ -1,6 +1,6 @@
# RUN: not tf_tfl_translate -mlir-pretty-debuginfo -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes=1,224,224,3 -tf-output-arrays=MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm -tf-debug-info=%s.debug %s -o - 2>&1 | FileCheck %s
# CHECK: fake/user/code/file_C.py: error: 'tf.Conv2D' op attribute 'data_format' failed to satisfy constraint: 'NHWC' or 'NCHW' convnet data format
# CHECK: fake/user/code/file_C.py:27:0: error: 'tf.Conv2D' op attribute 'data_format' failed to satisfy constraint: 'NHWC' or 'NCHW' convnet data format
node {
name: "input"

View File

@ -1,9 +1,9 @@
# RUN: not tf_tfl_translate -mlir-pretty-debuginfo -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes=1,224,224,3 -tf-output-arrays=MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm -tf-debug-info=%s.debug %s -o - 2>&1 | FileCheck %s
# CHECK: fake/user/code/file_C.py: error: 'tf.Conv2D' op attribute 'data_format' failed to satisfy constraint: 'NHWC' or 'NCHW' convnet data format
# CHECK: fake/user/code/file_D.py: note: called from
# CHECK: fake/user/code/file_E.py: note: called from
# CHECK: fake/user/code/file_F.py: note: called from
# CHECK: fake/user/code/file_C.py:27:0: error: 'tf.Conv2D' op attribute 'data_format' failed to satisfy constraint: 'NHWC' or 'NCHW' convnet data format
# CHECK: fake/user/code/file_D.py:28:0: note: called from
# CHECK: fake/user/code/file_E.py:29:0: note: called from
# CHECK: fake/user/code/file_F.py:30:0: note: called from
node {
name: "input"

View File

@ -8,6 +8,12 @@ glob_lit_tests(
":test_utilities",
],
driver = "@llvm-project//mlir:run_lit.sh",
tags_override = {
"add.pbtxt": ["no_rocm"],
"conv_2d.pbtxt": ["no_rocm"],
"fake_quant_per_channel.pbtxt": ["no_rocm"],
"ophint_lstm.pbtxt": ["no_rocm"],
},
test_file_exts = [
"pbtxt",
],

View File

@ -50,8 +50,8 @@ func @while_cond_10_frozen0(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: t
// INLINE: ^bb0([[ARGS]]):
// INLINE: %cst_2 = constant
// INLINE: yield
// INLINE: while_body
// INLINE: while_cond
// INLINE-NOT: while_body
// INLINE-NOT: while_cond
// CANON-LABEL: func @while_main
// CANON-SAME: ([[VAL_0:%.*]]: tensor<?x256x256xf32>)

View File

@ -192,11 +192,11 @@ func @argmin(%arg0: tensor<3xi32>, %arg1: tensor<i32>) -> tensor<i32> {
// CHECK: "tfl.arg_min"(%arg0, %arg1) : (tensor<3xi32>, tensor<i32>) -> tensor<i32>
}
func @sigmoid(%arg0: tensor<?x88xf16>) -> tensor<?x88xf16> {
%0 = "tf.Sigmoid"(%arg0) : (tensor<?x88xf16>) -> tensor<?x88xf16>
return %0 : tensor<?x88xf16>
func @sigmoid(%arg0: tensor<?x88xf32>) -> tensor<?x88xf32> {
%0 = "tf.Sigmoid"(%arg0) : (tensor<?x88xf32>) -> tensor<?x88xf32>
return %0 : tensor<?x88xf32>
// CHECK-LABEL: sigmoid
// CHECK: "tfl.logistic"(%arg0) : (tensor<?x88xf16>) -> tensor<?x88xf16>
// CHECK: "tfl.logistic"(%arg0) : (tensor<?x88xf32>) -> tensor<?x88xf32>
}
func @sqrt(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
@ -1316,16 +1316,6 @@ func @assert_remove(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi1>
// CHECK: return
}
func @reciprocal_f16(%arg0: tensor<8xf16>) -> tensor<8xf16> {
%0 = "tf.Reciprocal"(%arg0) : (tensor<8xf16>) -> tensor<8xf16>
return %0: tensor<8xf16>
// CHECK-LABEL: reciprocal_f16
// CHECK: %cst = constant dense<1.000000e+00> : tensor<f16>
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<f16>, tensor<8xf16>) -> tensor<8xf16>
// CHECK: return
}
func @reciprocal_f32(%arg0: tensor<8xf32>) -> tensor<8xf32> {
%0 = "tf.Reciprocal"(%arg0) : (tensor<8xf32>) -> tensor<8xf32>
return %0: tensor<8xf32>
@ -1336,16 +1326,6 @@ func @reciprocal_f32(%arg0: tensor<8xf32>) -> tensor<8xf32> {
// CHECK: return
}
func @reciprocal_complex_f32(%arg0: tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>> {
%0 = "tf.Reciprocal"(%arg0) : (tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>>
return %0: tensor<8xcomplex<f32>>
// CHECK-LABEL: reciprocal_complex_f32
// CHECK: %cst = constant opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C455836342074656E736F725F7368617065207B2064696D207B2073697A653A2031207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3230303F5C3030305C3030305C3030305C30303022"> : tensor<complex<f32>>
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<complex<f32>>, tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>>
// CHECK: return
}
func @reciprocal_i32(%arg0: tensor<8xi32>) -> tensor<8xi32> {
%0 = "tf.Reciprocal"(%arg0) : (tensor<8xi32>) -> tensor<8xi32>
return %0: tensor<8xi32>
@ -1356,16 +1336,6 @@ func @reciprocal_i32(%arg0: tensor<8xi32>) -> tensor<8xi32> {
// CHECK: return
}
func @reciprocal_i64(%arg0: tensor<8xi64>) -> tensor<8xi64> {
%0 = "tf.Reciprocal"(%arg0) : (tensor<8xi64>) -> tensor<8xi64>
return %0: tensor<8xi64>
// CHECK-LABEL: reciprocal_i64
// CHECK: %cst = constant dense<1> : tensor<i64>
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<i64>, tensor<8xi64>) -> tensor<8xi64>
// CHECK: return
}
func @random_uniform() -> tensor<2x5xf32> {
%0 = "tf.Const"() { value = dense<[2, 5]> : tensor<2xi32> } : () -> tensor<2xi32>
%1 = "tf.RandomUniform"(%0) { seed = 1, seed2 = 0} : (tensor<2xi32>) -> tensor<2x5xf32>

View File

@ -14,7 +14,6 @@
// CHECK-NEXT: Tensor 2 val kTfLiteFloat32 kTfLiteMmapRo 4 bytes
// CHECK-NEXT: Tensor 3 tfl.while kTfLiteInt32 kTfLiteArenaRw 4 bytes
// CHECK-NEXT: Tensor 4 result kTfLiteFloat32 kTfLiteArenaRw 4 bytes
// CHECK-NEXT: Tensor 5 tfl.while:2 kTfLiteInt32 kTfLiteArenaRw 4 bytes
// Verify while was not folded away:
// ------------------------------------

View File

@ -16,7 +16,7 @@ func @testCos(tensor<? x f32>) -> tensor<? x f32> {
// test invalid Cos input
func @testCosWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
^bb0(%arg0: tensor<?xi32>):
// expected-error @+1 {{tfl.cos' op operand #0 must be tensor of floating-point values}}
// expected-error @+1 {{tfl.cos' op operand #0 must be tensor of 32-bit float values}}
%0 = "tfl.cos"(%arg0): (tensor<?xi32>) -> tensor<?xi32>
return %0#0 : tensor<?xi32>
}
@ -103,7 +103,7 @@ func @testAddN(tensor<? x f32>, tensor<? x f32>, tensor<? x f32>) -> tensor<? x
// test invalid AddN
func @testAddNWrongOperandResultType(tensor<? x f16>, tensor<? x f16>, tensor<? x f16>) -> tensor<? x f16> {
^bb0(%arg0: tensor<? x f16>, %arg1: tensor<? x f16>, %arg2: tensor<? x f16>):
// expected-error @+1 {{'tfl.add_n' op operand #0 must be tensor of 32-bit float or 32-bit signless integer or QI16 type or QUI16 type values}}
// expected-error @+1 {{'tfl.add_n' op operand #0 must be tensor of 32-bit float or 32-bit signless integer}}
%0 = "tfl.add_n"(%arg0, %arg1, %arg2): (tensor<? x f16>, tensor<? x f16>, tensor<? x f16>) -> tensor<? x f16>
return %0 : tensor<? x f16>
}
@ -147,7 +147,7 @@ func @testSin(tensor<? x f32>) -> tensor<? x f32> {
// test invalid Sin input
func @testSinWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
^bb0(%arg0: tensor<?xi32>):
// expected-error @+1 {{tfl.sin' op operand #0 must be tensor of floating-point values}}
// expected-error @+1 {{tfl.sin' op operand #0 must be tensor of 32-bit float values}}
%0 = "tfl.sin"(%arg0): (tensor<?xi32>) -> tensor<?xi32>
return %0#0 : tensor<?xi32>
}
@ -157,7 +157,7 @@ func @testSinWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
// test invalid Sqrt input
func @testSqrtWithWrongInputType(tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x i32>):
// expected-error @+1 {{tfl.sqrt' op operand #0 must be tensor of floating-point values}}
// expected-error @+1 {{tfl.sqrt' op operand #0 must be tensor of 32-bit float values}}
%0 = "tfl.sqrt"(%arg0): (tensor<? x i32>) -> tensor<? x i32>
return %0#0 : tensor<? x i32>
}
@ -167,7 +167,7 @@ func @testSqrtWithWrongInputType(tensor<? x i32>) -> tensor<? x i32> {
// test invalid Square input
func @testSquareWithWrongInputType(tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x i32>):
// expected-error @+1 {{tfl.square' op operand #0 must be tensor of floating-point values}}
// expected-error @+1 {{tfl.square' op operand #0 must be tensor of 32-bit float values}}
%0 = "tfl.square"(%arg0): (tensor<? x i32>) -> tensor<? x i32>
return %0#0 : tensor<? x i32>
}
@ -425,7 +425,7 @@ func @testTileF32(%arg0: tensor<4 x 1 x f32>, %arg1: tensor<4 x i32>) -> tensor<
// -----
func @testEluI32(%arg0: tensor<? x i32>) -> tensor<? x i32> {
// expected-error @+1 {{operand #0 must be tensor of floating-point values}}
// expected-error @+1 {{op operand #0 must be tensor of 32-bit float values}}
%0 = "tfl.elu"(%arg0): (tensor<? x i32>) -> tensor<? x i32>
return %0#0 : tensor<? x i32>
}
@ -531,11 +531,11 @@ func @testMaxUnpooling2D(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf
// -----
// CHECK-LABEL: testLogistic
func @testLogistic(tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> {
^bb0(%arg0: tensor<1x2x3x4x5xbf16>):
func @testLogistic(tensor<1x2x3x4x5xf32>) -> tensor<1x2x3x4x5xf32> {
^bb0(%arg0: tensor<1x2x3x4x5xf32>):
// CHECK: "tfl.logistic"(%arg0)
%0 = "tfl.logistic"(%arg0): (tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16>
return %0 : tensor<1x2x3x4x5xbf16>
%0 = "tfl.logistic"(%arg0): (tensor<1x2x3x4x5xf32>) -> tensor<1x2x3x4x5xf32>
return %0 : tensor<1x2x3x4x5xf32>
}
// -----
@ -543,7 +543,7 @@ func @testLogistic(tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> {
// test invalid Logistic input
func @testLogisticWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
^bb0(%arg0: tensor<?xi32>):
// expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of floating-point or QI8 type or QUI8 type or QI16 type or QUI16 type values}}
// expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of 32-bit float or QI8 type or QUI8 type or QI16 type or QUI16 type values}}
%0 = "tfl.logistic"(%arg0): (tensor<?xi32>) -> tensor<?xi32>
return %0#0 : tensor<?xi32>
}

View File

@ -400,8 +400,8 @@ func @FuseFullyConnectedReshapeAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<
// FOLD: return %[[fc]]
}
// CHECK-LABEL: @NotReorderReshapeAddIfNotBroadcastable
func @NotReorderReshapeAddIfNotBroadcastable(%arg0: tensor<40x10x4xf32>) -> tensor<40x40xf32> {
// CHECK-LABEL: @NotReorderReshapeAddIfNotBroadcastableAfter
func @NotReorderReshapeAddIfNotBroadcastableAfter(%arg0: tensor<40x10x4xf32>) -> tensor<40x40xf32> {
%cst = constant dense<2.0> : tensor<40xf32>
%shape = constant dense<[40, 40]> : tensor<2xi32>
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<40x10x4xf32>, tensor<2xi32>) -> tensor<40x40xf32>
@ -413,6 +413,19 @@ func @NotReorderReshapeAddIfNotBroadcastable(%arg0: tensor<40x10x4xf32>) -> tens
// CHECK: return %[[rs2]]
}
// CHECK-LABEL: @NotReorderReshapeAddIfNotTailingDimAfter
func @NotReorderReshapeAddIfNotTailingDimAfter(%arg0: tensor<1x30x1x96xf32>) -> tensor<1x30x96xf32> {
%cst = constant dense<2.0> : tensor<1x30x96xf32>
%shape = constant dense<[1, 30, 96]> : tensor<3xi32>
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<1x30x1x96xf32>, tensor<3xi32>) -> tensor<1x30x96xf32>
%2 = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<1x30x96xf32>, tensor<1x30x96xf32>) -> tensor<1x30x96xf32>
return %2 : tensor<1x30x96xf32>
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0
// CHECK: %[[rs2:.*]] = tfl.add %[[rs1]]
// CHECK: return %[[rs2]]
}
// CHECK-LABEL: @NotReorderReshapeAddIfNotTailingDim
func @NotReorderReshapeAddIfNotTailingDim(%arg0: tensor<40x40x1xf32>) -> tensor<40x40xf32> {
%cst = constant dense<2.0> : tensor<1x40xf32>

View File

@ -15,14 +15,14 @@ func @while() -> tensor<1xf32>
%0:2 = "tfl.while"(%cst0, %cst1) ( {
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>):
// CHECK: call @WhileOp_cond
// CHECK-SAME: (tensor<*xi32>, tensor<*xf32>, tensor<i32>)
// CHECK-SAME: (tensor<*xi32>, tensor<*xf32>)
%cst_0 = constant dense<0> : tensor<i32>
%1 = "tfl.greater"(%arg2, %cst_0) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
"tfl.yield"(%1) : (tensor<i1>) -> ()
}, {
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>):
// CHECK: call @WhileOp_body
// CHECK-SAME: (tensor<*xi32>, tensor<*xf32>, tensor<i32>)
// CHECK-SAME: (tensor<*xi32>, tensor<*xf32>)
%1 = "tfl.sub"(%arg2, %cst) {fused_activation_function = "NONE"} :
(tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%2 = tfl.add %arg3, %arg3 {fused_activation_function = "NONE"} : tensor<*xf32>
@ -40,8 +40,7 @@ func @while() -> tensor<1xf32>
// CHECK-LABEL: func @while2
// Verify that while body//cond with implicitly captured values result in changing while operands/results.
func @while2() -> tensor<1xf32> attributes {tf.entry_function = {outputs = "result"}} {
%cst = constant dense<1> : tensor<i32>
func @while2(%cst : tensor<i32>) -> tensor<1xf32> attributes {tf.entry_function = {outputs = "result"}} {
%cst_0 = constant dense<5> : tensor<i32>
%cst_1 = constant dense<3.000000e+00> : tensor<1xf32>
// Verifies 3 operands post outlining.
@ -148,22 +147,21 @@ func @rnn(%arg0: tensor<4x4x3xf32> {tf.device = "/device:CPU:0"}) -> tensor<4x?x
// CHECK: tfl.while
// CHECK: tfl.yield
// CHECK-SAME: (tensor<i1>) -> ()
// CHECK: [[VAL_41:%.*]]:18 =
// CHECK: [[VAL_30:%.*]]:7 =
// CHECK: call @tfl.while_body
// CHECK: tfl.yield
// CHECK-SAME: (tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<i32>, tensor<i32>, tensor<4x4x3xf32>, tensor<8x5xf32>, tensor<8xf32>, tensor<f32>, tensor<1xi32>, tensor<i32>, tensor<1xi32>, tensor<1xi32>, tensor<i32>, tensor<1xi32>) -> ()
// CHECK-SAME: (tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<4x4x3xf32>) -> ()
// CHECK-LABEL: func @tfl.while_cond(
// CHECK-SAME: [[VAL_56:%.*]]: tensor<i32>, [[VAL_57:%.*]]: tensor<i32>, [[VAL_58:%.*]]: tensor<*xf32>, [[VAL_59:%.*]]: tensor<4x2xf32>, [[VAL_60:%.*]]: tensor<4x2xf32>, [[VAL_61:%.*]]: tensor<*xf32>, [[VAL_62:%.*]]: tensor<i32>, [[VAL_63:%.*]]: tensor<i32>, [[VAL_64:%.*]]: tensor<4x4x3xf32>, [[VAL_65:%.*]]: tensor<8x5xf32>, [[VAL_66:%.*]]: tensor<8xf32>, [[VAL_67:%.*]]: tensor<f32>, [[VAL_68:%.*]]: tensor<1xi32>, [[VAL_69:%.*]]: tensor<i32>, [[VAL_70:%.*]]: tensor<1xi32>, [[VAL_71:%.*]]: tensor<1xi32>, [[VAL_72:%.*]]: tensor<i32>, [[VAL_73:%.*]]: tensor<1xi32>) -> tensor<i1> attributes {sym_visibility = "private"} {
// CHECK-SAME: [[VAL_35:%.*]]: tensor<i32>, [[VAL_36:%.*]]: tensor<i32>, [[VAL_37:%.*]]: tensor<*xf32>, [[VAL_38:%.*]]: tensor<4x2xf32>, [[VAL_39:%.*]]: tensor<4x2xf32>, [[VAL_40:%.*]]: tensor<*xf32>, [[VAL_41:%.*]]: tensor<4x4x3xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
// CHECK: return
// CHECK-SAME: tensor<i1>
// CHECK: }
// CHECK-LABEL: func @tfl.while_body(
// CHECK-SAME: [[VAL_77:%.*]]: tensor<i32>, [[VAL_78:%.*]]: tensor<i32>, [[VAL_79:%.*]]: tensor<*xf32>, [[VAL_80:%.*]]: tensor<4x2xf32>, [[VAL_81:%.*]]: tensor<4x2xf32>, [[VAL_82:%.*]]: tensor<*xf32>, [[VAL_83:%.*]]: tensor<i32>, [[VAL_84:%.*]]: tensor<i32>, [[VAL_85:%.*]]: tensor<4x4x3xf32>, [[VAL_86:%.*]]: tensor<8x5xf32>, [[VAL_87:%.*]]: tensor<8xf32>, [[VAL_88:%.*]]: tensor<f32>, [[VAL_89:%.*]]: tensor<1xi32>, [[VAL_90:%.*]]: tensor<i32>, [[VAL_91:%.*]]: tensor<1xi32>, [[VAL_92:%.*]]: tensor<1xi32>, [[VAL_93:%.*]]: tensor<i32>, [[VAL_94:%.*]]: tensor<1xi32>) -> (tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<i32>, tensor<i32>, tensor<4x4x3xf32>, tensor<8x5xf32>, tensor<8xf32>, tensor<f32>, tensor<1xi32>, tensor<i32>, tensor<1xi32>, tensor<1xi32>, tensor<i32>, tensor<1xi32>) attributes {sym_visibility = "private"} {
// CHECK: [[VAL_123:%.*]] = "tfl.cast"
// CHECK-SAME: [[VAL_46:%.*]]: tensor<i32>, [[VAL_47:%.*]]: tensor<i32>, [[VAL_48:%.*]]: tensor<*xf32>, [[VAL_49:%.*]]: tensor<4x2xf32>, [[VAL_50:%.*]]: tensor<4x2xf32>, [[VAL_51:%.*]]: tensor<*xf32>, [[VAL_52:%.*]]: tensor<4x4x3xf32>) -> (tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<4x4x3xf32>) attributes {sym_visibility = "private"} {
// CHECK: [[VAL_91:%.*]] = "tfl.cast"
// CHECK: return
// CHECK-SAME: [[VAL_123]], [[VAL_83]], [[VAL_84]], [[VAL_85]], [[VAL_86]], [[VAL_87]], [[VAL_88]], [[VAL_89]], [[VAL_90]], [[VAL_91]], [[VAL_92]], [[VAL_93]], [[VAL_94]] : tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<i32>, tensor<i32>, tensor<4x4x3xf32>, tensor<8x5xf32>, tensor<8xf32>, tensor<f32>, tensor<1xi32>, tensor<i32>, tensor<1xi32>, tensor<1xi32>, tensor<i32>, tensor<1xi32>
// CHECK-SAME: [[VAL_91]], [[VAL_52]] : tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<4x4x3xf32>
// CHECK: }
// CHECK: }

View File

@ -79,10 +79,12 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
pass_config.quant_specs.serialized_quant_stats));
}
if (pass_config.lower_tensor_list_ops) {
// TODO(haoliang): Add this pass by default.
pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
}
// The conversion pipeline has to follow the following orders:
// 1) Try to convert ophint nodes if present first like ophint lstm.
// 2) Saved model related optimization like decompose resource ops
// 3) Convert composite functions like lstm/rnns, along with proper function
// inlining & dce.
// 4) Lower static tensor list pass.
// The ophint extractions happen before lots of other passes:
// The assumption of ophint-extraction is each ophinted region is a black-box
@ -105,29 +107,40 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFDevice::CreateDecomposeResourceOpsPass());
// This pass does resource analysis of saved model global tensors and marks
// those deemed read-only as immutable.
pass_manager->addPass(
mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
// Note:
// We need to fuse composite ops before LowerStaticTensorList pass.
// The tensorflow list is not supported right now by that pass.
// Enable fusing composite ops that can be lowered to built-in TFLite ops.
if (pass_config.emit_builtin_tflite_ops) {
pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
}
// This pass marks non-exported functions as symbol visibility 'private'
// those deemed read-only as immutable.
pass_manager->addPass(
mlir::tf_saved_model::
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass());
// Enable fusing composite ops that can be lowered to built-in TFLite ops.
if (pass_config.emit_builtin_tflite_ops) {
pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
pass_manager->addPass(mlir::createInlinerPass());
pass_manager->addPass(mlir::createSymbolDCEPass());
if (pass_config.lower_tensor_list_ops) {
// TODO(haoliang): Add this pass by default.
pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
}
// This pass does resource analysis of saved model global tensors and marks
// those deemed read-only as immutable.
pass_manager->addPass(
mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
// Legalize while early to allow further constant folding.
// TODO(jpienaar): This may not actually matter as we do canonicalization
// after the legalize below, for now it needs to be below the above passes
// that work on TF dialect and before inliner so that the function calls in
// body and cond are inlined for optimization.
if (pass_config.legalize_tf_while) {
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreateLegalizeTFWhilePass());
pass_manager->addPass(mlir::TFL::CreateLegalizeTFWhilePass());
}
// Add function inlining pass. Both TF and TFLite dialects are opted into

View File

@ -31,44 +31,52 @@ namespace {
// Legalize TF While to TFL While with calls to the original functions from the
// cond and body regions.
struct LegalizeWhile : public FunctionPass<LegalizeWhile> {
void runOnFunction() override {
auto func = getFunction();
// Convert all TF WhileOps inside the function body to TFL While ops.
func.getBody().walk([](TF::WhileOp while_op) {
Operation* op = while_op.getOperation();
// Create new TFL While op that will be used to replace TF While op.
auto new_op = OpBuilder(op).create<TFL::WhileOp>(
op->getLoc(), op->getResultTypes(), op->getOperands(),
while_op.is_stateless());
// Insert call to the given function into the 'region'.
auto create_region_with_call = [&while_op](FlatSymbolRefAttr symbol,
Region& region) {
OpBuilder builder(region);
auto block = builder.createBlock(&region);
SmallVector<Value, 4> new_operands;
auto func = while_op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
symbol.getValue());
for (Type t : func.getType().getInputs())
new_operands.push_back(block->addArgument(t));
auto call =
builder.create<CallOp>(while_op.getLoc(), symbol,
func.getType().getResults(), new_operands);
builder.create<YieldOp>(while_op.getLoc(), call.getResults());
};
create_region_with_call(while_op.condAttr(), new_op.cond());
create_region_with_call(while_op.bodyAttr(), new_op.body());
struct LegalizeWhile : public ModulePass<LegalizeWhile> {
void RunOnFunction(FuncOp func);
op->replaceAllUsesWith(new_op.getResults());
op->erase();
});
void runOnModule() override {
for (auto op : getModule().getOps<FuncOp>()) RunOnFunction(op);
}
};
} // namespace
void RunOnWhile(TF::WhileOp while_op) {
Operation* op = while_op.getOperation();
// Create new TFL While op that will be used to replace TF While op.
auto new_op = OpBuilder(op).create<TFL::WhileOp>(
op->getLoc(), op->getResultTypes(), op->getOperands(),
while_op.is_stateless());
// Insert call to the given function into the 'region'.
auto create_region_with_call = [&while_op](FlatSymbolRefAttr symbol,
Region& region) {
OpBuilder builder(region);
auto block = builder.createBlock(&region);
SmallVector<Value, 4> new_operands;
auto func = while_op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
symbol.getValue());
for (Type t : func.getType().getInputs())
new_operands.push_back(block->addArgument(t));
auto call = builder.create<CallOp>(
while_op.getLoc(), symbol, func.getType().getResults(), new_operands);
builder.create<YieldOp>(while_op.getLoc(), call.getResults());
// Mark old function as private so that it can be DCE'd if not called.
func.setVisibility(SymbolTable::Visibility::Private);
};
create_region_with_call(while_op.condAttr(), new_op.cond());
create_region_with_call(while_op.bodyAttr(), new_op.body());
op->replaceAllUsesWith(new_op.getResults());
op->erase();
}
void LegalizeWhile::RunOnFunction(FuncOp func) {
// Convert all TF WhileOps inside the function body to TFL While ops.
func.getBody().walk([](TF::WhileOp while_op) { RunOnWhile(while_op); });
}
// Creates an instance of the TensorFlow While to TFLite While pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFWhilePass() {
std::unique_ptr<OpPassBase<ModuleOp>> CreateLegalizeTFWhilePass() {
return std::make_unique<LegalizeWhile>();
}

View File

@ -862,6 +862,11 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
target.addLegalOp<ConstantOp>();
target.addLegalOp<FuncOp>();
target.addLegalOp<ReturnOp>();
// Register fused LSTM/RNN ops as legal.
target.addLegalOp<TFL::LSTMOp>();
target.addLegalOp<TFL::UnidirectionalSequenceLSTMOp>();
target.addLegalOp<TFL::UnidirectionalSequenceRNNOp>();
target.addLegalOp<TFL::BidirectionalSequenceLSTMOp>();
OwningRewritePatternList patterns;
populateWithGenerated(context, &patterns);

View File

@ -339,11 +339,15 @@ foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in {
(ConstantOp:$rhs $a), TFL_AF_None),
(TFL_ReshapeOp (BinaryOp $input, $rhs, TFL_AF_None), $shape),
// The broadcasting of "BinaryOp" only happens in the lower
// dimensions, and the higher dimensions are same.
// dimensions, and the higher dimensions are same, so we know the
// result and input of the "BinaryOp" in the source pattern have
// the same shape, which is defined by `shape`.
[(IsTailOfShape $rhs, $lhs),
(HasOneUse $lhs),
// the two operands of the binary op is broadcastable
(AreBroadcastableTypes $rhs, $input)]>;
// The result of the new "BinaryOp" will have the same shape as
// `input`. In other words, the shape of the `Reshape` op are not
// changed after the transformation.
(IsTailOfShape $rhs, $input)]>;
}
foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
@ -363,11 +367,15 @@ foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
(ConstantOp:$rhs $a)),
(TFL_ReshapeOp (BinaryOp $input, $rhs), $shape),
// The broadcasting of "BinaryOp" only happens in the lower
// dimensions, and the higher dimensions are same.
// dimensions, and the higher dimensions are same, so we know the
// result and input of the "BinaryOp" in the source pattern have
// the same shape, which is defined by `shape`.
[(IsTailOfShape $rhs, $lhs),
(HasOneUse $lhs),
// the two operands of the binary op is broadcastable
(AreBroadcastableTypes $rhs, $input)]>;
// The result of the new "BinaryOp" will have the same shape as
// `input`. In other words, the shape of the `Reshape` op are not
// changed after the transformation.
(IsTailOfShape $rhs, $input)]>;
}
// Returns shape of a ranked tensor.

View File

@ -86,7 +86,7 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateDefaultQuantParamsPass(
std::unique_ptr<OpPassBase<FuncOp>> CreateDenseToSparsePass();
// Creates function pass to legalize TF While to TFL While.
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFWhilePass();
std::unique_ptr<OpPassBase<ModuleOp>> CreateLegalizeTFWhilePass();
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateWhileOutlinePass();

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "mlir/IR/Identifier.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
@ -89,24 +90,20 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
llvm::SmallVector<Region*, 2> regions{&while_op.cond(), &while_op.body()};
for (auto it : llvm::enumerate(regions)) {
llvm::SetVector<Value> region_extern_values;
Value const_none = nullptr;
getUsedValuesDefinedAbove(*it.value(), region_extern_values);
// Sink down none type constants into the functions.
// Sink down constants into the functions.
for (auto extern_value : region_extern_values) {
if (!extern_value.getType().isa<NoneType>()) {
if (!matchPattern(extern_value, m_Constant())) {
extern_values.insert(extern_value);
continue;
}
if (!const_none) {
// Add constant at start of region.
auto const_builder =
OpBuilder(&it.value()->front(), it.value()->front().begin());
const_none = const_builder.create<ConstantOp>(
while_op.getLoc(), extern_value.getType(),
const_builder.getUnitAttr());
}
replaceAllUsesInRegionWith(extern_value, const_none, *it.value());
// Add constant at start of region.
auto const_builder =
OpBuilder(&it.value()->front(), it.value()->front().begin());
auto const_value = const_builder.clone(*extern_value.getDefiningOp());
replaceAllUsesInRegionWith(extern_value, const_value->getResult(0),
*it.value());
}
}

View File

@ -21,6 +21,10 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_os_ostream.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
@ -86,6 +90,17 @@ MlirOptimizationPassRegistry& MlirOptimizationPassRegistry::Global() {
return *global;
}
static void RegisterDialects() {
static bool init_once = []() {
mlir::registerDialect<mlir::StandardOpsDialect>();
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
return true;
}();
(void)init_once;
}
Status MlirFunctionOptimizationPass::Run(
const DeviceSet& device_set, const ConfigProto& config_proto,
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
@ -107,6 +122,7 @@ Status MlirFunctionOptimizationPass::Run(
<< "(registered " << registry_->passes().size() << " passes)";
GraphDebugInfo debug_info;
RegisterDialects();
mlir::MLIRContext context;
GraphImportConfig import_config;
import_config.graph_as_function = true;
@ -178,9 +194,12 @@ Status MlirV1CompatGraphOptimizationPass::Run(
<< "(registered" << registry_->passes().size() << " passes)";
GraphDebugInfo debug_info;
RegisterDialects();
mlir::MLIRContext context;
GraphImportConfig import_config;
import_config.upgrade_legacy = true;
// TODO(b/150959075): Running functionalization before TPU cluster formation
// is not semantics preserving and should be disabled for now.
import_config.upgrade_legacy = false;
TF_ASSIGN_OR_RETURN(
auto module_ref,
ConvertGraphToMlir(**options.graph, debug_info, *options.flib_def,

View File

@ -10,6 +10,7 @@ package_group(
name = "friends",
includes = ["//third_party/mlir:subpackages"],
packages = [
"//learning/brain/experimental/tfrt/...",
"//learning/pathways/data_parallel/tf2xla/...",
"//tensorflow/compiler/...",
"//tensorflow/lite/experimental/tf_runtime/...",
@ -986,7 +987,6 @@ cc_library(
tf_cc_test(
name = "error_util_test",
srcs = ["utils/error_util_test.cc"],
tags = ["no_rocm"],
deps = [
":error_util",
"//tensorflow/compiler/xla:test",
@ -1065,6 +1065,7 @@ COMPILE_MLIR_UTIL_DEPS = [
":tensorflow_dialect_registration",
":tensorflow_passes",
":translate_utils",
"@com_google_absl//absl/types:optional",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
@ -1083,6 +1084,8 @@ COMPILE_MLIR_UTIL_DEPS = [
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:logging",
"//tensorflow/stream_executor/lib",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:hlo",
]
# Prefer to link 'compile_mlir_util' library that also links necessary

View File

@ -1143,6 +1143,29 @@ that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Clips tensor values to a specified min and max.";
let description = [{
Given a tensor `t`, this operation returns a tensor of the same type and
shape as `t` with its values clipped to `clip_value_min` and `clip_value_max`.
Any values less than `clip_value_min` are set to `clip_value_min`. Any values
greater than `clip_value_max` are set to `clip_value_max`.
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$t,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$clip_value_min,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$clip_value_max
);
let results = (outs
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_ComplexOp : TF_Op<"Complex", [NoSideEffect, ResultsBroadcastableShape]> {
let summary = "Converts two real numbers to a complex number.";
@ -2592,7 +2615,7 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<3>;
}
def TF_FusedBatchNormGradV3Op : TF_Op<"FusedBatchNormGradV3", [NoSideEffect]> {
def TF_FusedBatchNormGradV3Op : TF_Op<"FusedBatchNormGradV3", [NoSideEffect, TF_LayoutSensitiveInterface]> {
let summary = "Gradient for batch normalization.";
let description = [{
@ -2623,9 +2646,17 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<3>;
let extraClassDeclaration = [{
// TF_LayoutSensitiveInterface:
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0, 1}; }
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
StringRef GetOptimalLayout(const RuntimeDevices& devices);
LogicalResult UpdateDataFormat(StringRef data_format);
}];
}
def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> {
let summary = "Batch normalization.";
let description = [{
@ -2662,6 +2693,10 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
// TF_LayoutSensitiveInterface:
StringRef GetOptimalLayout(const RuntimeDevices& devices);
LogicalResult UpdateDataFormat(StringRef data_format);
}];
}

View File

@ -1520,6 +1520,34 @@ static LogicalResult Verify(FillOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// FusedBatchNormGradOp
//===----------------------------------------------------------------------===//
// TODO(b/150954845): Add benchmarks to verify that layout preference didn't
// change in the latest GPU generations.
LogicalResult FusedBatchNormGradV3Op::UpdateDataFormat(StringRef data_format) {
return ::mlir::TF::UpdateDataFormat(data_format, this);
}
StringRef FusedBatchNormGradV3Op::GetOptimalLayout(
const RuntimeDevices &devices) {
// Keep current data format if no GPUs are available or if explicit placement
// does not allow to use GPU for this operation.
if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
return data_format();
// For f16 data type on devices with Tensor Cores support NHWC data format
// is up to ~2x faster.
auto x_ty = x().getType().cast<TensorType>();
const bool is_f16 = x_ty.getElementType().isF16();
if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
// For all other data types prefer NCHW.
return "NCHW";
}
//===----------------------------------------------------------------------===//
// FusedBatchNormOp
//===----------------------------------------------------------------------===//
@ -1547,9 +1575,36 @@ static LogicalResult Verify(FusedBatchNormOp op) {
LogicalResult FusedBatchNormV3Op::FoldOperandsPermutation(
ArrayRef<int64_t> permutation) {
// FusedBatchNorm in training mode is a layout sentitive operation, and should
// have already assigned an optimal data format.
if (is_training()) return failure();
return ::mlir::TF::FoldOperandsPermutation(permutation, this);
}
LogicalResult FusedBatchNormV3Op::UpdateDataFormat(StringRef data_format) {
return ::mlir::TF::UpdateDataFormat(data_format, this);
}
StringRef FusedBatchNormV3Op::GetOptimalLayout(const RuntimeDevices &devices) {
// In inference mode FusedBatchNorm is not sensitive to data layout.
if (!is_training()) return data_format();
// Keep current data format if no GPUs are available or if explicit placement
// does not allow to use GPU for this operation.
if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
return data_format();
// For f16 data type on devices with Tensor Cores support NHWC data format
// is up to ~2x faster.
auto x_ty = x().getType().cast<TensorType>();
const bool is_f16 = x_ty.getElementType().isF16();
if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
// For all other data types prefer NCHW.
return "NCHW";
}
//===----------------------------------------------------------------------===//
// GatherV2Op
//===----------------------------------------------------------------------===//

View File

@ -5,6 +5,10 @@ package(licenses = ["notice"])
glob_lit_tests(
data = [":test_utilities"],
driver = "@llvm-project//mlir:run_lit.sh",
tags_override = {
"optimize.mlir": ["no_rocm"],
"tf_optimize.mlir": ["no_rocm"],
},
test_file_exts = ["mlir"],
)

View File

@ -7,6 +7,37 @@ func @einsum_basic(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor
// CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
}
func @einsum_broadcast(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<3x4x6xf32> {
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ijk,km->ijm"}: (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32>
return %0 : tensor<3x4x6xf32>
// CHECK-LABEL: einsum_broadcast
// CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32>
}
func @einsum_reducesum(%arg0: tensor<2x5x7xf32>, %arg1: tensor<5x2xf32>) -> tensor<5x7xf32> {
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "lbh,bl->bh"}: (tensor<2x5x7xf32>, tensor<5x2xf32>) -> tensor<5x7xf32>
return %0 : tensor<5x7xf32>
// CHECK-LABEL: einsum_reducesum
// CHECK: %[[cst:.*]] = constant dense<[1, 2, 0]> : tensor<3xi32>
// CHECK: %[[cst_1:.*]] = constant dense<[5, 1, 2]> : tensor<3xi64>
// CHECK: %[[cst_2:.*]] = constant dense<2> : tensor<1xi32>
// CHECK: %[[v0:.*]] = "tf.Transpose"(%arg0, %[[cst]]) : (tensor<2x5x7xf32>, tensor<3xi32>) -> tensor<5x7x2xf32>
// CHECK: %[[v1:.*]] = "tf.Reshape"(%arg1, %[[cst_1]]) : (tensor<5x2xf32>, tensor<3xi64>) -> tensor<5x1x2xf32>
// CHECK: %[[v2:.*]] = "tf.Mul"(%[[v0]], %[[v1]]) : (tensor<5x7x2xf32>, tensor<5x1x2xf32>) -> tensor<5x7x2xf32>
// CHECK: "tf.Sum"(%[[v2]], %[[cst_2]]) {keep_dims = false} : (tensor<5x7x2xf32>, tensor<1xi32>) -> tensor<5x7xf32>
}
func @einsum_transpose_matmul(%arg0: tensor<2x5x7xf32>, %arg1: tensor<5x3x2xf32>) -> tensor<5x3x7xf32> {
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "lbh,bkl->bkh"}: (tensor<2x5x7xf32>, tensor<5x3x2xf32>) -> tensor<5x3x7xf32>
return %0 : tensor<5x3x7xf32>
// CHECK-LABEL: einsum_transpose_matmul
// CHECK: %[[cst:.*]] = constant dense<[1, 2, 0]> : tensor<3xi32>
// CHECK: %[[cst_0:.*]] = constant dense<[0, 2, 1]> : tensor<3xi32>
// CHECK: %[[v0:.*]] = "tf.Transpose"(%arg0, %[[cst]]) : (tensor<2x5x7xf32>, tensor<3xi32>) -> tensor<5x7x2xf32>
// CHECK: %[[v1:.*]] = "tf.Transpose"(%arg1, %[[cst_0]]) : (tensor<5x3x2xf32>, tensor<3xi32>) -> tensor<5x2x3xf32>
// CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) {adj_x = false, adj_y = false} : (tensor<5x7x2xf32>, tensor<5x2x3xf32>) -> tensor<5x7x3xf32>
// CHECK: %[[v3:.*]] = "tf.Transpose"(%[[v2]], %[[cst_0]]) : (tensor<5x7x3xf32>, tensor<3xi32>) -> tensor<5x3x7xf32>
}
func @einsum_4D(%arg0: tensor<2x5x7x3xf32>, %arg1: tensor<2x4x7x3xf32>) -> tensor<2x7x5x4xf32> {
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "bfnh,btnh->bnft"}: (tensor<2x5x7x3xf32>, tensor<2x4x7x3xf32>) -> tensor<2x7x5x4xf32>
return %0 : tensor<2x7x5x4xf32>

View File

@ -38,6 +38,24 @@ func @test_single_branch_direct_t() -> tensor<i32> {
return %0 : tensor<i32>
}
// CHECK-LABEL: test_single_branch_direct_arg_f
// CHECK: Switch
// CHECK: tf.AddV2
func @test_single_branch_direct_arg_f(%pred : tensor<i1>) -> tensor<i32> {
%cst_0 = constant dense<10> : tensor<i32>
%cst_1 = constant dense<1> : tensor<i32>
%0 = tf_executor.graph {
%7:3 = tf_executor.Switch %cst_0, %pred : tensor<i32>
%8:2 = tf_executor.island {
%12 = "tf.AddV2"(%7#1, %cst_1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_executor.yield %12 : tensor<i32>
}
%11:3 = tf_executor.Merge %7#0, %8#0 : tensor<i32> {N = 2 : i64}
tf_executor.fetch %11#0 : tensor<i32>
}
return %0 : tensor<i32>
}
// pred ? x + 1 : x - 1
// CHECK-LABEL: ControlFlowTest.testCond_1f
// CHECK-NOT: Switch

View File

@ -1,14 +1,15 @@
# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s
# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s --dump-input-on-failure
# Verify that the data_format attributes is pulled from the default value in the
# registry when not present in the GraphDef
# CHECK: tf.Conv2D
# CHECK-SAME: data_format = "NHWC"
# Verify that we can also pull some attributes that are needed to be able to
# create a Graph in memory, like `T`.
# Verify that we don't import derived attributes as these will be added only on
# export.
# CHECK: tf.MaxPool
# CHECK-SAME: T = f32
# CHECK-NOT: T = f32
# CHECK-SAME: : (tensor<?x112x112x32xf32>) -> tensor<?x56x56x32xf32>
node {
name: "input"

View File

@ -3,6 +3,21 @@
node {
name: "unnamed"
op: "foo"
attr {
key: "_output_shapes"
value {
list {
shape {
}
}
}
}
attr {
key: "_disable_call_shape_inference"
value {
b: true
}
}
experimental_debug_info {
}
}
@ -39,7 +54,7 @@ versions {
# Verify that functions from the library are properly imported.
# CHECK-LABEL: func @main() {
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo0}
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = true, f = @foo0}
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @bar0}
# CHECK-LABEL: func @foo0() {

View File

@ -105,7 +105,6 @@ versions {
# CHECK: func @main
# CHECK: "tf.PartitionedCall"()
# CHECK-SAME: Tout = ["tfdtype$DT_UINT8"]
# CHECK-SAME: f = @[[FUNCTION:[A-Za-z0-9_]*]]
# CHECK: func @[[FUNCTION]]() -> tensor<ui8>
# CHECK: return {{.*}} : tensor<ui8>

View File

@ -0,0 +1,22 @@
# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s
node {
name: "input0"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
}
versions {
producer: 29
min_consumer: 12
}
# Verify that functions from the library are properly imported.
# CHECK-LABEL: func @main() {
# CHECK: "tf.Placeholder"
# CHECK-NOT: _output_shapes

View File

@ -1,7 +1,6 @@
# RUN: tf-mlir-translate -graphdef-to-splatted-mlir %s -o - | FileCheck %s
# CHECK: tf.Const
# CHECK-SAME: _output_shapes = ["tfshape$dim { size: 3 }"]
# CHECK-SAME: value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B2073697A653A2033207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C30303022"> : tensor<3x!tf.string>
node {

View File

@ -62,4 +62,106 @@ func @transposeConv2DBackpropInput_f16(
return %0 : tensor<1x28x28x64xf16>
}
// CHECK-LABEL: func @transposeFusedBatchNormV3_f32
func @transposeFusedBatchNormV3_f32(
%arg0: tensor<1x28x28x64xf32>,
%arg1: tensor<64xf32>
) -> tensor<1x28x28x64xf32> {
// CHECK: "tf.FusedBatchNormV3"
// CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %arg1, %arg1, %arg1, %arg1)
// CHECK-SAME: data_format = "NCHW"
%y, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3
= "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1)
{
data_format = "NHWC",
epsilon = 1.001 : f32,
exponential_avg_factor = 1.0 : f32,
is_training = true
}
: (tensor<1x28x28x64xf32>, tensor<64xf32>, tensor<64xf32>,
tensor<64xf32>, tensor<64xf32>)
-> (tensor<1x28x28x64xf32>, tensor<64xf32>, tensor<64xf32>,
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
return %y : tensor<1x28x28x64xf32>
}
// CHECK-LABEL: func @transposeFusedBatchNormV3_f16
func @transposeFusedBatchNormV3_f16(
%arg0: tensor<1x28x28x64xf16>,
%arg1: tensor<64xf32>
) -> tensor<1x28x28x64xf16> {
// CHECK: "tf.FusedBatchNormV3"
// CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %arg1, %arg1, %arg1, %arg1)
// CHECK-SAME: data_format = "NCHW"
%y, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3
= "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1)
{
data_format = "NHWC",
epsilon = 1.001 : f32,
exponential_avg_factor = 1.0 : f32,
is_training = true
}
: (tensor<1x28x28x64xf16>, tensor<64xf32>, tensor<64xf32>,
tensor<64xf32>, tensor<64xf32>)
-> (tensor<1x28x28x64xf16>, tensor<64xf32>, tensor<64xf32>,
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
return %y : tensor<1x28x28x64xf16>
}
// CHECK-LABEL: func @transposeFusedBatchNormGradV3_f32
func @transposeFusedBatchNormGradV3_f32(
%arg0: tensor<1x28x28x64xf32>,
%arg1: tensor<1x28x28x64xf32>,
%arg2: tensor<64xf32>
) -> tensor<1x28x28x64xf32> {
// CHECK: "tf.FusedBatchNormGradV3"
// CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %[[Y_TRANSPOSE:[0-9]*]],
// CHECK-SAME: data_format = "NCHW"
%x_backprop, %scale_backprop, %offset_backprop, %reserve_1, %reserve_2
= "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg2, %arg2, %arg2)
{
data_format = "NHWC",
epsilon = 1.001 : f32,
exponential_avg_factor = 1.0 : f32,
is_training = true
}
: (tensor<1x28x28x64xf32>, tensor<1x28x28x64xf32>,
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
-> (tensor<1x28x28x64xf32>,
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
return %x_backprop : tensor<1x28x28x64xf32>
}
// CHECK-LABEL: func @transposeFusedBatchNormGradV3_f16
func @transposeFusedBatchNormGradV3_f16(
%arg0: tensor<1x28x28x64xf16>,
%arg1: tensor<1x28x28x64xf16>,
%arg2: tensor<64xf32>
) -> tensor<1x28x28x64xf16> {
// CHECK: "tf.FusedBatchNormGradV3"
// CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %[[Y_TRANSPOSE:[0-9]*]],
// CHECK-SAME: data_format = "NCHW"
%x_backprop, %scale_backprop, %offset_backprop, %reserve_1, %reserve_2
= "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg2, %arg2, %arg2)
{
data_format = "NHWC",
epsilon = 1.001 : f32,
exponential_avg_factor = 1.0 : f32,
is_training = true
}
: (tensor<1x28x28x64xf16>, tensor<1x28x28x64xf16>,
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
-> (tensor<1x28x28x64xf16>,
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
return %x_backprop : tensor<1x28x28x64xf16>
}
}

View File

@ -143,4 +143,106 @@ func @transposeConv2DBackpropInput_f16(
return %0 : tensor<1x64x28x28xf16>
}
// CHECK-LABEL: func @transposeFusedBatchNormV3_f32
func @transposeFusedBatchNormV3_f32(
%arg0: tensor<1x28x28x64xf32>,
%arg1: tensor<64xf32>
) -> tensor<1x28x28x64xf32> {
// CHECK: "tf.FusedBatchNormV3"
// CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %arg1, %arg1, %arg1, %arg1)
// CHECK-SAME: data_format = "NCHW"
%y, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3
= "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1)
{
data_format = "NHWC",
epsilon = 1.001 : f32,
exponential_avg_factor = 1.0 : f32,
is_training = true
}
: (tensor<1x28x28x64xf32>, tensor<64xf32>, tensor<64xf32>,
tensor<64xf32>, tensor<64xf32>)
-> (tensor<1x28x28x64xf32>, tensor<64xf32>, tensor<64xf32>,
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
return %y : tensor<1x28x28x64xf32>
}
// CHECK-LABEL: func @transposeFusedBatchNormV3_f16
func @transposeFusedBatchNormV3_f16(
%arg0: tensor<1x64x28x28xf16>,
%arg1: tensor<64xf32>
) -> tensor<1x64x28x28xf16> {
// CHECK: "tf.FusedBatchNormV3"
// CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %arg1, %arg1, %arg1, %arg1)
// CHECK-SAME: data_format = "NHWC"
%y, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3
= "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1)
{
data_format = "NCHW",
epsilon = 1.001 : f32,
exponential_avg_factor = 1.0 : f32,
is_training = true
}
: (tensor<1x64x28x28xf16>, tensor<64xf32>, tensor<64xf32>,
tensor<64xf32>, tensor<64xf32>)
-> (tensor<1x64x28x28xf16>, tensor<64xf32>, tensor<64xf32>,
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
return %y : tensor<1x64x28x28xf16>
}
// CHECK-LABEL: func @transposeFusedBatchNormGradV3_f32
func @transposeFusedBatchNormGradV3_f32(
%arg0: tensor<1x28x28x64xf32>,
%arg1: tensor<1x28x28x64xf32>,
%arg2: tensor<64xf32>
) -> tensor<1x28x28x64xf32> {
// CHECK: "tf.FusedBatchNormGradV3"
// CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %[[Y_TRANSPOSE:[0-9]*]],
// CHECK-SAME: data_format = "NCHW"
%x_backprop, %scale_backprop, %offset_backprop, %reserve_1, %reserve_2
= "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg2, %arg2, %arg2)
{
data_format = "NHWC",
epsilon = 1.001 : f32,
exponential_avg_factor = 1.0 : f32,
is_training = true
}
: (tensor<1x28x28x64xf32>, tensor<1x28x28x64xf32>,
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
-> (tensor<1x28x28x64xf32>,
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
return %x_backprop : tensor<1x28x28x64xf32>
}
// CHECK-LABEL: func @transposeFusedBatchNormGradV3_f16
func @transposeFusedBatchNormGradV3_f16(
%arg0: tensor<1x64x28x28xf16>,
%arg1: tensor<1x64x28x28xf16>,
%arg2: tensor<64xf32>
) -> tensor<1x64x28x28xf16> {
// CHECK: "tf.FusedBatchNormGradV3"
// CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %[[Y_TRANSPOSE:[0-9]*]],
// CHECK-SAME: data_format = "NHWC"
%x_backprop, %scale_backprop, %offset_backprop, %reserve_1, %reserve_2
= "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg2, %arg2, %arg2)
{
data_format = "NCHW",
epsilon = 1.001 : f32,
exponential_avg_factor = 1.0 : f32,
is_training = true
}
: (tensor<1x64x28x28xf16>, tensor<1x64x28x28xf16>,
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
-> (tensor<1x64x28x28xf16>,
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
return %x_backprop : tensor<1x64x28x28xf16>
}
}

View File

@ -146,3 +146,82 @@ func @transposeConv2DBackpropInput(
return %0 : tensor<1x32x32x3xf32>
}
// CHECK-LABEL: func @transposeFusedBatchNormV3
func @transposeFusedBatchNormV3(
%arg0: tensor<1x28x28x64xf32>,
%arg1: tensor<64xf32>
) -> tensor<1x28x28x64xf32> {
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"()
// CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
// CHECK: "tf.FusedBatchNormV3"
// CHECK-SAME: (%[[ARG_TRANSPOSE]], %arg1, %arg1, %arg1, %arg1)
// CHECK-SAME: data_format = "NCHW"
// CHECK-SAME: (tensor<1x64x28x28xf32>, tensor<64xf32>,
// CHECK-SAME: -> (tensor<1x64x28x28xf32>, tensor<64xf32>,
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"()
// CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi32>}
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%y, %[[RES_PERM]])
// CHECK: return %[[RES_TRANSPOSE]]
%y, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3
= "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1)
{
data_format = "NHWC",
epsilon = 1.001 : f32,
exponential_avg_factor = 1.0 : f32,
is_training = true
}
: (tensor<1x28x28x64xf32>, tensor<64xf32>, tensor<64xf32>,
tensor<64xf32>, tensor<64xf32>)
-> (tensor<1x28x28x64xf32>, tensor<64xf32>, tensor<64xf32>,
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
return %y : tensor<1x28x28x64xf32>
}
// CHECK-LABEL: func @transposeFusedBatchNormGradV3
func @transposeFusedBatchNormGradV3(
%arg0: tensor<1x28x28x64xf32>,
%arg1: tensor<1x28x28x64xf32>,
%arg2: tensor<64xf32>
) -> tensor<1x28x28x64xf32> {
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"()
// CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
// CHECK: %[[ARG0_TPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
// CHECK: %[[ARG1_TPOSE:[0-9]*]] = "tf.Transpose"(%arg1, %[[ARG_PERM]])
// CHECK: "tf.FusedBatchNormGradV3"
// CHECK-SAME: (%[[ARG0_TPOSE]], %[[ARG1_TPOSE]], %arg2, %arg2, %arg2, %arg2)
// CHECK-SAME: data_format = "NCHW"
// CHECK-SAME: (tensor<1x64x28x28xf32>, tensor<1x64x28x28xf32>,
// CHECK-SAME: -> (tensor<1x64x28x28xf32>,
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"()
// CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi32>}
// CHECK: %[[RES_TPOSE:[0-9]*]] = "tf.Transpose"
// CHECK-SAME: (%x_backprop, %[[RES_PERM]])
// CHECK: return %[[RES_TPOSE]]
%x_backprop, %scale_backprop, %offset_backprop, %reserve_1, %reserve_2
= "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg2, %arg2, %arg2)
{
data_format = "NHWC",
epsilon = 1.001 : f32,
exponential_avg_factor = 1.0 : f32,
is_training = true
}
: (tensor<1x28x28x64xf32>, tensor<1x28x28x64xf32>,
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
-> (tensor<1x28x28x64xf32>,
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
return %x_backprop : tensor<1x28x28x64xf32>
}

View File

@ -33,3 +33,40 @@ func @transposeConv2D(%input: tensor<1x3x32x32xf32>, %filter: tensor<1x1x3x8xf32
return %0 : tensor<1x8x32x32xf32>
}
// CHECK-LABEL: func @transposeFusedBatchNormV3
func @transposeFusedBatchNormV3(
%arg0: tensor<1x64x28x28xf32>,
%arg1: tensor<64xf32>
) -> tensor<1x64x28x28xf32> {
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"()
// CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi32>}
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
// CHECK: "tf.FusedBatchNormV3"
// CHECK-SAME: (%[[ARG_TRANSPOSE]], %arg1, %arg1, %arg1, %arg1)
// CHECK-SAME: data_format = "NHWC"
// CHECK-SAME: (tensor<1x28x28x64xf32>, tensor<64xf32>,
// CHECK-SAME: -> (tensor<1x28x28x64xf32>, tensor<64xf32>,
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"()
// CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%y, %[[RES_PERM]])
// CHECK: return %[[RES_TRANSPOSE]]
%y, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3
= "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1)
{
data_format = "NCHW",
epsilon = 1.001 : f32,
exponential_avg_factor = 1.0 : f32,
is_training = true
}
: (tensor<1x64x28x28xf32>, tensor<64xf32>, tensor<64xf32>,
tensor<64xf32>, tensor<64xf32>)
-> (tensor<1x64x28x28xf32>, tensor<64xf32>, tensor<64xf32>,
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
return %y : tensor<1x64x28x28xf32>
}

File diff suppressed because it is too large Load Diff

View File

@ -6,7 +6,7 @@ func @main() {
// CHECK-NEXT: name: "predicate"
// CHECK-NEXT: op: "Const"
// CHECK-NEXT: attr {
// CHECK-NEXT: key: "dtype"
// CHECK: key: "dtype"
// CHECK-NEXT: value {
// CHECK-NEXT: type: DT_INT32
// CHECK-NEXT: }

View File

@ -22,7 +22,7 @@ func @main() {
// CHECK-NEXT: name: "tf.Const"
// CHECK-NEXT: op: "Const"
// CHECK-NEXT: attr {
// CHECK-NEXT: key: "dtype"
// CHECK: key: "dtype"
// CHECK-NEXT: value {
// CHECK-NEXT: type: DT_INT32
// CHECK-NEXT: }
@ -43,8 +43,8 @@ func @main() {
// CHECK-NEXT: name: "tf.Empty"
// CHECK-NEXT: op: "Empty"
// CHECK-NEXT: input: "tf.Const:output:0"
// CHECK-NEXT: attr {
// CHECK-NEXT: key: "dtype"
// CHECK-NEXT: attr {
// CHECK: key: "dtype"
// CHECK-NEXT: value {
// CHECK-NEXT: type: DT_FLOAT
// CHECK-NEXT: }

View File

@ -0,0 +1,64 @@
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
func @main(%arg0: tensor<10xi32>) -> tensor<10xi32>
attributes {tf.entry_function = {inputs = "input0", outputs = "Placeholder"}} {
%graph = tf_executor.graph {
%0:2 = tf_executor.island wraps "tf.Placeholder.input"(%arg0) {device = "", dtype = "tfdtype$DT_INT32", shape = "tfshape$dim { size: 10 }"} : (tensor<10xi32>) -> tensor<10xi32>
tf_executor.fetch %0 : tensor<10xi32>
}
return %graph : tensor<10xi32>
}
// CHECK: node {
// CHECK-NEXT: name: "Placeholder"
// CHECK-NEXT: op: "Placeholder"
// CHECK-NEXT: attr {
// CHECK-NEXT: key: "_output_shapes"
// CHECK-NEXT: value {
// CHECK-NEXT: list {
// CHECK-NEXT: shape {
// CHECK-NEXT: dim {
// CHECK-NEXT: size: 10
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: attr {
// CHECK-NEXT: key: "dtype"
// CHECK-NEXT: value {
// CHECK-NEXT: type: DT_INT32
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: attr {
// CHECK-NEXT: key: "shape"
// CHECK-NEXT: value {
// CHECK-NEXT: shape {
// CHECK-NEXT: dim {
// CHECK-NEXT: size: 10
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: experimental_debug_info {
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: node {
// CHECK-NEXT: name: "main"
// CHECK-NEXT: op: "_Retval"
// CHECK-NEXT: input: "Placeholder"
// CHECK-NEXT: attr {
// CHECK-NEXT: key: "T"
// CHECK-NEXT: value {
// CHECK-NEXT: type: DT_INT32
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: attr {
// CHECK-NEXT: key: "index"
// CHECK-NEXT: value {
// CHECK-NEXT: i: 0
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: library {
// CHECK-NEXT: }

View File

@ -34,7 +34,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: attr {
// CHECK-NEXT: key: "dense_shapes"
// CHECK: key: "dense_shapes"
// CHECK-NEXT: value {
// CHECK-NEXT: list {
// CHECK-NEXT: shape {

View File

@ -2,7 +2,7 @@
// CHECK: attr {
// CHECK-NEXT: key: "dtypes"
// CHECK: key: "dtypes"
// CHECK-NEXT: value {
// CHECK-NEXT: list {
// CHECK-NEXT: type: DT_INT32

View File

@ -5,7 +5,7 @@ func @main() {
// CHECK-NEXT: name: "Empty/shape"
// CHECK-NEXT: op: "Const"
// CHECK: attr {
// CHECK-NEXT: key: "dtype"
// CHECK: key: "dtype"
// CHECK-NEXT: value {
// CHECK-NEXT: type: DT_INT32
// CHECK-NEXT: }

View File

@ -4,8 +4,8 @@ func @main() {
^bb0:
// CHECK: name: "node_name"
// CHECK-NEXT: op: "Const"
// CHECK: attr {
// CHECK-NEXT: key: "dtype"
// CHECK-NEXT: attr {
// CHECK: key: "dtype"
// CHECK-NEXT: value {
// CHECK-NEXT: type: DT_INT32
// CHECK-NEXT: }

View File

@ -6,7 +6,7 @@ func @main() {
// CHECK-NEXT: name: "Const"
// CHECK-NEXT: op: "Const"
// CHECK-NEXT: attr {
// CHECK-NEXT: key: "dtype"
// CHECK: key: "dtype"
// CHECK-NEXT: value {
// CHECK-NEXT: type: DT_FLOAT
// CHECK-NEXT: }
@ -31,7 +31,7 @@ func @main() {
// CHECK-NEXT: name: "foo"
// CHECK-NEXT: op: "foo"
// CHECK-NEXT: input: "Const"
// CHECK-NEXT: experimental_debug_info {
// CHECK: experimental_debug_info {
// CHECK-NEXT: }
// CHECK-NEXT: }
%1:2 = tf_executor.island wraps "tf.foo"(%0#0) {device = ""} : (tensor<f32>) -> tensor<*xf32> loc("foo")

View File

@ -18,6 +18,11 @@ func @foo0(%arg0: tensor<*xi32>) -> tensor<*xi32> {
// CHECK: node {
// CHECK: name: "tf.LegacyCall"
// CHECK-NEXT: op: "foo0"
// CHECK: attr {
// CHECK-NEXT: key: "_output_shapes"
// CHECK-NEXT: value {
// CHECK-NEXT: list {
// CHECK-NEXT: shape {
// CHECK: library {
// CHECK-NEXT: function {

View File

@ -14,8 +14,8 @@ attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} {
// CHECK: node {
// CHECK-NEXT: name: "input0"
// CHECK-NEXT: op: "Placeholder"
// CHECK-NEXT: attr {
// CHECK-NEXT: key: "dtype"
// CHECK-NEXT: attr {
// CHECK: key: "dtype"
// CHECK-NEXT: value {
// CHECK-NEXT: type: DT_INT32
// CHECK-NEXT: }
@ -36,8 +36,8 @@ attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} {
// CHECK-NEXT: node {
// CHECK-NEXT: name: "input1"
// CHECK-NEXT: op: "Placeholder"
// CHECK-NEXT: attr {
// CHECK-NEXT: key: "dtype"
// CHECK-NEXT: attr {
// CHECK: key: "dtype"
// CHECK-NEXT: value {
// CHECK-NEXT: type: DT_INT32
// CHECK-NEXT: }
@ -66,7 +66,7 @@ attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} {
// CHECK-NEXT: type: DT_INT32
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: experimental_debug_info {
// CHECK: experimental_debug_info {
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: node {

View File

@ -124,3 +124,36 @@ func @nested_ops(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) {
// CHECK-NEXT: tf_device.return %[[OP_C]]
// CHECK-NEXT: device = "c"
// CHECK-NEXT: tf_device.return %[[SHAPE]], %[[LAUNCH_A]], %[[LAUNCH_B]], %[[LAUNCH_C]]
// CHECK-LABEL: func @do_not_hoist_ops_with_virtual_device
// CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf32>, [[VAL_1:%.*]]: tensor<*xf32>)
func @do_not_hoist_ops_with_virtual_device(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) {
%0:8 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<*xf32>) {devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2: i32} {
%1 = "tf.Shape"(%ri) {device = "", T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<*xf32>) -> tensor<?xi32>
%2 = "tf.opA"(%1) {device = "TPU_REPLICATED_CORE_0"} : (tensor<?xi32>) -> tensor<*xi32>
%3 = "tf_device.launch"() ( {
%b = "tf.opB"(%1) : (tensor<?xi32>) -> tensor<*xi32>
tf_device.return %b : tensor<*xi32>
}) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor<*xi32>
%4 = "tf_device.launch"() ( {
%c = "tf.opC"(%1) {device = "TPU_REPLICATED_CORE_0"} : (tensor<?xi32>) -> tensor<*xi32>
tf_device.return %c : tensor<*xi32>
}) {device = "c"} : () -> tensor<*xi32>
tf_device.return %1, %2, %3, %4 : tensor<?xi32>, tensor<*xi32>, tensor<*xi32>, tensor<*xi32>
}
return
}
// CHECK: [[SHAPE:%.*]] = "tf.Shape"([[VAL_0]])
// CHECK: tf_device.replicate({{\[}}[[VAL_0]], [[VAL_1]]] as [[VAL_4:%.*]]: tensor<*xf32>) {devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
// CHECK: [[OP_A:%.*]] = "tf.opA"([[SHAPE]]) {device = "TPU_REPLICATED_CORE_0"} : (tensor<?xi32>) -> tensor<*xi32>
// CHECK: [[LAUNCH_B:%.*]] = "tf_device.launch"() ( {
// CHECK: [[OP_B:%.*]] = "tf.opB"([[SHAPE]]) : (tensor<?xi32>) -> tensor<*xi32>
// CHECK: tf_device.return [[OP_B]] : tensor<*xi32>
// CHECK: }) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor<*xi32>
// CHECK: [[LAUNCH_C:%.*]] = "tf_device.launch"() ( {
// CHECK: [[OP_C:%.*]] = "tf.opC"([[SHAPE]]) {device = "TPU_REPLICATED_CORE_0"} : (tensor<?xi32>) -> tensor<*xi32>
// CHECK: tf_device.return [[OP_C]] : tensor<*xi32>
// CHECK: }) {device = "c"} : () -> tensor<*xi32>
// CHECK: tf_device.return [[SHAPE]], [[OP_A]], [[LAUNCH_B]], [[LAUNCH_C]]

View File

@ -183,7 +183,7 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
// CHECK-LABEL: func @reused_if_then_branch
// CHECK-SAME: (%arg0: tensor<*xf32>) -> tensor<*xf32>
// expected-error @+1 {{expected control flow function reused_if_then_branch to have exactly 1 use}}
// expected-warning @+1 {{expected control flow function reused_if_then_branch to have exactly 1 use}}
func @reused_if_then_branch(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: return
// CHECK-SAME: tensor<*xf32>
@ -192,7 +192,7 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
// CHECK-LABEL: func @reused_if_else_branch
// CHECK-SAME: (%arg0: tensor<*xf32>) -> tensor<*xf32>
// expected-error @+1 {{expected control flow function reused_if_else_branch to have exactly 1 use}}
// expected-warning @+1 {{expected control flow function reused_if_else_branch to have exactly 1 use}}
func @reused_if_else_branch(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: "tf.Identity"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%0 = "tf.Identity"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>)
@ -278,4 +278,23 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
func @variant_body_func(%arg0: tensor<!tf.variant<tensor<16x1xf32>>>) -> tensor<!tf.variant<tensor<16x1xf32>>> {
return %arg0 : tensor<!tf.variant<tensor<16x1xf32>>>
}
// Test propagation from called functions to the call site.
// CHECK-LABEL: func @stateful_partitioned_call(
// CHECK-SAME: -> tensor<20xi32>
func @stateful_partitioned_call(%arg0: tensor<20xi32>) -> tensor<*xi32> {
// CHECK: tf.PartitionedCall
// CHECK-SAME: (tensor<20xi32>) -> tensor<20xi32>
%0 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @a_called_func} : (tensor<20xi32>) -> (tensor<*xi32>)
// CHECK: tf.StatefulPartitionedCall
// CHECK-SAME: (tensor<20xi32>) -> tensor<20xi32>
%1 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @stateful_partitioned_call_func} : (tensor<20xi32>) -> (tensor<*xi32>)
return %0 : tensor<*xi32>
}
func @a_called_func(%arg0: tensor<?xi32>) -> (tensor<?xi32>) {
return %arg0 : tensor<?xi32>
}
func @stateful_partitioned_call_func(%arg0: tensor<?xi32>) -> (tensor<?xi32>) {
return %arg0 : tensor<?xi32>
}
}

View File

@ -187,3 +187,199 @@ func @main() {
%write3 = "tf.TensorArrayWriteV3"(%grad3#0, %index, %value, %grad3#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
return
}
// -----
// Tests while loop with access to the tensor array defined outside and its
// gradient defined inside. The gradient creation should be moved outside.
// CHECK-LABEL: func @main
func @main() -> () {
// CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
%index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
// CHECK: %[[GVAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
// CHECK: "tf.While"(%[[VAR]], %[[SIZE]], %[[GVAR]])
%1:2 = "tf.While"(%ta#0, %size) {
body = @while_body, cond = @while_cond, device = "", is_stateless = false}
: (tensor<!tf.resource>, tensor<i32>) -> (tensor<!tf.resource>, tensor<i32>)
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: "tf.Slice"(%[[READ]],
%read = "tf.TensorArrayReadV3"(%1#0, %index, %ta#1) : (tensor<!tf.resource>, tensor<i32>, tensor<f32>) -> tensor<3xf32>
return
}
// CHECK: func @while_body(%[[BARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[BARG1:.*]]: tensor<i32>, %[[BARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
func @while_body(%arg0: tensor<!tf.resource>, %arg1: tensor<i32>) -> (tensor<!tf.resource>, tensor<i32>) {
// CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[BARG1]], %[[CONST1]])
%sub = "tf.Sub"(%arg1, %const1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
%elem = "tf._SomeOp"() : () -> tensor<3xf32>
%flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[BARG0]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]],
// CHECK: "tf.AssignVariableOp"(%[[BARG0]], %[[UPDATE1]])
%write = "tf.TensorArrayWriteV3"(%arg0, %sub, %elem, %flow) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
%grad:2 = "tf.TensorArrayGradV3"(%arg0, %write) {source = "a"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
// CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[BARG2]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: %[[UPDATE2:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ2]],
// CHECK: "tf.AssignVariableOp"(%[[BARG2]], %[[UPDATE2]])
%gwrite = "tf.TensorArrayWriteV3"(%grad#0, %sub, %elem, %grad#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
// CHECK: return %[[BARG0]], %[[SUB]], %[[BARG2]]
return %arg0, %sub : tensor<!tf.resource>, tensor<i32>
}
// CHECK: func @while_cond(%[[CARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG1:.*]]: tensor<i32>, %[[CARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
func @while_cond(%arg0: tensor<!tf.resource>, %arg1: tensor<i32>) -> tensor<i32> {
// CHECK-NEXT: return %[[CARG1]]
return %arg1 : tensor<i32>
}
// -----
// Tests If op with access to the tensor array defined outside and its gradient
// defined inside. The gradient creation should be moved outside.
// CHECK-LABEL: func @main
func @main() -> () {
// CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
%index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
// CHECK: %[[COND:.*]] = "tf._SomeOp"() : () -> tensor<i1>
%cond = "tf._SomeOp"() : () -> tensor<i1>
// CHECK: %[[GVAR1:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
// CHECK: %[[GVAR2:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
// CHECK: "tf.If"(%[[COND]], %[[VAR]], %[[GVAR1]], %[[GVAR2]])
%1 = "tf.If"(%cond, %ta#0) {
then_branch = @then_branch, else_branch = @else_branch, device = "", is_stateless = false}
: (tensor<i1>, tensor<!tf.resource>) -> tensor<!tf.resource>
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: "tf.Slice"(%[[READ]],
%read = "tf.TensorArrayReadV3"(%1, %index, %ta#1) : (tensor<!tf.resource>, tensor<i32>, tensor<f32>) -> tensor<3xf32>
return
}
// CHECK: func @then_branch(%[[TARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[TARG1:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[TARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
func @then_branch(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%elem = "tf._SomeOp"() : () -> tensor<3xf32>
%flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[TARG1]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]],
// CHECK: "tf.AssignVariableOp"(%[[TARG1]], %[[UPDATE1]])
%grad:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "a"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
%gwrite = "tf.TensorArrayWriteV3"(%grad#0, %const1, %elem, %grad#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
// CHECK: return %[[TARG0]]
return %arg0 : tensor<!tf.resource>
}
// CHECK: func @else_branch(%[[EARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[EARG1:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[EARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
func @else_branch(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
// CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%elem = "tf._SomeOp"() : () -> tensor<3xf32>
%flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[EARG2]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: %[[UPDATE2:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ2]],
// CHECK: "tf.AssignVariableOp"(%[[EARG2]], %[[UPDATE2]])
%grad:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "b"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
%gwrite = "tf.TensorArrayWriteV3"(%grad#0, %const1, %elem, %grad#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
// CHECK: return %[[EARG0]]
return %arg0 : tensor<!tf.resource>
}
// -----
// Tests (Stateful)PartitionedCall op with access to the tensor array defined
// outside and its gradient defined inside. The gradient creation should be
// moved outside.
// CHECK-LABEL: func @main
func @main() -> () {
// CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
%index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
// CHECK: %[[COND:.*]] = "tf._SomeOp"() : () -> tensor<i1>
%cond = "tf._SomeOp"() : () -> tensor<i1>
// CHECK: %[[GVAR1:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
%grad:2 = "tf.TensorArrayGradV3"(%ta#0, %ta#1) {source = "a"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
// CHECK: %[[GVAR2:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
// CHECK: "tf.StatefulPartitionedCall"(%[[VAR]], %[[GVAR1]], %[[GVAR2]])
// CHECK-SAME: f = @callee_tensorarray_decomposed
%call = "tf.StatefulPartitionedCall"(%ta#0) {f = @callee, config = "", config_proto = "", executor_type = ""}
: (tensor<!tf.resource>) -> tensor<!tf.resource>
// CHECK: "tf.PartitionedCall"(%[[VAR]], %[[GVAR1]], %[[GVAR2]])
// CHECK-SAME: f = @callee_tensorarray_decomposed
%call2 = "tf.PartitionedCall"(%call) {f = @callee, config = "", config_proto = "", executor_type = ""}
: (tensor<!tf.resource>) -> tensor<!tf.resource>
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: "tf.Slice"(%[[READ]],
%read = "tf.TensorArrayReadV3"(%call2, %index, %ta#1) : (tensor<!tf.resource>, tensor<i32>, tensor<f32>) -> tensor<3xf32>
return
}
// CHECK-LABEL: func @callee
// CHECK-SAME: (%[[OCARG0:.*]]: tensor<!tf.resource>) -> tensor<!tf.resource>
func @callee(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%elem = "tf._SomeOp"() : () -> tensor<3xf32>
%flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
%grad:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "a"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
%gwrite = "tf.TensorArrayWriteV3"(%grad#0, %const1, %elem, %grad#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
%grad2:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "b"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
%gwrite2 = "tf.TensorArrayWriteV3"(%grad2#0, %const1, %elem, %grad2#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
return %arg0 : tensor<!tf.resource>
}
// CHECK: func @callee_tensorarray_decomposed(%[[CARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG1:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
// CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[CARG1]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]],
// CHECK: "tf.AssignVariableOp"(%[[CARG1]], %[[UPDATE1]])
// CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[CARG2]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: %[[UPDATE2:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ2]],
// CHECK: "tf.AssignVariableOp"(%[[CARG2]], %[[UPDATE2]])
// CHECK: return %[[CARG0]]
// -----
// Test the pass reports failure on unknown size.
func @main(%arg0: tensor<i32>) -> () {
// expected-error @+1 {{unknown max element count}}
%ta:2 = "tf.TensorArrayV3"(%arg0) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
return
}
// -----
// Test the pass reports failure on unknown shape.
func @main(%arg0: tensor<i32>) -> () {
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
// expected-error @+1 {{unknown element shape}}
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$unknown_rank: true", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
return
}
// -----
// Tests that the pass reports error on ambiguous tensor array.
func @main(%arg0: tensor<i1>) -> () {
%size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
%ta0:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
%ta1:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
%if_op = "tf.If"(%arg0, %ta0#0, %ta1#0) {then_branch = @if_then, else_branch = @if_else, is_stateless = false}
: (tensor<i1>, tensor<!tf.resource>, tensor<!tf.resource>) -> tensor<!tf.resource>
%index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// expected-error @+1 {{unknown tensor array}}
%read = "tf.TensorArrayReadV3"(%if_op, %index, %ta0#1) : (tensor<!tf.resource>, tensor<i32>, tensor<f32>) -> tensor<3xf32>
return
}
func @if_then(%arg0: tensor<!tf.resource>, %arg1: tensor<!tf.resource>) -> tensor<!tf.resource> {
return %arg0 : tensor<!tf.resource>
}
func @if_else(%arg0: tensor<!tf.resource>, %arg1: tensor<!tf.resource>) -> tensor<!tf.resource> {
return %arg1 : tensor<!tf.resource>
}

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
#include <algorithm>
#include <climits>
#include <cstdint>
@ -49,6 +50,9 @@ enum EinsumEquation {
FourDMatrixDotProd,
ThreeDReshapeTail,
FourDBatchMatMul,
BroadcastMatMul,
ReduceSum,
TransposeMatMul,
UnsupportedEquation
};
@ -121,6 +125,18 @@ EinsumEquation parseEquation(const std::vector<EquationToken>& eqn) {
if (is_equal(eqn, {A, B, C, COMMA, C, D, E, ARROW, A, B, D, E})) {
return EinsumEquation::ThreeDReshapeTail;
}
// BFH,HO->BFO
if (is_equal(eqn, {A, B, C, COMMA, C, D, ARROW, A, B, D})) {
return EinsumEquation::BroadcastMatMul;
}
// LBH,BL->BH
if (is_equal(eqn, {A, B, C, COMMA, B, A, ARROW, B, C})) {
return EinsumEquation::ReduceSum;
}
// LBH,BKL->BKH
if (is_equal(eqn, {A, B, C, COMMA, B, D, A, ARROW, B, D, C})) {
return EinsumEquation::TransposeMatMul;
}
return EinsumEquation::UnsupportedEquation;
}
@ -151,6 +167,28 @@ TF::TransposeOp createTransposeOp(Value value, Location loc,
perm_op);
}
TF::SumOp createSumOp(Value value, Location loc,
llvm::ArrayRef<int32_t> redux_axes,
PatternRewriter* rewriter) {
auto value_type = value.getType().cast<RankedTensorType>();
auto shape = value_type.getShape();
auto redux_type = RankedTensorType::get(
{static_cast<int32_t>(redux_axes.size())}, rewriter->getIntegerType(32));
auto redux_attr = DenseElementsAttr::get(redux_type, redux_axes);
auto redux_op = rewriter->create<ConstantOp>(loc, redux_type, redux_attr);
std::vector<int64_t> sum_shape(shape.size() - redux_axes.size());
int count = 0;
for (int i = 0; i < shape.size(); ++i) {
if (std::find(redux_axes.begin(), redux_axes.end(), i) ==
redux_axes.end()) {
sum_shape[count] = shape[i];
count++;
}
}
auto sum_type = RankedTensorType::get(sum_shape, value_type.getElementType());
return rewriter->create<TF::SumOp>(loc, sum_type, value, redux_op);
}
TF::ReshapeOp createReshapeOp(Value value, ArrayRef<int64_t> shape,
Type element_type, Location loc,
PatternRewriter* rewriter) {
@ -173,7 +211,6 @@ LogicalResult ConvertTFEinsumOp::matchAndRewrite(
Value lhs = op.getOperand(0);
Value rhs = op.getOperand(1);
Location loc = op.getLoc();
if (!lhs.getType().isa<RankedTensorType>()) {
// LHS must be a ranked tensor type
return failure();
@ -193,10 +230,10 @@ LogicalResult ConvertTFEinsumOp::matchAndRewrite(
return failure();
}
// Currently support use cases of LHS, RHS dims = 3 or 4
// Currently support use cases of LHS dims \in {3,4} RHS dims \in {2, 3, 4}
const int dims_lhs = lhs_shape.size();
const int dims_rhs = rhs_shape.size();
if (dims_rhs < 3 || dims_rhs > 4 || dims_lhs < 3 || dims_lhs > 4) {
if (dims_lhs < 3 || dims_lhs > 4 || dims_rhs < 2 || dims_lhs > 4) {
return failure();
}
@ -209,6 +246,46 @@ LogicalResult ConvertTFEinsumOp::matchAndRewrite(
rewriter.replaceOp(op, bmm_op.getResult());
return success();
}
if (einsum_eqn == EinsumEquation::BroadcastMatMul) {
// Case "BFH,HO->BFO"
auto bmm_op = rewriter.create<TF::BatchMatMulV2Op>(
loc, ArrayRef<Type>{output_type}, lhs, rhs, rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false));
rewriter.replaceOp(op, bmm_op.getResult());
return success();
}
if (einsum_eqn == EinsumEquation::ReduceSum) {
// Case "LBH,BL->BH"
// Transpose LHS
lhs = createTransposeOp(lhs, loc, {1, 2, 0}, &rewriter);
// Reshape RHS
auto rhs_element_type = rhs_type.getElementType();
const int rhs_dim0 = rhs_shape[0];
const int rhs_dim1 = rhs_shape[1];
auto reshaped_rhs = createReshapeOp(rhs, {rhs_dim0, 1, rhs_dim1},
rhs_element_type, loc, &rewriter);
auto mul_op = rewriter.create<TF::MulOp>(loc, lhs, reshaped_rhs);
auto sum_op = createSumOp(mul_op, loc, {2}, &rewriter);
rewriter.replaceOp(op, {sum_op.getResult()});
return success();
}
if (einsum_eqn == EinsumEquation::TransposeMatMul) {
// Case "LBH,BKL->BKH"
// Transpose LHS
lhs = createTransposeOp(lhs, loc, {1, 2, 0}, &rewriter);
// Transpose RHS
rhs = createTransposeOp(rhs, loc, {0, 2, 1}, &rewriter);
std::vector<int64_t> bmm_shape = {lhs_shape[1], lhs_shape[2], rhs_shape[1]};
auto bmm_type = RankedTensorType::get(bmm_shape, rhs_type.getElementType());
auto bmm_op = rewriter.create<TF::BatchMatMulV2Op>(
loc, ArrayRef<Type>{bmm_type}, lhs, rhs, rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false));
auto trans_bmm = createTransposeOp(bmm_op, loc, {0, 2, 1}, &rewriter);
rewriter.replaceOp(op, {trans_bmm.getResult()});
return success();
}
if (einsum_eqn == EinsumEquation::ThreeDReshapeTail) {
// Case "BFD,DNH->BFNH"
auto lhs_type = lhs.getType().cast<RankedTensorType>();

View File

@ -67,7 +67,7 @@ class SwitchFoldPass : public mlir::FunctionPass<SwitchFoldPass> {
// Returns the defining op for a value looking through islands.
static Operation* GetDefiningOp(Value val) {
Operation* op = val.getDefiningOp();
auto island_op = dyn_cast<tf_executor::IslandOp>(op);
auto island_op = dyn_cast_or_null<tf_executor::IslandOp>(op);
if (!island_op) return op;
auto yield_op = island_op.GetYield();
auto index = val.cast<mlir::OpResult>().getResultNumber();
@ -84,7 +84,8 @@ static Operation* GetDefiningOp(Value val) {
static Value LookThroughIdentityOp(Value pred_val) {
if (!pred_val) return pred_val;
auto op = GetDefiningOp(pred_val);
if (auto id_op = dyn_cast<TF::IdentityOp>(op)) pred_val = id_op.input();
if (auto id_op = dyn_cast_or_null<TF::IdentityOp>(op))
pred_val = id_op.input();
return pred_val;
}

Some files were not shown because too many files have changed in this diff Show More