Merge branch 'master' into add_xnnpack_to_label_image_again
This commit is contained in:
commit
9d022cb885
4
.bazelrc
4
.bazelrc
@ -19,10 +19,10 @@
|
||||
# Compiler options:
|
||||
# cuda_clang: Use clang when building CUDA code.
|
||||
# c++17: Build with C++17 options
|
||||
# C++1z: Build with C++17 options
|
||||
# c++1z: Build with C++17 options
|
||||
# avx_linux: Build with avx instruction set on linux.
|
||||
# avx2_linux: Build with avx2 instruction set on linux.
|
||||
# arch_native_linux: Build with instruction sets available to the host machine on linux
|
||||
# native_arch_linux: Build with instruction sets available to the host machine on linux
|
||||
# avx_win: Build with avx instruction set on windows
|
||||
# avx2_win: Build with avx2 instruction set on windows
|
||||
#
|
||||
|
@ -88,6 +88,9 @@ TensorFlow coding style.
|
||||
submitting PRs to fix one typo, one warning,etc. We recommend fixing the
|
||||
same issue at the file level at least (e.g.: fix all typos in a file, fix
|
||||
all compiler warning in a file, etc.)
|
||||
* Tests should follow the
|
||||
[testing best practices](https://www.tensorflow.org/community/contribute/tests)
|
||||
guide.
|
||||
|
||||
#### License
|
||||
|
||||
|
@ -135,6 +135,7 @@ Build Type | Status
|
||||
* [TensorFlow Examples](https://github.com/tensorflow/examples)
|
||||
* [TensorFlow in Practice from Coursera](https://www.coursera.org/specializations/tensorflow-in-practice)
|
||||
* [TensorFlow: Data and Deployment from Coursera](https://www.coursera.org/specializations/tensorflow-data-and-deployment)
|
||||
* [Getting Started with TensorFlow 2 from Coursera](https://www.coursera.org/learn/getting-started-with-tensor-flow2)
|
||||
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
|
||||
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
|
||||
* [TensorFlow Blog](https://blog.tensorflow.org)
|
||||
|
@ -41,6 +41,11 @@ import sys as _sys
|
||||
from tensorflow.python.tools import module_util as _module_util
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
|
||||
|
||||
# Make sure code inside the TensorFlow codebase can use tf2.enabled() at import.
|
||||
_os.environ['TF2_BEHAVIOR'] = '1'
|
||||
from tensorflow.python import tf2 as _tf2
|
||||
_tf2.enable()
|
||||
|
||||
# API IMPORTS PLACEHOLDER
|
||||
|
||||
# WRAPPER_PLACEHOLDER
|
||||
|
@ -683,7 +683,11 @@ TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
|
||||
|
||||
tensorflow::Tensor tensor(dtype, tensorflow::TensorShape({}));
|
||||
std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
|
||||
return TFE_TensorHandle::CreateLocalHandle(tensor, status);
|
||||
|
||||
status->status = tensorflow::Status::OK();
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(
|
||||
tensorflow::TensorHandle::CreateLocalHandle(tensor))};
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -703,7 +707,8 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
|
||||
} while (0);
|
||||
|
||||
// New server created for new server_def. Unused if updating server_def.
|
||||
tensorflow::EagerContext* context = ctx->context;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::GrpcServer* grpc_server =
|
||||
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||
if (grpc_server == nullptr) {
|
||||
@ -822,8 +827,7 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
node_def.add_input("dummy_input");
|
||||
}
|
||||
tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
tfe_op->operation.get())
|
||||
OperationFromInterface(tfe_op->operation)
|
||||
->Attrs()
|
||||
.FillAttrValueMap(node_def.mutable_attr());
|
||||
|
||||
|
@ -28,6 +28,9 @@ tf_cuda_library(
|
||||
"c_api_debug.cc",
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
"c_api_unified_experimental.h",
|
||||
"context_interface.cc",
|
||||
"context_interface.h",
|
||||
"operation_interface.cc",
|
||||
"operation_interface.h",
|
||||
"tensor_handle_interface.h",
|
||||
@ -62,6 +65,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
],
|
||||
}) + select({
|
||||
"//tensorflow:with_xla_support": [
|
||||
@ -95,6 +99,8 @@ filegroup(
|
||||
srcs = [
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
"c_api_unified_experimental.h",
|
||||
"context_interface.h",
|
||||
"dlpack.h",
|
||||
"operation_interface.h",
|
||||
"tensor_handle_interface.h",
|
||||
@ -109,6 +115,8 @@ tf_cuda_library(
|
||||
name = "c_api_internal",
|
||||
srcs = [
|
||||
"c_api_experimental.h",
|
||||
"c_api_unified_experimental.h",
|
||||
"context_interface.h",
|
||||
"operation_interface.h",
|
||||
"tensor_handle_interface.h",
|
||||
],
|
||||
@ -206,7 +214,6 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"//tensorflow/core/platform:casts",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
@ -215,8 +222,12 @@ tf_cuda_library(
|
||||
name = "c_api_experimental",
|
||||
srcs = [
|
||||
"c_api_experimental.cc",
|
||||
"c_api_unified_experimental.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"c_api_experimental.h",
|
||||
"c_api_unified_experimental.h",
|
||||
],
|
||||
hdrs = ["c_api_experimental.h"],
|
||||
copts = tf_copts() + tfe_xla_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
@ -242,6 +253,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
],
|
||||
}) + select({
|
||||
"//tensorflow:with_xla_support": [
|
||||
@ -293,6 +305,30 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "c_api_unified_experimental_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"c_api_unified_experimental_test.cc",
|
||||
],
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + ["nomac"],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "custom_device_test",
|
||||
size = "small",
|
||||
|
@ -305,7 +305,9 @@ tensorflow::Status CreateRemoteContexts(
|
||||
server_def.default_session_config());
|
||||
|
||||
std::vector<bool> filtered_device_mask;
|
||||
ctx->context->FilterDevicesForRemoteWorkers(
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
context->FilterDevicesForRemoteWorkers(
|
||||
remote_worker, base_request.cluster_device_attributes(),
|
||||
&filtered_device_mask);
|
||||
DCHECK_EQ(filtered_device_mask.size(),
|
||||
@ -388,7 +390,9 @@ tensorflow::Status UpdateRemoteContexts(
|
||||
}
|
||||
|
||||
std::vector<bool> filtered_device_mask;
|
||||
ctx->context->FilterDevicesForRemoteWorkers(
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
context->FilterDevicesForRemoteWorkers(
|
||||
remote_worker, base_request.cluster_device_attributes(),
|
||||
&filtered_device_mask);
|
||||
DCHECK_EQ(filtered_device_mask.size(), cluster_device_count);
|
||||
@ -467,7 +471,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
|
||||
// New server created for new server_def. Unused if updating server_def.
|
||||
std::unique_ptr<tensorflow::ServerInterface> new_server;
|
||||
tensorflow::EagerContext* context = ctx->context;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::GrpcServer* grpc_server;
|
||||
if (reset_context) {
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
|
||||
@ -696,14 +701,16 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
tensorflow::Rendezvous* r =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||
|
||||
return new TFE_Context{new tensorflow::EagerContext(
|
||||
opts->session_options.options,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
opts->device_placement_policy),
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
|
||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
|
||||
/*device_mgr_owned*/ true, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator())};
|
||||
return new TFE_Context{std::make_unique<tensorflow::ContextInterface>(
|
||||
new tensorflow::EagerContext(
|
||||
opts->session_options.options,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
opts->device_placement_policy),
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(
|
||||
opts->mirroring_policy),
|
||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
|
||||
/*device_mgr_owned*/ true, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator()))};
|
||||
}
|
||||
|
||||
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
|
||||
@ -714,20 +721,24 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
|
||||
tensorflow::Rendezvous* r =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr);
|
||||
|
||||
return new TFE_Context{new tensorflow::EagerContext(
|
||||
opts->session_options.options,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
opts->device_placement_policy),
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
|
||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
|
||||
/*device_mgr_owned*/ false, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator())};
|
||||
return new TFE_Context{std::make_unique<tensorflow::ContextInterface>(
|
||||
new tensorflow::EagerContext(
|
||||
opts->session_options.options,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
opts->device_placement_policy),
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(
|
||||
opts->mirroring_policy),
|
||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
|
||||
/*device_mgr_owned*/ false, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator()))};
|
||||
}
|
||||
|
||||
void TFE_DeleteContext(TFE_Context* ctx) {
|
||||
// context->RefCountIsOne() should be true here.
|
||||
// TODO(iga): Remove EagerContext refcounting.
|
||||
ctx->context->Unref();
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
context->Unref();
|
||||
|
||||
delete ctx;
|
||||
}
|
||||
@ -739,7 +750,9 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
|
||||
}
|
||||
|
||||
void TFE_ContextClearCaches(TFE_Context* ctx) {
|
||||
ctx->context->ClearCachesAndThreadExecutors();
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
context->ClearCachesAndThreadExecutors();
|
||||
}
|
||||
|
||||
// Set server_def on the context, possibly updating it.
|
||||
@ -769,8 +782,10 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
||||
device_filters[i] = tdf.second.device_filters(i);
|
||||
}
|
||||
const string remote_worker = remote_prefix + std::to_string(task_index);
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
status->status =
|
||||
ctx->context->SetRemoteDeviceFilters(remote_worker, device_filters);
|
||||
context->SetRemoteDeviceFilters(remote_worker, device_filters);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -789,11 +804,13 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
|
||||
"TFE_ContextSetServerDef not supported on mobile");
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
tensorflow::ServerDef server_def;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
if (!server_def.ParseFromArray(proto, proto_len)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Invalid tensorflow.ServerDef protocol buffer");
|
||||
return;
|
||||
} else if (ctx->context->GetContextId() ==
|
||||
} else if (context->GetContextId() ==
|
||||
tensorflow::EagerContext::kInvalidContextId) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Trying to update a context with invalid context id.");
|
||||
@ -817,7 +834,8 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
"TFE_ContextSetServerDef not supported on mobile");
|
||||
return false;
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
tensorflow::EagerContext* context = ctx->context;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::GrpcServer* grpc_server =
|
||||
static_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||
|
||||
@ -872,13 +890,17 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
||||
#if defined(IS_MOBILE_PLATFORM)
|
||||
status->status = tensorflow::Status::OK();
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
status->status = ctx->context->SyncExecutors();
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
status->status = context->SyncExecutors();
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
|
||||
ctx->context->SetThreadLocalDevicePlacementPolicy(
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
context->SetThreadLocalDevicePlacementPolicy(
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
|
||||
}
|
||||
|
||||
@ -887,15 +909,20 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||
// safe to call this function from the async EagerExecutor threads.
|
||||
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
|
||||
TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
return static_cast<TFE_ContextDevicePlacementPolicy>(
|
||||
ctx->context->GetDevicePlacementPolicy());
|
||||
context->GetDevicePlacementPolicy());
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
|
||||
tensorflow::Tensor tensor;
|
||||
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
return TFE_TensorHandle::CreateLocalHandle(tensor, status);
|
||||
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(
|
||||
tensorflow::TensorHandle::CreateLocalHandle(tensor))};
|
||||
}
|
||||
|
||||
void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
|
||||
@ -1050,10 +1077,12 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
|
||||
}
|
||||
|
||||
return new TFE_TensorHandle{
|
||||
std::unique_ptr<AbstractTensorHandleInterface>(h->handle->Copy())};
|
||||
std::unique_ptr<tensorflow::AbstractTensorHandleInterface>(
|
||||
h->handle->Copy())};
|
||||
}
|
||||
|
||||
AbstractTensorHandleInterface* tensorflow::TensorHandleInterface::Copy() {
|
||||
tensorflow::AbstractTensorHandleInterface*
|
||||
tensorflow::TensorHandleInterface::Copy() {
|
||||
handle_->Ref();
|
||||
return new TensorHandleInterface(handle_);
|
||||
}
|
||||
@ -1069,13 +1098,22 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return h->handle->Resolve(&status->status);
|
||||
std::unique_ptr<tensorflow::AbstractTensorInterface> t =
|
||||
h->handle->Resolve(&status->status);
|
||||
if (t == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tensorflow::Tensor tensor = tensorflow::TensorFromInterface(t);
|
||||
return tensorflow::TF_TensorFromTensor(tensor, &status->status);
|
||||
}
|
||||
|
||||
TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
|
||||
std::unique_ptr<tensorflow::AbstractTensorInterface>
|
||||
tensorflow::TensorHandleInterface::Resolve(Status* status) {
|
||||
if (!IsValid(status)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (VariantDeviceIsCustom(handle_->device())) {
|
||||
tensorflow::CustomDevice* custom_device =
|
||||
absl::get<tensorflow::CustomDevice*>(handle_->device());
|
||||
@ -1104,7 +1142,7 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
|
||||
h_cpu->Unref();
|
||||
return nullptr;
|
||||
}
|
||||
TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status);
|
||||
auto retval = std::make_unique<tensorflow::TensorInterface>(*t);
|
||||
h_cpu->Unref();
|
||||
return retval;
|
||||
} else {
|
||||
@ -1131,7 +1169,7 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
|
||||
if (!status->ok()) return nullptr;
|
||||
}
|
||||
}
|
||||
return tensorflow::TF_TensorFromTensor(tensor, status);
|
||||
return std::make_unique<tensorflow::TensorInterface>(std::move(tensor));
|
||||
}
|
||||
}
|
||||
|
||||
@ -1142,8 +1180,7 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||
->Handle();
|
||||
tensorflow::TensorHandleFromInterface(h->handle);
|
||||
if (VariantDeviceIsCustom(handle->device())) {
|
||||
const tensorflow::Tensor* t;
|
||||
status->status = handle->Tensor(&t);
|
||||
@ -1178,7 +1215,8 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
void (*deallocator)(void* data, size_t len, void* arg),
|
||||
void* deallocator_arg, TF_Status* status) {
|
||||
tensorflow::Device* device = nullptr;
|
||||
tensorflow::EagerContext* context = ctx->context;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
status->status = context->FindDeviceFromName(device_name, &device);
|
||||
tensorflow::CustomDevice* custom_device = nullptr;
|
||||
if (!status->status.ok()) {
|
||||
@ -1203,19 +1241,17 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
tensorflow::Tensor t(static_cast<tensorflow::DataType>(dtype),
|
||||
tensorflow::TensorShape(dimvec), buf);
|
||||
buf->Unref();
|
||||
tensorflow::TensorHandle* ret_handle;
|
||||
if (custom_device == nullptr) {
|
||||
status->status = tensorflow::TensorHandle::CreateLocalHandle(
|
||||
std::move(t), device, device, context, &ret_handle);
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(
|
||||
tensorflow::TensorHandle::CreateLocalHandle(std::move(t), device,
|
||||
device, context))};
|
||||
} else {
|
||||
status->status = tensorflow::TensorHandle::CreateLocalHandle(
|
||||
std::move(t), custom_device, context, &ret_handle);
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(
|
||||
tensorflow::TensorHandle::CreateLocalHandle(
|
||||
std::move(t), custom_device, context))};
|
||||
}
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(ret_handle)};
|
||||
}
|
||||
|
||||
// This function will block till the operation that produces `h` has
|
||||
@ -1229,9 +1265,7 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
||||
return 0;
|
||||
}
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||
->Handle();
|
||||
|
||||
tensorflow::TensorHandleFromInterface(h->handle);
|
||||
if (handle->IsRemote()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"TFE_TensorHandleDeviceMemorySize may not be called on a remote tensor "
|
||||
@ -1248,8 +1282,7 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
||||
|
||||
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op> new_op(
|
||||
new TFE_Op{std::make_unique<tensorflow::OperationInterface>(ctx)});
|
||||
std::unique_ptr<TFE_Op> new_op(new TFE_Op{ctx->context->CreateOperation()});
|
||||
status->status = new_op->operation->Reset(op_or_function_name, nullptr);
|
||||
if (!status->status.ok()) {
|
||||
new_op.reset();
|
||||
@ -1285,8 +1318,8 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
|
||||
|
||||
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
|
||||
TF_Status* status) {
|
||||
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>> handles(
|
||||
num_inputs);
|
||||
absl::FixedArray<std::unique_ptr<tensorflow::AbstractTensorHandleInterface>>
|
||||
handles(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
handles[i].reset(inputs[i]->handle->Copy());
|
||||
}
|
||||
@ -1383,7 +1416,10 @@ void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
|
||||
|
||||
void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
|
||||
TF_Status* status) {
|
||||
status->status = op->operation->SetAttrTensor(attr_name, tensor);
|
||||
tensorflow::Tensor t;
|
||||
status->status = TF_TensorToTensor(tensor, &t);
|
||||
status->status = op->operation->SetAttrTensor(
|
||||
attr_name, std::make_unique<tensorflow::TensorInterface>(t));
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
|
||||
@ -1480,8 +1516,8 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
|
||||
|
||||
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||
TF_Status* status) {
|
||||
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>> handles(
|
||||
*num_retvals);
|
||||
absl::FixedArray<std::unique_ptr<tensorflow::AbstractTensorHandleInterface>>
|
||||
handles(*num_retvals);
|
||||
status->status = op->operation->Execute(&handles, num_retvals);
|
||||
if (!status->status.ok()) {
|
||||
return;
|
||||
@ -1497,17 +1533,15 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
TF_Status* status) {
|
||||
tensorflow::TensorHandle* handle = nullptr;
|
||||
tensorflow::Device* device;
|
||||
tensorflow::EagerContext* context = ctx->context;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
status->status = context->FindDeviceFromName(device_name, &device);
|
||||
if (!status->status.ok()) {
|
||||
tensorflow::CustomDevice* dev;
|
||||
status->status = context->FindCustomDeviceFromName(device_name, &dev);
|
||||
if (status->status.ok()) {
|
||||
status->status = dev->CopyTensorToDevice(
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
h->handle.get())
|
||||
->Handle(),
|
||||
&handle);
|
||||
tensorflow::TensorHandleFromInterface(h->handle), &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
|
||||
@ -1524,10 +1558,7 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
status->status = context->FindCustomDeviceFromName(handle_device_name, &dev);
|
||||
if (status->status.ok()) {
|
||||
status->status = dev->CopyTensorFromDevice(
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
h->handle.get())
|
||||
->Handle(),
|
||||
device_name, &handle);
|
||||
tensorflow::TensorHandleFromInterface(h->handle), device_name, &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
|
||||
@ -1537,9 +1568,8 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
|
||||
// Handle regular case.
|
||||
status->status = tensorflow::EagerCopyToDevice(
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||
->Handle(),
|
||||
context, &context->Executor(), device, false, &handle);
|
||||
tensorflow::TensorHandleFromInterface(h->handle), context,
|
||||
&context->Executor(), device, false, &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
|
||||
@ -1556,41 +1586,56 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
|
||||
tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
|
||||
return;
|
||||
}
|
||||
status->status = ctx->context->AddFunctionDef(function_def);
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
status->status = context->AddFunctionDef(function_def);
|
||||
}
|
||||
|
||||
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
|
||||
TF_Status* status) {
|
||||
status->status = ctx->context->AddFunctionDef(function->fdef);
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
status->status = context->AddFunctionDef(function->fdef);
|
||||
}
|
||||
|
||||
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
|
||||
TF_Status* status) {
|
||||
status->status = ctx->context->RemoveFunction(name);
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
status->status = context->RemoveFunction(name);
|
||||
}
|
||||
|
||||
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
|
||||
return ctx->context->FindFunctionDef(name) != nullptr;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
return context->FindFunctionDef(name) != nullptr;
|
||||
}
|
||||
|
||||
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
|
||||
ctx->context->SetShouldStoreGraphs(true);
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
context->SetShouldStoreGraphs(true);
|
||||
}
|
||||
|
||||
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
|
||||
ctx->context->SetShouldStoreGraphs(false);
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
context->SetShouldStoreGraphs(false);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
|
||||
TF_Status* status) {
|
||||
return TFE_TensorHandle::CreateLocalHandle(t, status);
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(
|
||||
tensorflow::TensorHandle::CreateLocalHandle(t))};
|
||||
}
|
||||
|
||||
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context = ctx->context;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
status->status = context->Executor().WaitForAllPendingNodes();
|
||||
if (!status->status.ok()) return;
|
||||
tensorflow::mutex_lock ml(*context->MetadataMu());
|
||||
@ -1611,21 +1656,27 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); }
|
||||
void TFE_ContextStartStep(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
context->StartStep();
|
||||
}
|
||||
|
||||
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); }
|
||||
void TFE_ContextEndStep(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
context->EndStep();
|
||||
}
|
||||
|
||||
void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) {
|
||||
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
op->operation.get());
|
||||
*attrs = TFE_OpAttrs(&operation->Attrs(), op->operation->Name().c_str());
|
||||
tensorflow::EagerOperation* operation = OperationFromInterface(op->operation);
|
||||
*attrs = TFE_OpAttrs(&operation->Attrs(), operation->Name().c_str());
|
||||
}
|
||||
|
||||
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
|
||||
tensorflow::AttrValueMap m;
|
||||
attrs->attributes->FillAttrValueMap(&m);
|
||||
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
op->operation.get());
|
||||
tensorflow::EagerOperation* operation = OperationFromInterface(op->operation);
|
||||
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
|
||||
for (auto attribute : m) {
|
||||
destination->Set(attribute.first, attribute.second);
|
||||
@ -1721,9 +1772,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
TFE_TensorHandle* result_handle =
|
||||
device_.copy_tensor_to_device(context_, &tensor_handle, &status, info_);
|
||||
if (!status.status.ok()) return status.status;
|
||||
*result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
result_handle->handle.get())
|
||||
->Handle();
|
||||
*result = tensorflow::TensorHandleFromInterface(result_handle->handle);
|
||||
(*result)->Ref();
|
||||
delete result_handle;
|
||||
return status.status;
|
||||
@ -1740,9 +1789,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
|
||||
context_, &tensor_handle, target_device_name.c_str(), &status, info_);
|
||||
if (!status.status.ok()) return status.status;
|
||||
*result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
result_handle->handle.get())
|
||||
->Handle();
|
||||
*result = tensorflow::TensorHandleFromInterface(result_handle->handle);
|
||||
(*result)->Ref();
|
||||
delete result_handle;
|
||||
return status.status;
|
||||
@ -1766,9 +1813,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
&attributes, num_retvals, outputs.data(), &status, info_);
|
||||
if (status.status.ok()) {
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals[i] = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
outputs[i]->handle.get())
|
||||
->Handle();
|
||||
retvals[i] = tensorflow::TensorHandleFromInterface(outputs[i]->handle);
|
||||
retvals[i]->Ref();
|
||||
delete outputs[i];
|
||||
}
|
||||
@ -1793,6 +1838,8 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
||||
TF_Status* status) {
|
||||
auto custom_device =
|
||||
std::make_unique<CustomDeviceAPI>(ctx, device, device_info, device_name);
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
status->status =
|
||||
ctx->context->RegisterCustomDevice(device_name, std::move(custom_device));
|
||||
context->RegisterCustomDevice(device_name, std::move(custom_device));
|
||||
}
|
||||
|
@ -54,36 +54,32 @@ extern "C" {
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
TFE_TensorHandle* h, TF_Status* status) {
|
||||
return h->handle->TensorDebugInfo(&status->status);
|
||||
}
|
||||
|
||||
TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
|
||||
Status* status) {
|
||||
tensorflow::TensorHandle* handle = TensorHandleFromInterface(h->handle);
|
||||
const tensorflow::Tensor* tensor;
|
||||
*status = handle_->Tensor(&tensor);
|
||||
if (!status->ok()) {
|
||||
status->status = handle->Tensor(&tensor);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
tensorflow::Device* device = absl::get<Device*>(handle_->device());
|
||||
auto* device = absl::get<tensorflow::Device*>(handle->device());
|
||||
|
||||
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
|
||||
tensorflow::XlaDevice* xla_device =
|
||||
dynamic_cast<tensorflow::XlaDevice*>(device);
|
||||
auto* xla_device = dynamic_cast<tensorflow::XlaDevice*>(device);
|
||||
if (xla_device != nullptr) {
|
||||
tensorflow::XlaDevice::PaddedShapeFn shape_fn =
|
||||
xla_device->metadata().padded_shape_fn();
|
||||
xla::Shape padded_shape;
|
||||
*status = shape_fn(*tensor, &padded_shape);
|
||||
if (!status->ok()) {
|
||||
status->status = shape_fn(*tensor, &padded_shape);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (VLOG_IS_ON(3)) {
|
||||
std::vector<int64> shape_to_log = TensorShapeAsVector(*handle_, status);
|
||||
if (!status->ok()) {
|
||||
std::vector<int64> shape_to_log =
|
||||
TensorShapeAsVector(*handle, &status->status);
|
||||
if (!status->status.ok()) {
|
||||
// Ignore the status here as we are simply logging.
|
||||
*status = tensorflow::Status::OK();
|
||||
status->status = tensorflow::Status::OK();
|
||||
} else {
|
||||
VLOG(3) << "Fully padded shape of ["
|
||||
<< absl::StrJoin(shape_to_log, ", ") << "] is "
|
||||
@ -96,7 +92,7 @@ TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
|
||||
// Currently, the only case of XlaTensor containing a tuple shape is to
|
||||
// represent 64 bit ints, doubles, and complex numbers (we don't support
|
||||
// 64bit complex numbers).
|
||||
*status = tensorflow::errors::InvalidArgument(
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"XlaTensors should only contain tuples of size 2. Shape: ",
|
||||
padded_shape.DebugString());
|
||||
return nullptr;
|
||||
@ -108,13 +104,13 @@ TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
|
||||
const xla::Shape& shape1 =
|
||||
xla::ShapeUtil::GetTupleElementShape(padded_shape, 1);
|
||||
if (shape0.IsTuple() || shape1.IsTuple()) {
|
||||
*status = tensorflow::errors::InvalidArgument(
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"XlaTensors should not contain nested tuples. Shape: ",
|
||||
padded_shape.DebugString());
|
||||
return nullptr;
|
||||
}
|
||||
if (!xla::ShapeUtil::Equal(shape0, shape1)) {
|
||||
*status = tensorflow::errors::InvalidArgument(
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Subshapes of XlaTensors should be the same. Shape: ",
|
||||
padded_shape.DebugString());
|
||||
return nullptr;
|
||||
@ -139,15 +135,15 @@ TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
|
||||
dev_dims.push_back(padded_shape.dimensions(dim_index));
|
||||
}
|
||||
}
|
||||
*status = tensorflow::Status::OK();
|
||||
status->status = tensorflow::Status::OK();
|
||||
return new TFE_TensorDebugInfo(dev_dims);
|
||||
}
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
|
||||
// If the tensor is not an XLA tensor, the device shape is
|
||||
// the same as regular tensor shape.
|
||||
std::vector<int64> dev_dims = TensorShapeAsVector(*handle_, status);
|
||||
if (!status->ok()) {
|
||||
std::vector<int64> dev_dims = TensorShapeAsVector(*handle, &status->status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TFE_TensorDebugInfo(dev_dims);
|
||||
|
@ -31,6 +31,7 @@ using tensorflow::string;
|
||||
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
||||
const char* raw_device_name, TF_Status* status) {
|
||||
if (op_to_reset) {
|
||||
op_to_reset->operation->Clear();
|
||||
status->status =
|
||||
op_to_reset->operation->Reset(op_or_function_name, raw_device_name);
|
||||
} else {
|
||||
@ -40,11 +41,15 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
||||
}
|
||||
|
||||
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
|
||||
ctx->context->SetShouldStoreGraphs(true);
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
context->SetShouldStoreGraphs(true);
|
||||
}
|
||||
|
||||
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
|
||||
ctx->context->SetShouldStoreGraphs(false);
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
context->SetShouldStoreGraphs(false);
|
||||
}
|
||||
|
||||
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
|
||||
@ -474,7 +479,9 @@ void TFE_ContextOptionsSetMirroringPolicy(TFE_ContextOptions* options,
|
||||
|
||||
void TFE_ContextSetThreadLocalMirroringPolicy(
|
||||
TFE_Context* ctx, TFE_ContextMirroringPolicy policy) {
|
||||
ctx->context->SetThreadLocalMirroringPolicy(
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
context->SetThreadLocalMirroringPolicy(
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(policy));
|
||||
}
|
||||
|
||||
@ -483,8 +490,9 @@ void TFE_ContextSetThreadLocalMirroringPolicy(
|
||||
// safe to call this function from the async EagerExecutor threads.
|
||||
extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
|
||||
TFE_Context* ctx) {
|
||||
return static_cast<TFE_ContextMirroringPolicy>(
|
||||
ctx->context->GetMirroringPolicy());
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
return static_cast<TFE_ContextMirroringPolicy>(context->GetMirroringPolicy());
|
||||
}
|
||||
|
||||
void TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions* options,
|
||||
@ -492,6 +500,10 @@ void TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions* options,
|
||||
options->lazy_remote_inputs_copy = lazy_copy;
|
||||
}
|
||||
|
||||
void TFE_ContextOptionsSetTfrt(TFE_ContextOptions* options, bool use_tfrt) {
|
||||
options->use_tfrt = use_tfrt;
|
||||
}
|
||||
|
||||
TFE_CancellationManager* TFE_NewCancellationManager() {
|
||||
return new TFE_CancellationManager;
|
||||
}
|
||||
@ -514,7 +526,11 @@ void TFE_DeleteCancellationManager(
|
||||
void TFE_OpSetCancellationManager(TFE_Op* op,
|
||||
TFE_CancellationManager* cancellation_manager,
|
||||
TF_Status* status) {
|
||||
status->status = op->operation->SetCancellationManager(cancellation_manager);
|
||||
tensorflow::EagerOperation* operation =
|
||||
tensorflow::OperationFromInterface(op->operation);
|
||||
operation->SetCancellationManager(
|
||||
&cancellation_manager->cancellation_manager);
|
||||
status->status = tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
TFE_Executor* TFE_NewExecutor(bool is_async) {
|
||||
@ -537,16 +553,22 @@ void TFE_ExecutorClearError(TFE_Executor* executor) {
|
||||
}
|
||||
|
||||
void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
|
||||
ctx->context->SetExecutorForThread(executor->executor());
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
context->SetExecutorForThread(executor->executor());
|
||||
}
|
||||
|
||||
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
|
||||
return new TFE_Executor(&ctx->context->Executor());
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
return new TFE_Executor(&context->Executor());
|
||||
}
|
||||
|
||||
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
|
||||
ctx->context->HostCPU()->parsed_name());
|
||||
context->HostCPU()->parsed_name());
|
||||
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
|
||||
void* data = tensorflow::port::Malloc(str.length());
|
||||
str.copy(static_cast<char*>(data), str.length(), 0);
|
||||
@ -565,7 +587,9 @@ void TFE_TensorHandleEnableImplicitMirroring(TFE_TensorHandle* h,
|
||||
|
||||
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
|
||||
TF_Buffer* buf, TF_Status* status) {
|
||||
auto* function_def = ctx->context->FindFunctionDef(function_name);
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
auto* function_def = context->FindFunctionDef(function_name);
|
||||
if (function_def == nullptr) {
|
||||
status->status = tensorflow::errors::NotFound(
|
||||
"Unable to find FunctionDef with name: ", function_name);
|
||||
|
@ -296,6 +296,10 @@ TF_CAPI_EXPORT extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
|
||||
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetLazyRemoteInputsCopy(
|
||||
TFE_ContextOptions*, bool lazy_copy);
|
||||
|
||||
// Sets whether to use TFRT
|
||||
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetTfrt(TFE_ContextOptions*,
|
||||
bool use_tfrt);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Cancellation APIs.
|
||||
|
||||
|
@ -455,6 +455,7 @@ TEST(CAPI, TensorHandleOnDeviceMemory) {
|
||||
TFE_DeleteTensorHandle(copy_aliased); // Note that this will delete copy.
|
||||
TFE_DeleteTensorHandle(on_host);
|
||||
}
|
||||
TF_DeleteDeviceList(devices);
|
||||
TF_DeleteTensor(m_data);
|
||||
TFE_DeleteTensorHandle(m);
|
||||
TFE_DeleteContext(ctx);
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/context_interface.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
@ -59,25 +60,16 @@ struct TFE_ContextOptions {
|
||||
TFE_ContextMirroringPolicy mirroring_policy{TFE_MIRRORING_NONE};
|
||||
// If true, lazily copy the remote inputs of a function to the target devices.
|
||||
bool lazy_remote_inputs_copy = true;
|
||||
// If true, use TFRT backend
|
||||
bool use_tfrt = false;
|
||||
};
|
||||
|
||||
struct TFE_Context {
|
||||
tensorflow::EagerContext* context;
|
||||
std::unique_ptr<tensorflow::AbstractContextInterface> context;
|
||||
};
|
||||
|
||||
struct TFE_TensorHandle {
|
||||
static TFE_TensorHandle* CreateLocalHandle(const class tensorflow::Tensor& t,
|
||||
TF_Status* s) {
|
||||
tensorflow::TensorHandle* handle;
|
||||
s->status = tensorflow::TensorHandle::CreateLocalHandle(t, &handle);
|
||||
if (!s->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorHandleInterface> handle;
|
||||
std::unique_ptr<tensorflow::AbstractTensorHandleInterface> handle;
|
||||
};
|
||||
|
||||
struct TFE_TensorDebugInfo {
|
||||
@ -89,7 +81,7 @@ struct TFE_TensorDebugInfo {
|
||||
};
|
||||
|
||||
struct TFE_Op {
|
||||
std::unique_ptr<AbstractOperationInterface> operation;
|
||||
std::unique_ptr<tensorflow::AbstractOperationInterface> operation;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringCounterCell {
|
||||
|
@ -184,9 +184,7 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||
|
||||
// TODO(gjn): Add support for waiting on async local mirrors
|
||||
if (!async) {
|
||||
auto remote_arg = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
h1_task2->handle.get())
|
||||
->Handle();
|
||||
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
|
||||
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
matmul->operation.get());
|
||||
// The input handles should never change since they have been mirrored.
|
||||
|
@ -409,13 +409,8 @@ void TensorHandleSilentCopy(bool async,
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Validate if the input was replaced with a different TensorHandle
|
||||
auto arg0 = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
hcpu->handle.get())
|
||||
->Handle();
|
||||
auto arg1 = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
hgpu->handle.get())
|
||||
->Handle();
|
||||
|
||||
auto arg0 = tensorflow::TensorHandleFromInterface(hcpu->handle);
|
||||
auto arg1 = tensorflow::TensorHandleFromInterface(hgpu->handle);
|
||||
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
matmul->operation.get());
|
||||
|
||||
|
261
tensorflow/c/eager/c_api_unified_experimental.cc
Normal file
261
tensorflow/c/eager/c_api_unified_experimental.cc
Normal file
@ -0,0 +1,261 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
// =============================================================================
|
||||
// Unified Execution APIs for Eager and tracing backends.
|
||||
// =============================================================================
|
||||
|
||||
typedef void (*ExecuteOperation)(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs,
|
||||
TF_OutputList* o, TF_ExecutionContext* ctx,
|
||||
TF_Status* s);
|
||||
struct TF_ExecutionContext {
|
||||
explicit TF_ExecutionContext() {}
|
||||
absl::variant<TFE_Context*, TF_GraphContext*> ctx;
|
||||
ExecuteOperation execution_callback;
|
||||
};
|
||||
|
||||
struct TF_AbstractTensor {
|
||||
absl::variant<TFE_TensorHandle*, TF_GraphTensor*> t;
|
||||
};
|
||||
|
||||
struct TF_AbstractOp {
|
||||
string op_type;
|
||||
string op_name;
|
||||
};
|
||||
|
||||
TF_ExecutionContext* TF_NewExecutionContext() {
|
||||
return new TF_ExecutionContext();
|
||||
}
|
||||
|
||||
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete c; }
|
||||
|
||||
TF_AbstractOp* TF_NewAbstractOp() {
|
||||
TF_AbstractOp* op = new TF_AbstractOp;
|
||||
return op;
|
||||
}
|
||||
|
||||
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete op; }
|
||||
|
||||
TF_AbstractTensor* TF_NewAbstractTensor() {
|
||||
TF_AbstractTensor* t = new TF_AbstractTensor;
|
||||
return t;
|
||||
}
|
||||
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete t; }
|
||||
|
||||
struct TF_GraphContext {
|
||||
TF_Graph* graph;
|
||||
// TODO(srbs): Handle captures.
|
||||
};
|
||||
|
||||
TF_GraphContext* TF_NewGraphContext(TF_Graph* g) {
|
||||
auto ctx = new TF_GraphContext;
|
||||
ctx->graph = g;
|
||||
return ctx;
|
||||
}
|
||||
|
||||
void TF_DeleteGraphContext(TF_GraphContext* ctx) { delete ctx; }
|
||||
|
||||
struct TF_GraphTensor {
|
||||
TF_Output output;
|
||||
TF_GraphContext* ctx;
|
||||
};
|
||||
TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* ctx, TF_Output output,
|
||||
TF_Status* s) {
|
||||
TF_GraphTensor* t = new TF_GraphTensor;
|
||||
t->output = output;
|
||||
t->ctx = ctx;
|
||||
return t;
|
||||
}
|
||||
TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s) {
|
||||
return t->output;
|
||||
}
|
||||
void TF_DeleteGraphTensor(TF_GraphTensor* t) { delete t; }
|
||||
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
|
||||
TF_Status* s) {
|
||||
at->t = t;
|
||||
}
|
||||
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s) {
|
||||
if (!absl::holds_alternative<TFE_TensorHandle*>(at->t)) {
|
||||
string msg = absl::StrCat("Not an eager tensor handle.",
|
||||
reinterpret_cast<uintptr_t>(at));
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
return absl::get<TFE_TensorHandle*>(at->t);
|
||||
}
|
||||
void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t,
|
||||
TF_Status* s) {
|
||||
at->t = t;
|
||||
}
|
||||
TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s) {
|
||||
if (!absl::holds_alternative<TF_GraphTensor*>(at->t)) {
|
||||
string msg = absl::StrCat("Not an graph tensor handle.");
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
return absl::get<TF_GraphTensor*>(at->t);
|
||||
}
|
||||
|
||||
bool IsEagerTensor(const TF_AbstractTensor* const t) {
|
||||
return absl::holds_alternative<TFE_TensorHandle*>(t->t);
|
||||
}
|
||||
|
||||
struct TF_OutputList {
|
||||
std::vector<TF_AbstractTensor*> outputs;
|
||||
int expected_num_outputs = -1;
|
||||
};
|
||||
|
||||
TF_OutputList* TF_NewOutputList() { return new TF_OutputList; }
|
||||
void TF_DeleteOutputList(TF_OutputList* o) { delete o; }
|
||||
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
|
||||
TF_Status* s) {
|
||||
o->expected_num_outputs = num_outputs;
|
||||
}
|
||||
int TF_OutputListNumOutputs(TF_OutputList* o) { return o->outputs.size(); }
|
||||
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
|
||||
return o->outputs[i];
|
||||
}
|
||||
|
||||
void ExecuteOperationEager(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_ExecutionContext* ctx, TF_Status* s) {
|
||||
auto* tfe_op =
|
||||
TFE_NewOp(absl::get<TFE_Context*>(ctx->ctx), op->op_type.c_str(), s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
if (!IsEagerTensor(inputs[i])) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
|
||||
return;
|
||||
}
|
||||
TFE_OpAddInput(tfe_op, absl::get<TFE_TensorHandle*>(inputs[i]->t), s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
}
|
||||
if (o->expected_num_outputs == -1) {
|
||||
string msg =
|
||||
"The number of outputs must be provided in eager mode. Use "
|
||||
"TF_OutputListSetNumOutputs.";
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return;
|
||||
}
|
||||
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals;
|
||||
int num_retvals = o->expected_num_outputs;
|
||||
retvals.resize(num_retvals);
|
||||
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
|
||||
TFE_DeleteOp(tfe_op);
|
||||
if (TF_GetCode(s) != TF_OK) {
|
||||
return;
|
||||
}
|
||||
o->outputs.clear();
|
||||
o->outputs.reserve(num_retvals);
|
||||
for (int i = 0; i < num_retvals; ++i) {
|
||||
auto* t = TF_NewAbstractTensor();
|
||||
t->t = retvals[i];
|
||||
o->outputs.push_back(t);
|
||||
}
|
||||
}
|
||||
|
||||
TF_GraphContext* GetGraphContext(TF_AbstractTensor const* t) {
|
||||
return absl::get<TF_GraphTensor*>(t->t)->ctx;
|
||||
}
|
||||
|
||||
void ExecuteOperationGraph(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_ExecutionContext* ctx, TF_Status* s) {
|
||||
TF_GraphContext* graph_ctx = absl::get<TF_GraphContext*>(ctx->ctx);
|
||||
TF_Graph* g = graph_ctx->graph;
|
||||
auto* tf_opdesc =
|
||||
TF_NewOperation(g, op->op_type.c_str(), op->op_name.c_str());
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
auto* input = inputs[i];
|
||||
if (IsEagerTensor(input)) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Capturing eager tensors is not supported yet.");
|
||||
return;
|
||||
} else {
|
||||
if (GetGraphContext(input) != graph_ctx) {
|
||||
TF_SetStatus(
|
||||
s, TF_INVALID_ARGUMENT,
|
||||
"Capturing tensors from other graphs is not supported yet.");
|
||||
return;
|
||||
}
|
||||
TF_AddInput(tf_opdesc, absl::get<TF_GraphTensor*>(input->t)->output);
|
||||
}
|
||||
}
|
||||
auto* operation = TF_FinishOperation(tf_opdesc, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
int num_outputs = TF_OperationNumOutputs(operation);
|
||||
o->outputs.clear();
|
||||
o->outputs.reserve(num_outputs);
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
auto* t = TF_NewAbstractTensor();
|
||||
TF_GraphTensor* output_t = TF_NewGraphTensor(graph_ctx, {operation, i}, s);
|
||||
if (TF_GetCode(s) != TF_OK) {
|
||||
return;
|
||||
}
|
||||
t->t = output_t;
|
||||
o->outputs.push_back(t);
|
||||
}
|
||||
}
|
||||
|
||||
void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context,
|
||||
TFE_Context* eager_context,
|
||||
TF_Status* s) {
|
||||
context->ctx = eager_context;
|
||||
context->execution_callback = &ExecuteOperationEager;
|
||||
}
|
||||
|
||||
void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context,
|
||||
TF_GraphContext* graph_context,
|
||||
TF_Status* s) {
|
||||
context->ctx = graph_context;
|
||||
context->execution_callback = &ExecuteOperationGraph;
|
||||
}
|
||||
|
||||
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
||||
TF_Status* s) {
|
||||
op->op_type = op_type;
|
||||
}
|
||||
|
||||
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
|
||||
TF_Status* s) {
|
||||
op->op_name = op_name;
|
||||
}
|
||||
|
||||
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_ExecutionContext* ctx, TF_Status* s) {
|
||||
ctx->execution_callback(op, num_inputs, inputs, o, ctx, s);
|
||||
}
|
119
tensorflow/c/eager/c_api_unified_experimental.h
Normal file
119
tensorflow/c/eager/c_api_unified_experimental.h
Normal file
@ -0,0 +1,119 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
|
||||
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// =============================================================================
|
||||
// Unified Execution APIs for Eager and tracing backends.
|
||||
// =============================================================================
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Core APIs
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// A TF_ExecutionContext stores knowledge about how to execute an operation.
|
||||
// E.g. it could know whether we're in eager mode or in graph mode, keeps track
|
||||
// of gradient tapes, etc.
|
||||
typedef struct TF_ExecutionContext TF_ExecutionContext;
|
||||
// A TF_AbstractTensor is an input to an operation. E.g. it could be a union
|
||||
// type of eager and graph tensors.
|
||||
typedef struct TF_AbstractTensor TF_AbstractTensor;
|
||||
// A TF_AbstractOp is the metadata we need to execute an operation. E.g. this
|
||||
// could contain the op type and other attributes.
|
||||
typedef struct TF_AbstractOp TF_AbstractOp;
|
||||
|
||||
TF_ExecutionContext* TF_NewExecutionContext();
|
||||
void TF_DeleteExecutionContext(TF_ExecutionContext*);
|
||||
|
||||
TF_AbstractOp* TF_NewAbstractOp();
|
||||
void TF_DeleteAbstractOp(TF_AbstractOp*);
|
||||
|
||||
TF_AbstractTensor* TF_NewAbstractTensor();
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// APIs for Eager and graph modes
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// Keeps track of the current graph and other state e.g. captures etc.
|
||||
typedef struct TF_GraphContext TF_GraphContext;
|
||||
TF_GraphContext* TF_NewGraphContext(TF_Graph*);
|
||||
void TF_DeleteGraphContext(TF_GraphContext*);
|
||||
|
||||
// `eager_context` must outlive `context`.
|
||||
void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context,
|
||||
TFE_Context* eager_context, TF_Status*);
|
||||
// `graph_context` must outlive `context`.
|
||||
void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context,
|
||||
TF_GraphContext* graph_context,
|
||||
TF_Status*);
|
||||
|
||||
// TODO(srbs): Add APIs for specifying attrs etc.
|
||||
// `op_type` must outlive `op`.
|
||||
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
||||
TF_Status* s);
|
||||
// `op_name` must outlive `op`.
|
||||
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
|
||||
TF_Status* s);
|
||||
|
||||
// Wrapper for TF_Output but contains a pointer to TF_GraphContext as well.
|
||||
typedef struct TF_GraphTensor TF_GraphTensor;
|
||||
TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* c, TF_Output t,
|
||||
TF_Status* s);
|
||||
TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s);
|
||||
void TF_DeleteGraphTensor(TF_GraphTensor* t);
|
||||
|
||||
// `t` must outlive `at`.
|
||||
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
|
||||
TF_Status* s);
|
||||
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s);
|
||||
|
||||
// `t` must outlive `at`.
|
||||
void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t,
|
||||
TF_Status* s);
|
||||
TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s);
|
||||
|
||||
// TF_OutputList just lets us not specify the number of outputs of an operation
|
||||
// beforehand. This forces a memory allocation in the runtime, which is bad, but
|
||||
// it allows for generic code.
|
||||
typedef struct TF_OutputList TF_OutputList;
|
||||
TF_OutputList* TF_NewOutputList();
|
||||
void TF_DeleteOutputList(TF_OutputList* o);
|
||||
void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*);
|
||||
int TF_OutputListNumOutputs(TF_OutputList* o);
|
||||
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
|
||||
|
||||
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
|
||||
// capture some inputs and then add a node in the graph, and after
|
||||
// execution/node creation it'll go and record things that happened in any tape
|
||||
// which happens to be active.
|
||||
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_ExecutionContext* ctx, TF_Status* s);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
|
204
tensorflow/c/eager/c_api_unified_experimental_test.cc
Normal file
204
tensorflow/c/eager/c_api_unified_experimental_test.cc
Normal file
@ -0,0 +1,204 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/cc/profiler/profiler.h"
|
||||
#include "tensorflow/core/lib/monitoring/collection_registry.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(UnifedCAPI, TestBasicEager) {
|
||||
TF_ExecutionContext* ctx = TF_NewExecutionContext();
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* eager_ctx = TFE_NewContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
// Enter the eager context.
|
||||
TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract input tensor.
|
||||
TFE_TensorHandle* t = TestScalarTensorHandle(2.0f);
|
||||
TF_AbstractTensor* at = TF_NewAbstractTensor();
|
||||
TF_AbstractTensorSetEagerTensor(at, t, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract operation.
|
||||
auto* op = TF_NewAbstractOp();
|
||||
TF_AbstractOpSetOpType(op, "Add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_AbstractTensor* inputs[2] = {at, at};
|
||||
TF_OutputList* o = TF_NewOutputList();
|
||||
TF_OutputListSetNumOutputs(o, 1, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Execute.
|
||||
TF_ExecuteOperation(op, 2, inputs, o, ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Clean up operation and inputs.
|
||||
TF_DeleteAbstractOp(op);
|
||||
TF_DeleteAbstractTensor(at);
|
||||
TFE_DeleteTensorHandle(t);
|
||||
|
||||
// Verify the results.
|
||||
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
|
||||
TF_AbstractTensor* result = TF_OutputListGet(o, 0);
|
||||
TFE_TensorHandle* result_t =
|
||||
TF_AbstractTensorGetEagerTensor(result, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_Tensor* result_tensor = TFE_TensorHandleResolve(result_t, status.get());
|
||||
float* result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, 4.0);
|
||||
|
||||
TF_DeleteTensor(result_tensor);
|
||||
TF_DeleteAbstractTensor(result);
|
||||
TFE_DeleteTensorHandle(result_t);
|
||||
TF_DeleteOutputList(o);
|
||||
TFE_DeleteContext(eager_ctx);
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TestBasicGraph) {
|
||||
TF_ExecutionContext* ctx = TF_NewExecutionContext();
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
||||
// Enter a graph context.
|
||||
TF_Graph* g = TF_NewGraph();
|
||||
TF_GraphContext* graph_context = TF_NewGraphContext(g);
|
||||
TF_ExecutionContextSetGraphContext(ctx, graph_context, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
auto* placeholder_op = TF_NewOperation(g, "Placeholder", "Placeholder");
|
||||
TF_SetAttrType(placeholder_op, "dtype", TF_FLOAT);
|
||||
auto* operation = TF_FinishOperation(placeholder_op, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_Output placeholder_t = {operation, 0};
|
||||
TF_GraphTensor* graph_t =
|
||||
TF_NewGraphTensor(graph_context, placeholder_t, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractTensor* t = TF_NewAbstractTensor();
|
||||
TF_AbstractTensorSetGraphTensor(t, graph_t, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract operation.
|
||||
auto* op = TF_NewAbstractOp();
|
||||
TF_AbstractOpSetOpType(op, "Add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetOpName(op, "my_add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_AbstractTensor* inputs[2] = {t, t};
|
||||
TF_OutputList* o = TF_NewOutputList();
|
||||
|
||||
// Execute.
|
||||
TF_ExecuteOperation(op, 2, inputs, o, ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Clean up operation and inputs.
|
||||
TF_DeleteAbstractOp(op);
|
||||
TF_DeleteAbstractTensor(t);
|
||||
TF_DeleteGraphTensor(graph_t);
|
||||
|
||||
TF_AbstractTensor* result = TF_OutputListGet(o, 0);
|
||||
TF_GraphTensor* result_graph_tensor =
|
||||
TF_AbstractTensorGetGraphTensor(result, status.get());
|
||||
TF_DeleteAbstractTensor(result);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_Output result_output =
|
||||
TF_GraphTensorToOutput(result_graph_tensor, status.get());
|
||||
TF_DeleteGraphTensor(result_graph_tensor);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
string fn_name = "double";
|
||||
TF_Function* f = TF_GraphToFunction(
|
||||
g, fn_name.c_str(), 0, -1, nullptr, 1, &placeholder_t, 1, &result_output,
|
||||
nullptr, nullptr, fn_name.c_str(), status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an eager context to run the function.
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* eager_ctx = TFE_NewContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
// Build the abstract op to run the function.
|
||||
TFE_ContextAddFunction(eager_ctx, f, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOp* fn_op = TF_NewAbstractOp();
|
||||
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract input tensor.
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(2.0f);
|
||||
TF_AbstractTensor* input_t = TF_NewAbstractTensor();
|
||||
TF_AbstractTensorSetEagerTensor(input_t, input_eager, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Enter the eager context.
|
||||
TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_OutputListSetNumOutputs(o, 1, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_ExecuteOperation(fn_op, 1, &input_t, o, ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
|
||||
TF_AbstractTensor* final_result = TF_OutputListGet(o, 0);
|
||||
TFE_TensorHandle* final =
|
||||
TF_AbstractTensorGetEagerTensor(final_result, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_Tensor* f_t = TFE_TensorHandleResolve(final, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
float* f_value = static_cast<float*>(TF_TensorData(f_t));
|
||||
ASSERT_EQ(*f_value, 4.0);
|
||||
|
||||
TF_DeleteOutputList(o);
|
||||
TF_DeleteAbstractOp(fn_op);
|
||||
TF_DeleteAbstractTensor(input_t);
|
||||
TFE_DeleteTensorHandle(input_eager);
|
||||
TF_DeleteAbstractTensor(final_result);
|
||||
TFE_DeleteTensorHandle(final);
|
||||
TF_DeleteTensor(f_t);
|
||||
TF_DeleteFunction(f);
|
||||
|
||||
TF_DeleteGraphContext(graph_context);
|
||||
TF_DeleteGraph(g);
|
||||
TFE_DeleteContext(eager_ctx);
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
143
tensorflow/c/eager/context_interface.cc
Normal file
143
tensorflow/c/eager/context_interface.cc
Normal file
@ -0,0 +1,143 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/context_interface.h"
|
||||
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/core/framework/tensor_interface.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateInt64Scalar(
|
||||
int64 value) {
|
||||
return std::make_unique<TensorInterface>(Tensor(value));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateUint64Scalar(
|
||||
uint64 value) {
|
||||
return std::make_unique<TensorInterface>(Tensor(value));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateInt32Scalar(
|
||||
int32 value) {
|
||||
return std::make_unique<TensorInterface>(Tensor(value));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateFloatScalar(
|
||||
float value) {
|
||||
return std::make_unique<TensorInterface>(Tensor(value));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateDoubleScalar(
|
||||
double value) {
|
||||
return std::make_unique<TensorInterface>(Tensor(value));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateHalfScalar(
|
||||
Eigen::half value) {
|
||||
return std::make_unique<TensorInterface>(Tensor(value));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateStringScalar(
|
||||
tstring value) {
|
||||
return std::make_unique<TensorInterface>(Tensor(value));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface>
|
||||
ContextInterface::CreateComplex128Scalar(complex128 value) {
|
||||
return std::make_unique<TensorInterface>(Tensor(value));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateBoolScalar(
|
||||
bool value) {
|
||||
return std::make_unique<TensorInterface>(Tensor(value));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateInt64Tensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
return std::make_unique<TensorInterface>(
|
||||
Tensor(DT_INT64, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateUint64Tensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
return std::make_unique<TensorInterface>(
|
||||
Tensor(DT_UINT64, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateInt32Tensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
return std::make_unique<TensorInterface>(
|
||||
Tensor(DT_INT32, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateFloatTensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
return std::make_unique<TensorInterface>(
|
||||
Tensor(DT_FLOAT, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateDoubleTensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
return std::make_unique<TensorInterface>(
|
||||
Tensor(DT_DOUBLE, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateHalfTensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
return std::make_unique<TensorInterface>(
|
||||
Tensor(DT_HALF, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateStringTensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
return std::make_unique<TensorInterface>(
|
||||
Tensor(DT_STRING, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface>
|
||||
ContextInterface::CreateComplex128Tensor(absl::Span<const int64> dim_sizes) {
|
||||
return std::make_unique<TensorInterface>(
|
||||
Tensor(DT_COMPLEX128, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateBoolTensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
return std::make_unique<TensorInterface>(
|
||||
Tensor(DT_BOOL, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorHandleInterface>
|
||||
ContextInterface::CreateLocalHandle(
|
||||
const std::unique_ptr<AbstractTensorInterface> t) {
|
||||
Tensor tensor = tensorflow::down_cast<TensorInterface*>(t.get())->Tensor();
|
||||
return std::make_unique<TensorHandleInterface>(
|
||||
TensorHandle::CreateLocalHandle(std::move(tensor), /*d=*/ctx_->HostCPU(),
|
||||
/*op_device=*/nullptr, ctx_));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractOperationInterface>
|
||||
ContextInterface::CreateOperation() {
|
||||
return std::make_unique<tensorflow::OperationInterface>(ctx_);
|
||||
}
|
||||
|
||||
void ContextInterface::ListDevices(
|
||||
std::vector<tensorflow::DeviceAttributes>* devices) {
|
||||
ctx_->ListDevices(devices);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
155
tensorflow/c/eager/context_interface.h
Normal file
155
tensorflow/c/eager/context_interface.h
Normal file
@ -0,0 +1,155 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
|
||||
#define TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/tensor_interface.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to a context.
|
||||
//
|
||||
// A context is responsible for creating key objects such as Tensors,
|
||||
// TensorHandles & Operations.
|
||||
class AbstractContextInterface {
|
||||
public:
|
||||
virtual ~AbstractContextInterface() {}
|
||||
|
||||
// Scalar creation functions
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateInt64Scalar(
|
||||
int64 value) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateUint64Scalar(
|
||||
uint64 value) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateInt32Scalar(
|
||||
int32 value) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateFloatScalar(
|
||||
float value) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateDoubleScalar(
|
||||
double value) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateHalfScalar(
|
||||
Eigen::half value) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateStringScalar(
|
||||
tstring value) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateComplex128Scalar(
|
||||
complex128 value) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateBoolScalar(
|
||||
bool value) = 0;
|
||||
|
||||
// Tensor creation functions
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateInt64Tensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateUint64Tensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateInt32Tensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateFloatTensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateDoubleTensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateHalfTensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateStringTensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateComplex128Tensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateBoolTensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
|
||||
// Create a handle to wrap and manage a Tensor
|
||||
virtual std::unique_ptr<AbstractTensorHandleInterface> CreateLocalHandle(
|
||||
const std::unique_ptr<AbstractTensorInterface> t) = 0;
|
||||
|
||||
// Create an operation to perform op execution
|
||||
virtual std::unique_ptr<AbstractOperationInterface> CreateOperation() = 0;
|
||||
|
||||
// List attributes of available devices
|
||||
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
|
||||
};
|
||||
|
||||
// TODO(gjn): Try to move these all to EagerContext and make it implement
|
||||
// AbstractContextInterface. Currently, this is not so straightforward because
|
||||
// of various BUILD file dependencies.
|
||||
class ContextInterface : public AbstractContextInterface {
|
||||
public:
|
||||
explicit ContextInterface(EagerContext* ctx) : ctx_(ctx) {}
|
||||
~ContextInterface() override {}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> CreateInt64Scalar(
|
||||
int64 value) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateUint64Scalar(
|
||||
uint64 value) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateInt32Scalar(
|
||||
int32 value) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateFloatScalar(
|
||||
float value) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateDoubleScalar(
|
||||
double value) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateHalfScalar(
|
||||
Eigen::half value) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateStringScalar(
|
||||
tensorflow::tstring value) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateComplex128Scalar(
|
||||
tensorflow::complex128 value) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateBoolScalar(
|
||||
bool value) override;
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> CreateInt64Tensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateUint64Tensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateInt32Tensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateFloatTensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateDoubleTensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateHalfTensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateStringTensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateComplex128Tensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateBoolTensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
|
||||
std::unique_ptr<AbstractTensorHandleInterface> CreateLocalHandle(
|
||||
const std::unique_ptr<AbstractTensorInterface> t) override;
|
||||
std::unique_ptr<AbstractOperationInterface> CreateOperation() override;
|
||||
|
||||
void ListDevices(std::vector<DeviceAttributes>* devices) override;
|
||||
|
||||
// For runtime specific APIs, provide ability to get the underlying context.
|
||||
EagerContext* Context() const { return ctx_; }
|
||||
|
||||
private:
|
||||
EagerContext* ctx_;
|
||||
};
|
||||
|
||||
inline EagerContext* ContextFromInterface(
|
||||
const std::unique_ptr<AbstractContextInterface>& context) {
|
||||
return down_cast<ContextInterface*>(context.get())->Context();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
|
@ -20,7 +20,6 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_reference.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -47,9 +46,7 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||
->Handle();
|
||||
|
||||
tensorflow::TensorHandleFromInterface(h->handle);
|
||||
if (handle->IsRemote()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"DLPack doesn't support remote tensor");
|
||||
@ -289,9 +286,8 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
|
||||
return static_cast<void*>(dlm_tensor);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status,
|
||||
TFE_Context* ctx) {
|
||||
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
|
||||
DLTensor* dl_tensor = &dlmt->dl_tensor;
|
||||
absl::optional<std::string> device_name =
|
||||
|
@ -30,7 +30,8 @@ TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h,
|
||||
|
||||
// Converts DLPack (DLManagedTensor*) to eager tensor handle.
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm,
|
||||
TF_Status* status);
|
||||
TF_Status* status,
|
||||
TFE_Context* ctx);
|
||||
|
||||
// Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule.
|
||||
TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr);
|
||||
|
@ -26,8 +26,7 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
OperationInterface::OperationInterface(TFE_Context* ctx)
|
||||
: operation_(ctx->context) {}
|
||||
OperationInterface::OperationInterface(EagerContext* ctx) : operation_(ctx) {}
|
||||
|
||||
const string& OperationInterface::DeviceName() const {
|
||||
absl::variant<Device*, CustomDevice*> variant_device =
|
||||
@ -99,9 +98,8 @@ Status OperationInterface::SetAttrFunction(
|
||||
AttrValue attr_value;
|
||||
NameAttrList* func = attr_value.mutable_func();
|
||||
func->set_name(value->Name());
|
||||
OperationInterface* value_operation =
|
||||
tensorflow::down_cast<OperationInterface*>(value.get());
|
||||
value_operation->operation_.Attrs().FillAttrValueMap(func->mutable_attr());
|
||||
EagerOperation* value_operation = OperationFromInterface(value);
|
||||
value_operation->Attrs().FillAttrValueMap(func->mutable_attr());
|
||||
operation_.MutableAttrs()->Set(attr_name, attr_value);
|
||||
return Status::OK();
|
||||
}
|
||||
@ -116,10 +114,9 @@ Status OperationInterface::SetAttrFunctionName(const char* attr_name,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrTensor(const char* attr_name,
|
||||
TF_Tensor* tensor) {
|
||||
Tensor t;
|
||||
TF_RETURN_IF_ERROR(TF_TensorToTensor(tensor, &t));
|
||||
Status OperationInterface::SetAttrTensor(
|
||||
const char* attr_name, std::unique_ptr<AbstractTensorInterface> tensor) {
|
||||
Tensor t = TensorFromInterface(tensor);
|
||||
operation_.MutableAttrs()->Set(attr_name, t);
|
||||
return Status::OK();
|
||||
}
|
||||
@ -209,11 +206,10 @@ Status OperationInterface::SetAttrFunctionList(const char* attr_name,
|
||||
int num_values) {
|
||||
std::unique_ptr<NameAttrList[]> funcs(new NameAttrList[num_values]);
|
||||
for (int i = 0; i < num_values; i++) {
|
||||
auto value_operation =
|
||||
tensorflow::down_cast<OperationInterface*>(value[i]->operation.get());
|
||||
funcs[i].set_name(value_operation->operation_.Name());
|
||||
value_operation->operation_.Attrs().FillAttrValueMap(
|
||||
funcs[i].mutable_attr());
|
||||
EagerOperation* value_operation =
|
||||
OperationFromInterface(value[i]->operation);
|
||||
funcs[i].set_name(value_operation->Name());
|
||||
value_operation->Attrs().FillAttrValueMap(funcs[i].mutable_attr());
|
||||
}
|
||||
operation_.MutableAttrs()->Set(
|
||||
attr_name, gtl::ArraySlice<const NameAttrList>(funcs.get(), num_values));
|
||||
@ -267,8 +263,7 @@ Status OperationInterface::OutputLength(const char* output_name, int* length) {
|
||||
|
||||
Status OperationInterface::AddInput(
|
||||
const std::unique_ptr<AbstractTensorHandleInterface>& input) {
|
||||
TensorHandle* h =
|
||||
tensorflow::down_cast<TensorHandleInterface*>(input.get())->Handle();
|
||||
TensorHandle* h = TensorHandleFromInterface(input);
|
||||
operation_.AddInput(h);
|
||||
return operation_.MaybeInferSingleInputAttrs(h);
|
||||
}
|
||||
@ -277,8 +272,7 @@ Status OperationInterface::AddInputList(
|
||||
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
|
||||
inputs) {
|
||||
for (auto& input : inputs) {
|
||||
TensorHandle* h =
|
||||
tensorflow::down_cast<TensorHandleInterface*>(input.get())->Handle();
|
||||
TensorHandle* h = TensorHandleFromInterface(input);
|
||||
operation_.AddInput(h);
|
||||
}
|
||||
return operation_.InferInputListAttrs(inputs.size());
|
||||
@ -297,13 +291,6 @@ Status OperationInterface::Execute(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetCancellationManager(
|
||||
TFE_CancellationManager* cancellation_manager) {
|
||||
operation_.SetCancellationManager(
|
||||
&cancellation_manager->cancellation_manager);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetUseXla(bool enable) {
|
||||
operation_.SetUseXla(enable);
|
||||
return Status::OK();
|
||||
|
@ -19,9 +19,14 @@ limitations under the License.
|
||||
|
||||
#include "absl/container/fixed_array.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/framework/tensor_interface.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to an operation.
|
||||
class AbstractOperationInterface {
|
||||
@ -29,90 +34,67 @@ class AbstractOperationInterface {
|
||||
virtual ~AbstractOperationInterface() {}
|
||||
|
||||
virtual void Clear() = 0;
|
||||
virtual tensorflow::Status Reset(const char* op,
|
||||
const char* raw_device_name) = 0;
|
||||
virtual Status Reset(const char* op, const char* raw_device_name) = 0;
|
||||
|
||||
virtual const tensorflow::string& Name() const = 0;
|
||||
virtual const tensorflow::string& DeviceName() const = 0;
|
||||
virtual tensorflow::Status SetDeviceName(const char* name) = 0;
|
||||
virtual const string& Name() const = 0;
|
||||
virtual const string& DeviceName() const = 0;
|
||||
virtual Status SetDeviceName(const char* name) = 0;
|
||||
|
||||
virtual tensorflow::Status AddInput(
|
||||
virtual Status AddInput(
|
||||
const std::unique_ptr<AbstractTensorHandleInterface>& input) = 0;
|
||||
virtual tensorflow::Status AddInputList(
|
||||
virtual Status AddInputList(
|
||||
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
|
||||
inputs) = 0;
|
||||
virtual tensorflow::Status Execute(
|
||||
virtual Status Execute(
|
||||
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
|
||||
int* num_retvals) = 0;
|
||||
virtual const tensorflow::OpDef* OpDef() const = 0;
|
||||
|
||||
virtual tensorflow::Status SetAttrString(const char* attr_name,
|
||||
const char* data, size_t length) = 0;
|
||||
virtual tensorflow::Status SetAttrInt(const char* attr_name,
|
||||
int64_t value) = 0;
|
||||
virtual tensorflow::Status SetAttrFloat(const char* attr_name,
|
||||
float value) = 0;
|
||||
virtual tensorflow::Status SetAttrBool(const char* attr_name, bool value) = 0;
|
||||
virtual tensorflow::Status SetAttrType(const char* attr_name,
|
||||
TF_DataType value) = 0;
|
||||
virtual tensorflow::Status SetAttrShape(const char* attr_name,
|
||||
const int64_t* dims,
|
||||
const int num_dims) = 0;
|
||||
virtual tensorflow::Status SetAttrFunction(
|
||||
virtual Status SetAttrString(const char* attr_name, const char* data,
|
||||
size_t length) = 0;
|
||||
virtual Status SetAttrInt(const char* attr_name, int64_t value) = 0;
|
||||
virtual Status SetAttrFloat(const char* attr_name, float value) = 0;
|
||||
virtual Status SetAttrBool(const char* attr_name, bool value) = 0;
|
||||
virtual Status SetAttrType(const char* attr_name, TF_DataType value) = 0;
|
||||
virtual Status SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||
const int num_dims) = 0;
|
||||
virtual Status SetAttrFunction(
|
||||
const char* attr_name,
|
||||
const std::unique_ptr<AbstractOperationInterface>& value) = 0;
|
||||
virtual tensorflow::Status SetAttrFunctionName(const char* attr_name,
|
||||
const char* value,
|
||||
size_t length) = 0;
|
||||
virtual tensorflow::Status SetAttrTensor(const char* attr_name,
|
||||
TF_Tensor* tensor) = 0;
|
||||
virtual tensorflow::Status SetAttrStringList(const char* attr_name,
|
||||
const void* const* values,
|
||||
const size_t* lengths,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrFloatList(const char* attr_name,
|
||||
const float* values,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrIntList(const char* attr_name,
|
||||
const int64_t* values,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrTypeList(const char* attr_name,
|
||||
const TF_DataType* values,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrBoolList(const char* attr_name,
|
||||
const unsigned char* values,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrShapeList(const char* attr_name,
|
||||
const int64_t** dims,
|
||||
const int* num_dims,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrFunctionList(const char* attr_name,
|
||||
const TFE_Op** value,
|
||||
int num_values) = 0;
|
||||
virtual Status SetAttrFunctionName(const char* attr_name, const char* value,
|
||||
size_t length) = 0;
|
||||
virtual Status SetAttrTensor(
|
||||
const char* attr_name,
|
||||
std::unique_ptr<AbstractTensorInterface> tensor) = 0;
|
||||
virtual Status SetAttrStringList(const char* attr_name,
|
||||
const void* const* values,
|
||||
const size_t* lengths, int num_values) = 0;
|
||||
virtual Status SetAttrFloatList(const char* attr_name, const float* values,
|
||||
int num_values) = 0;
|
||||
virtual Status SetAttrIntList(const char* attr_name, const int64_t* values,
|
||||
int num_values) = 0;
|
||||
virtual Status SetAttrTypeList(const char* attr_name,
|
||||
const TF_DataType* values, int num_values) = 0;
|
||||
virtual Status SetAttrBoolList(const char* attr_name,
|
||||
const unsigned char* values,
|
||||
int num_values) = 0;
|
||||
virtual Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
|
||||
const int* num_dims, int num_values) = 0;
|
||||
virtual Status SetAttrFunctionList(const char* attr_name,
|
||||
const TFE_Op** value, int num_values) = 0;
|
||||
|
||||
virtual tensorflow::Status InputLength(const char* input_name,
|
||||
int* length) = 0;
|
||||
virtual tensorflow::Status OutputLength(const char* output_name,
|
||||
int* length) = 0;
|
||||
virtual Status InputLength(const char* input_name, int* length) = 0;
|
||||
virtual Status OutputLength(const char* output_name, int* length) = 0;
|
||||
|
||||
// Experimental
|
||||
virtual tensorflow::Status SetUseXla(bool enable) {
|
||||
return tensorflow::errors::Unimplemented("SetUseXla not implemented");
|
||||
}
|
||||
virtual tensorflow::Status SetCancellationManager(
|
||||
TFE_CancellationManager* cancellation_manager) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetCancellationManager not implemented");
|
||||
}
|
||||
virtual Status SetUseXla(bool enable) = 0;
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class OpDef;
|
||||
|
||||
class OperationInterface : public AbstractOperationInterface {
|
||||
public:
|
||||
explicit OperationInterface(TFE_Context* ctx);
|
||||
explicit OperationInterface(EagerContext* ctx);
|
||||
~OperationInterface() override{};
|
||||
|
||||
void Clear() override { operation_.Clear(); }
|
||||
@ -149,7 +131,9 @@ class OperationInterface : public AbstractOperationInterface {
|
||||
const std::unique_ptr<AbstractOperationInterface>& value) override;
|
||||
Status SetAttrFunctionName(const char* attr_name, const char* data,
|
||||
size_t length) override;
|
||||
Status SetAttrTensor(const char* attr_name, TF_Tensor* tensor) override;
|
||||
Status SetAttrTensor(
|
||||
const char* attr_name,
|
||||
std::unique_ptr<AbstractTensorInterface> tensor) override;
|
||||
Status SetAttrStringList(const char* attr_name, const void* const* values,
|
||||
const size_t* lengths, int num_values) override;
|
||||
Status SetAttrFloatList(const char* attr_name, const float* values,
|
||||
@ -169,20 +153,25 @@ class OperationInterface : public AbstractOperationInterface {
|
||||
Status OutputLength(const char* output_name, int* length) override;
|
||||
|
||||
Status SetUseXla(bool enable) override;
|
||||
Status SetCancellationManager(
|
||||
TFE_CancellationManager* cancellation_manager) override;
|
||||
|
||||
// TODO(gjn): Remove once TFE_InferShapes is removed
|
||||
const tensorflow::AttrBuilder& Attrs() const { return operation_.Attrs(); }
|
||||
tensorflow::AttrBuilder* MutableAttrs() { return operation_.MutableAttrs(); }
|
||||
const AttrBuilder& Attrs() const { return operation_.Attrs(); }
|
||||
AttrBuilder* MutableAttrs() { return operation_.MutableAttrs(); }
|
||||
|
||||
const TensorHandle* GetInput(int i) const { return operation_.Inputs()[i]; }
|
||||
|
||||
EagerOperation* Operation() { return &operation_; }
|
||||
|
||||
private:
|
||||
const tensorflow::OpDef* GetOpDef(Status* status);
|
||||
EagerOperation operation_;
|
||||
};
|
||||
|
||||
inline EagerOperation* OperationFromInterface(
|
||||
const std::unique_ptr<AbstractOperationInterface>& operation) {
|
||||
return down_cast<OperationInterface*>(operation.get())->Operation();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
||||
|
@ -15,10 +15,13 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
||||
#define TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/framework/tensor_interface.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to a TensorHandle.
|
||||
//
|
||||
@ -34,24 +37,22 @@ class AbstractTensorHandleInterface {
|
||||
virtual ~AbstractTensorHandleInterface() {}
|
||||
|
||||
// Check if the handle is in a valid initialized state.
|
||||
virtual bool IsValid(tensorflow::Status* status) const = 0;
|
||||
virtual bool IsValid(Status* status) const = 0;
|
||||
// Returns tensor dtype.
|
||||
virtual TF_DataType DataType() const = 0;
|
||||
// Returns number of dimensions.
|
||||
virtual int NumDims(tensorflow::Status* status) const = 0;
|
||||
virtual int NumDims(Status* status) const = 0;
|
||||
// Returns number of elements across all dimensions.
|
||||
virtual int64_t NumElements(tensorflow::Status* status) const = 0;
|
||||
virtual int64_t NumElements(Status* status) const = 0;
|
||||
// Returns size of specified dimension
|
||||
virtual int64_t Dim(int dim_index, tensorflow::Status* status) const = 0;
|
||||
virtual int64_t Dim(int dim_index, Status* status) const = 0;
|
||||
|
||||
// Returns the device which created the handle.
|
||||
virtual const char* DeviceName(tensorflow::Status* status) const = 0;
|
||||
virtual const char* DeviceName(Status* status) const = 0;
|
||||
// Returns the device where the tensor was placed.
|
||||
virtual const char* BackingDeviceName(tensorflow::Status* status) const = 0;
|
||||
virtual const char* BackingDeviceName(Status* status) const = 0;
|
||||
// Returns a tensor for the handle. If tensor is remote, it will be copied.
|
||||
virtual TF_Tensor* Resolve(tensorflow::Status* status) = 0;
|
||||
// Returns debug information about the tensor.
|
||||
virtual TFE_TensorDebugInfo* TensorDebugInfo(tensorflow::Status* status) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> Resolve(Status* status) = 0;
|
||||
|
||||
// Return a copy of the handle.
|
||||
virtual AbstractTensorHandleInterface* Copy() = 0;
|
||||
@ -65,8 +66,9 @@ class AbstractTensorHandleInterface {
|
||||
virtual void EnableImplicitMirroring() = 0;
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// TODO(gjn): Try to move these all to TensorHandle and make it implement
|
||||
// AbstractTensorHandleInterface. Currently, this is not so straightforward
|
||||
// because of various BUILD file dependencies.
|
||||
class TensorHandleInterface : public AbstractTensorHandleInterface {
|
||||
public:
|
||||
explicit TensorHandleInterface(TensorHandle* h) : handle_(h) {}
|
||||
@ -80,21 +82,24 @@ class TensorHandleInterface : public AbstractTensorHandleInterface {
|
||||
|
||||
const char* DeviceName(Status* status) const override;
|
||||
const char* BackingDeviceName(Status* status) const override;
|
||||
TF_Tensor* Resolve(Status* status) override;
|
||||
TFE_TensorDebugInfo* TensorDebugInfo(Status* status) override;
|
||||
std::unique_ptr<AbstractTensorInterface> Resolve(Status* status) override;
|
||||
|
||||
AbstractTensorHandleInterface* Copy() override;
|
||||
|
||||
void EnableImplicitMirroring() override;
|
||||
|
||||
// TODO(gjn): This is not a very generic interface, but is needed for specific
|
||||
// use cases.
|
||||
// For runtime specific APIs, provide ability to get the underlying handle.
|
||||
TensorHandle* Handle() { return handle_; }
|
||||
|
||||
private:
|
||||
TensorHandle* handle_;
|
||||
};
|
||||
|
||||
inline TensorHandle* TensorHandleFromInterface(
|
||||
const std::unique_ptr<AbstractTensorHandleInterface>& handle) {
|
||||
return down_cast<TensorHandleInterface*>(handle.get())->Handle();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
||||
|
@ -16,8 +16,8 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/error.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
using ::tensorflow::IOError;
|
||||
using ::tensorflow::Status;
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_TF_STATUS_HELPER_H_
|
||||
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_TF_STATUS_INTERNAL_H_
|
||||
#define TENSORFLOW_C_TF_STATUS_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
// Internal structures used by the status C API. These are likely to change
|
||||
// and should not be depended on.
|
||||
|
@ -381,7 +381,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
|
||||
->ToTensor(dst);
|
||||
}
|
||||
|
||||
Status TensorInterface::ToTensor(Tensor* dst) const {
|
||||
Status TensorInterface::ToTensor(tensorflow::Tensor* dst) const {
|
||||
if (tensor_.dtype() == DT_RESOURCE) {
|
||||
if (tensor_.dims() != 0) {
|
||||
return InvalidArgument(
|
||||
@ -389,7 +389,7 @@ Status TensorInterface::ToTensor(Tensor* dst) const {
|
||||
"shape ",
|
||||
tensor_.shape().DebugString());
|
||||
}
|
||||
*dst = Tensor(tensorflow::DT_RESOURCE, tensor_.shape());
|
||||
*dst = tensorflow::Tensor(tensorflow::DT_RESOURCE, tensor_.shape());
|
||||
if (!dst->scalar<tensorflow::ResourceHandle>()().ParseFromString(
|
||||
string(static_cast<const char*>(Data()), ByteSize()))) {
|
||||
return InvalidArgument(
|
||||
@ -414,7 +414,7 @@ Status TensorInterface::ToTensor(Tensor* dst) const {
|
||||
const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
|
||||
const char* limit = input + src_size;
|
||||
|
||||
*dst = Tensor(tensor_.dtype(), tensor_.shape());
|
||||
*dst = tensorflow::Tensor(tensor_.dtype(), tensor_.shape());
|
||||
auto dstarray = dst->flat<tstring>();
|
||||
for (tensorflow::int64 i = 0; i < num_elements; ++i) {
|
||||
tensorflow::uint64 offset =
|
||||
|
@ -31,7 +31,7 @@ limitations under the License.
|
||||
// passed to or returned from C functions *by pointer*. Otherwise, changes to
|
||||
// its internal structure will break the C API's binary interface.
|
||||
typedef struct TF_Tensor {
|
||||
std::unique_ptr<AbstractTensorInterface> tensor;
|
||||
std::unique_ptr<tensorflow::AbstractTensorInterface> tensor;
|
||||
} TF_Tensor;
|
||||
|
||||
class TF_ManagedBuffer : public tensorflow::TensorBuffer {
|
||||
|
@ -124,7 +124,7 @@ Status LoadSavedModel(const SessionOptions& session_options,
|
||||
/// the export directory definitely does not contain a SavedModel. If the method
|
||||
/// returns `true`, the export directory may contain a SavedModel but provides
|
||||
/// no guarantee that it can be loaded.
|
||||
bool MaybeSavedModelDirectory(const string& export_dir);
|
||||
bool MaybeSavedModelDirectory(const std::string& export_dir);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -115,11 +115,10 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -12,7 +12,7 @@ load("@bazel_skylib//lib:paths.bzl", "paths")
|
||||
_default_test_file_exts = ["mlir", ".pbtxt", ".td"]
|
||||
_default_driver = "@llvm-project//mlir:run_lit.sh"
|
||||
_default_size = "small"
|
||||
_default_tags = ["no_rocm"]
|
||||
_default_tags = []
|
||||
|
||||
# These are patterns which we should never match, for tests, subdirectories, or
|
||||
# test input data files.
|
||||
|
@ -17,8 +17,7 @@ package_group(
|
||||
name = "friends",
|
||||
includes = ["//third_party/mlir:subpackages"],
|
||||
packages = [
|
||||
"//learning/brain/experimental/mlir/...",
|
||||
"//learning/brain/google/xla/...",
|
||||
"//learning/brain/mlir/...",
|
||||
"//tensorflow/compiler/mlir/...",
|
||||
],
|
||||
)
|
||||
@ -205,8 +204,6 @@ cc_library(
|
||||
cc_library(
|
||||
name = "tensorflow_lite",
|
||||
srcs = [
|
||||
"experimental/estimators/estimator.h",
|
||||
"experimental/estimators/gpu_estimator.h.inc",
|
||||
"ir/tfl_ops.cc",
|
||||
"ir/tfl_ops.cc.inc",
|
||||
"ir/tfl_ops.h.inc",
|
||||
@ -216,7 +213,6 @@ cc_library(
|
||||
"utils/attribute_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"experimental/estimators/hardware.h",
|
||||
"ir/tfl_ops.h",
|
||||
"transforms/passes.h",
|
||||
"utils/attribute_utils.h",
|
||||
@ -226,6 +222,7 @@ cc_library(
|
||||
deps = [
|
||||
":tensorflow_lite_ops_inc_gen",
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/lite/experimental/estimators:cost_estimators",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@llvm-project//llvm:support",
|
||||
|
15
tensorflow/compiler/mlir/lite/experimental/estimators/BUILD
Normal file
15
tensorflow/compiler/mlir/lite/experimental/estimators/BUILD
Normal file
@ -0,0 +1,15 @@
|
||||
package(
|
||||
default_visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cost_estimators",
|
||||
textual_hdrs = [
|
||||
"estimator.h",
|
||||
"gpu_estimators.h",
|
||||
"hardware.h",
|
||||
],
|
||||
)
|
@ -1,72 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATOR_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATOR_H_
|
||||
|
||||
// tfl.average_pool_2d
|
||||
template <>
|
||||
class TFLiteCostEstimator<AveragePool2DOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.conv_2d
|
||||
template <>
|
||||
class TFLiteCostEstimator<Conv2DOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// TODO(renjieliu): We probably need to check for dynamic weights.
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.depthwise_conv_2d
|
||||
template <>
|
||||
class TFLiteCostEstimator<DepthwiseConv2DOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.max_pool_2d
|
||||
template <>
|
||||
class TFLiteCostEstimator<MaxPool2DOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATOR_H_
|
@ -0,0 +1,231 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
|
||||
|
||||
// tfl.add
|
||||
template <>
|
||||
class TFLiteCostEstimator<AddOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.average_pool_2d
|
||||
template <>
|
||||
class TFLiteCostEstimator<AveragePool2DOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.concatenation
|
||||
template <>
|
||||
class TFLiteCostEstimator<ConcatenationOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// TODO(renjieliu): We probably need to check for dynamic weights.
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.conv_2d
|
||||
template <>
|
||||
class TFLiteCostEstimator<Conv2DOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// TODO(renjieliu): We probably need to check for dynamic weights.
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.depthwise_conv_2d
|
||||
template <>
|
||||
class TFLiteCostEstimator<DepthwiseConv2DOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.fully_connected
|
||||
template <>
|
||||
class TFLiteCostEstimator<FullyConnectedOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// TODO(renjieliu): we need to check for dynamic weights.
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.logistic
|
||||
template <>
|
||||
class TFLiteCostEstimator<LogisticOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.max_pool_2d
|
||||
template <>
|
||||
class TFLiteCostEstimator<MaxPool2DOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.mirror_pad
|
||||
template <>
|
||||
class TFLiteCostEstimator<MirrorPadOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.maximum
|
||||
template <>
|
||||
class TFLiteCostEstimator<MaximumOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.minimum
|
||||
template <>
|
||||
class TFLiteCostEstimator<MinimumOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.mul
|
||||
template <>
|
||||
class TFLiteCostEstimator<MulOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.relu
|
||||
template <>
|
||||
class TFLiteCostEstimator<ReluOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.relu6
|
||||
template <>
|
||||
class TFLiteCostEstimator<Relu6Op, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.reshape
|
||||
template <>
|
||||
class TFLiteCostEstimator<ReshapeOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.softmax
|
||||
template <>
|
||||
class TFLiteCostEstimator<SoftmaxOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
|
||||
|
@ -54,7 +54,7 @@ class TensorFlowLiteDialect : public Dialect {
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc"
|
||||
// Include all specializes estimators below this line
|
||||
#include "tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimator.h.inc"
|
||||
#include "tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h"
|
||||
|
||||
} // end namespace TFL
|
||||
} // end namespace mlir
|
||||
|
@ -139,8 +139,7 @@ def TFL_Uint8 : UI<8>;
|
||||
def TFL_Int32Or64 : SignlessIntOfWidths<[32, 64]>;
|
||||
|
||||
def TFL_BoolTensor : TFL_TensorOf<[I1]>;
|
||||
def TFL_FpOrI32OrI64Tensor : TFL_TensorOf<[AnyFloat, TFL_Int32Or64]>;
|
||||
def TFL_FpTensor : TFL_TensorOf<[AnyFloat]>;
|
||||
def TFL_FpTensor : TFL_TensorOf<[F32]>;
|
||||
def TFL_I32OrI64Tensor : TFL_TensorOf<[TFL_Int32Or64]>;
|
||||
def TFL_I32Tensor : TFL_TensorOf<[I32]>;
|
||||
def TFL_I64Tensor : TFL_TensorOf<[I64]>;
|
||||
@ -324,9 +323,9 @@ class TFL_ConvOp<string mnemonic, string opSummary, int index> :
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins AnyTensor:$input,
|
||||
AnyTensor:$filter,
|
||||
TFL_TensorOfOrNone<[AnyType]>:$bias,
|
||||
ins TFL_TensorOf<[F32, QI8, QUI8]>:$input,
|
||||
TFL_TensorOf<[F32, QI8, QUI8]>:$filter,
|
||||
TFL_TensorOfOrNone<[F32, I32]>:$bias,
|
||||
I32Attr:$dilation_h_factor,
|
||||
I32Attr:$dilation_w_factor,
|
||||
TFL_AFAttr:$fused_activation_function,
|
||||
@ -335,7 +334,7 @@ class TFL_ConvOp<string mnemonic, string opSummary, int index> :
|
||||
I32Attr:$stride_w
|
||||
);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output);
|
||||
|
||||
let hasOptions = 0b1;
|
||||
}
|
||||
@ -361,7 +360,10 @@ an output element, this operation computes \\(y = |x|\\).
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape, NoSideEffect, Commutative]> {
|
||||
def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
Commutative,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Addition operator";
|
||||
|
||||
let description = [{
|
||||
@ -394,11 +396,11 @@ def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect, SameOperandsAndResu
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_VariadicTensorOf<[F32, I32, QI16, QUI16]>:$inputs
|
||||
TFL_VariadicTensorOf<[F32, I32]>:$inputs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I32, QI16, QUI16]>:$sum
|
||||
TFL_TensorOf<[F32, I32]>:$sum
|
||||
);
|
||||
}
|
||||
|
||||
@ -492,7 +494,7 @@ def TFL_AveragePool2DOp:
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins AnyTensor:$input,
|
||||
ins TFL_TensorOf<[F32, QI8, QUI8]>:$input,
|
||||
I32Attr:$filter_height,
|
||||
I32Attr:$filter_width,
|
||||
TFL_PaddingAttr:$padding,
|
||||
@ -501,7 +503,7 @@ def TFL_AveragePool2DOp:
|
||||
TFL_AFAttr:$fused_activation_function
|
||||
);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
let customOption = "Pool2DOptions";
|
||||
@ -577,7 +579,8 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation",
|
||||
NoSideEffect,
|
||||
PredOpTrait<"values and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
SameOperandsAndResultsScale
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_GpuTargetOp
|
||||
]> {
|
||||
let summary = "Concatenation operator";
|
||||
|
||||
@ -719,7 +722,18 @@ def TFL_CosOp: TFL_Op<"cos", [
|
||||
|
||||
def TFL_DepthwiseConv2DOp :
|
||||
TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> {
|
||||
let arguments = !con(TFL_Conv2DOp.arguments, (ins I32Attr:$depth_multiplier));
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input,
|
||||
TFL_TensorOf<[F32, QI8, QUI8]>:$filter,
|
||||
TFL_TensorOfOrNone<[F32, I32, I64]>:$bias,
|
||||
I32Attr:$dilation_h_factor,
|
||||
I32Attr:$dilation_w_factor,
|
||||
TFL_AFAttr:$fused_activation_function,
|
||||
TFL_PaddingAttr:$padding,
|
||||
I32Attr:$stride_h,
|
||||
I32Attr:$stride_w,
|
||||
I32Attr:$depth_multiplier
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// ChannelDimIndexInterface:
|
||||
@ -741,7 +755,8 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
|
||||
NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
|
||||
TFL_ChannelDimIndexInterface,
|
||||
AffineOpCoefficient<-1, 1>,
|
||||
TFL_SparseOp]> {
|
||||
TFL_SparseOp,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Fully connected op";
|
||||
|
||||
let arguments = (ins
|
||||
@ -1091,6 +1106,8 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [
|
||||
}
|
||||
|
||||
def TFL_DivOp : TFL_Op<"div", [
|
||||
// TODO(fengliuai): NoQuantizableResult is only correct for int8
|
||||
// quantization. update to handle Uint8 quantization.
|
||||
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
|
||||
let summary = "Division operator";
|
||||
|
||||
@ -1099,11 +1116,11 @@ def TFL_DivOp : TFL_Op<"div", [
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins AnyTensor:$lhs,
|
||||
AnyTensor:$rhs,
|
||||
ins TFL_TensorOf<[F32, I32, QUI8]>:$lhs,
|
||||
TFL_TensorOf<[F32, I32, TFL_Uint8]>:$rhs,
|
||||
TFL_AFAttr:$fused_activation_function);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, I32, TFL_Uint8]>:$output);
|
||||
|
||||
let builders = [TFL_FusedBroadcastableBinaryBuilder];
|
||||
|
||||
@ -1126,7 +1143,7 @@ def TFL_EluOp: TFL_Op<"elu", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
|
||||
let arguments = (ins TFL_FpTensor:$x);
|
||||
|
||||
let results = (outs AnyTensor:$y);
|
||||
let results = (outs TFL_FpTensor:$y);
|
||||
|
||||
let hasOptions = 0;
|
||||
}
|
||||
@ -1487,16 +1504,17 @@ def TFL_LogisticOp: TFL_Op<"logistic", [
|
||||
// zero_point = 0
|
||||
// scale = 1. / (max_value + 1)
|
||||
FixedResultScale<Int8UniformQuantizedType<-128, 390625, -8>>,
|
||||
FixedResultScale<UInt8UniformQuantizedType<0, 390625, -8>>]> {
|
||||
FixedResultScale<UInt8UniformQuantizedType<0, 390625, -8>>,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Logistic operator";
|
||||
|
||||
let description = [{
|
||||
Computes element-wise Sigmoid of input
|
||||
}];
|
||||
|
||||
let arguments = (ins TFL_TensorOf<[AnyFloat, QI8, QUI8, QI16, QUI16]>:$x);
|
||||
let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$x);
|
||||
|
||||
let results = (outs TFL_TensorOf<[AnyFloat, QI8, QUI8, QI16, QUI16]>:$y);
|
||||
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$y);
|
||||
}
|
||||
|
||||
def TFL_LogOp: TFL_Op<"log", [
|
||||
@ -1639,19 +1657,23 @@ def TFL_MaxUnpooling2DOp :
|
||||
}
|
||||
|
||||
def TFL_MaximumOp : TFL_Op<"maximum", [
|
||||
ResultsBroadcastableShape, NoSideEffect, Commutative, SameOperandsAndResultsScale]> {
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
Commutative,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Max operator";
|
||||
let description = [{
|
||||
Element-wise max operation.
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$lhs,
|
||||
TFL_TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$rhs
|
||||
ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$lhs,
|
||||
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$rhs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$max
|
||||
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$max
|
||||
);
|
||||
|
||||
let builders = [TFL_BroadcastableBinaryBuilder];
|
||||
@ -1837,19 +1859,23 @@ def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [NoSideEffect]> {
|
||||
}
|
||||
|
||||
def TFL_MinimumOp : TFL_Op<"minimum", [
|
||||
ResultsBroadcastableShape, NoSideEffect, Commutative, SameOperandsAndResultsScale]> {
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
Commutative,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Min operator";
|
||||
let description = [{
|
||||
Element-wise min operation.
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$lhs,
|
||||
TFL_TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$rhs
|
||||
ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$lhs,
|
||||
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$rhs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$min
|
||||
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$min
|
||||
);
|
||||
|
||||
let builders = [TFL_BroadcastableBinaryBuilder];
|
||||
@ -1857,7 +1883,10 @@ def TFL_MinimumOp : TFL_Op<"minimum", [
|
||||
let hasOptions = 0;
|
||||
}
|
||||
|
||||
def TFL_MulOp : TFL_Op<"mul", [ResultsBroadcastableShape, NoSideEffect, Commutative]> {
|
||||
def TFL_MulOp : TFL_Op<"mul", [ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
Commutative,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Multiplication operator";
|
||||
|
||||
let description = [{
|
||||
@ -2090,7 +2119,8 @@ def TFL_RankOp: TFL_Op<"rank", [NoSideEffect]> {
|
||||
|
||||
def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultsScale]> {
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Relu operator";
|
||||
|
||||
let description = [{
|
||||
@ -2105,7 +2135,8 @@ def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect,
|
||||
|
||||
def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultsScale]> {
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Relu6 operator";
|
||||
|
||||
let description = [{
|
||||
@ -2134,7 +2165,7 @@ def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect,
|
||||
}
|
||||
|
||||
def TFL_ReshapeOp: TFL_Op<"reshape", [
|
||||
NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
NoSideEffect, SameOperandsAndResultsScale, TFL_GpuTargetOp]> {
|
||||
let summary = "Reshape operator";
|
||||
|
||||
let description = [{
|
||||
@ -2351,7 +2382,8 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [
|
||||
// zero_point = 0
|
||||
// scale = 1. / (max_value + 1)
|
||||
FixedResultScale<Int8UniformQuantizedType<-128, 390625, -8>>,
|
||||
FixedResultScale<UInt8UniformQuantizedType<0, 390625, -8>>]> {
|
||||
FixedResultScale<UInt8UniformQuantizedType<0, 390625, -8>>,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Softmax operator";
|
||||
|
||||
let description = [{
|
||||
@ -2882,7 +2914,7 @@ def TFL_CastOp : TFL_Op<"cast", [
|
||||
TFL_TensorOf<[F32, I1, I32, I64, TFL_Quint8, TFL_Uint8, Complex<F<32>>]>:$input
|
||||
);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, I1, I32, I64, Complex<F<32>>]>:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, I1, I32, I64, TFL_Quint8, TFL_Uint8, Complex<F<32>>]>:$output);
|
||||
|
||||
// TFLite's cast op does not utilize CastOptions, instead derives types
|
||||
// from the TfLiteTensors.
|
||||
@ -2891,7 +2923,7 @@ def TFL_CastOp : TFL_Op<"cast", [
|
||||
|
||||
|
||||
def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [
|
||||
NoSideEffect, TFL_OperandHasRank<1, 2>]> {
|
||||
NoSideEffect, TFL_OperandHasRank<1, 2>, TFL_GpuTargetOp]> {
|
||||
let summary = "MirrorPad Operator. Pads a tensor with mirrored values.";
|
||||
|
||||
let description = [{
|
||||
@ -3400,7 +3432,7 @@ def TFL_BidirectionalSequenceLSTMOp :
|
||||
let summary = "Bidirectional sequence lstm operator";
|
||||
|
||||
let description = [{
|
||||
Bidirectional lstm is essentiallay two lstms, one running forward & the
|
||||
Bidirectional lstm is essentially two lstms, one running forward & the
|
||||
other running backward. And the output is the concatenation of the two
|
||||
lstms.
|
||||
}];
|
||||
|
@ -111,3 +111,31 @@ tf_native_cc_binary(
|
||||
"@llvm-project//mlir:TableGen",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "device_target",
|
||||
srcs = ["device_target.cc"],
|
||||
hdrs = ["device_target.h"],
|
||||
deps = [
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "quantization_context",
|
||||
srcs = ["quantization_context.cc"],
|
||||
hdrs = ["quantization_context.h"],
|
||||
deps = [
|
||||
":device_target",
|
||||
":quantization_lib",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
82
tensorflow/compiler/mlir/lite/quantization/device_target.cc
Normal file
82
tensorflow/compiler/mlir/lite/quantization/device_target.cc
Normal file
@ -0,0 +1,82 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace quant {
|
||||
|
||||
constexpr int k8Bits = 8;
|
||||
constexpr unsigned kSigned = quant::QuantizationFlags::Signed;
|
||||
|
||||
DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) {
|
||||
f32_ = FloatType::getF32(ctx_);
|
||||
i8_ = IntegerType::get(k8Bits, ctx_);
|
||||
i8_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k8Bits);
|
||||
i8_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k8Bits);
|
||||
any_ = AnyQuantizedType();
|
||||
qi8_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_, i8_max_);
|
||||
qi8n_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_ + 1, i8_max_);
|
||||
assert(qi8n_ == qi8n_);
|
||||
}
|
||||
|
||||
Optional<KernelSpec> DeviceTarget::Get(QuantizeRegionOp op) const {
|
||||
auto kernel_specs_it = specs_.find(op.logical_kernel());
|
||||
if (kernel_specs_it == specs_.end()) return llvm::None;
|
||||
|
||||
KernelSpecs::Signature signature;
|
||||
signature.reserve(op.input_specs().size() + op.output_specs().size());
|
||||
AppendToSignature(op.input_specs(), &signature);
|
||||
AppendToSignature(op.output_specs(), &signature);
|
||||
return kernel_specs_it->getValue().Find(signature);
|
||||
}
|
||||
|
||||
LogicalResult DeviceTarget::RegisterKernel(
|
||||
llvm::StringRef kernel, const KernelSpecs::Signature& signature,
|
||||
const ScaleFn& fn) {
|
||||
return specs_[kernel].Add(signature, {ScaleConstraintType::CustomScale, fn});
|
||||
}
|
||||
|
||||
LogicalResult DeviceTarget::RegisterKernel(
|
||||
llvm::StringRef kernel, const KernelSpecs::Signature& signature,
|
||||
const ScaleConstraintType constraint) {
|
||||
return specs_[kernel].Add(signature, {constraint, {}});
|
||||
}
|
||||
|
||||
void DeviceTarget::AppendToSignature(ArrayAttr specs_attr,
|
||||
KernelSpecs::Signature* signature) const {
|
||||
for (auto attr : specs_attr) {
|
||||
Type spec = attr.cast<TypeAttr>().getValue();
|
||||
if (auto quant = spec.dyn_cast<UniformQuantizedType>()) {
|
||||
signature->push_back(AnyQuantizedType::get(
|
||||
quant.getFlags(), quant.getStorageType(), quant.getExpressedType(),
|
||||
quant.getStorageTypeMin(), quant.getStorageTypeMax()));
|
||||
} else if (auto any = spec.dyn_cast<AnyQuantizedType>()) {
|
||||
signature->push_back(any);
|
||||
} else { // float
|
||||
signature->push_back({});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace quant
|
||||
} // namespace mlir
|
147
tensorflow/compiler/mlir/lite/quantization/device_target.h
Normal file
147
tensorflow/compiler/mlir/lite/quantization/device_target.h
Normal file
@ -0,0 +1,147 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_DEVICE_TARGET_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_DEVICE_TARGET_H_
|
||||
|
||||
#include <functional>
|
||||
#include <ostream>
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/Hashing.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace quant {
|
||||
|
||||
class QuantizeContext;
|
||||
|
||||
using AdjacentOperations = llvm::SmallVectorImpl<Operation*>;
|
||||
using ScaleFn = std::function<LogicalResult(QuantizeContext*, Operation*,
|
||||
AdjacentOperations*, bool*)>;
|
||||
|
||||
enum class ScaleConstraintType {
|
||||
OutputInputSameScale,
|
||||
OutputInputFreeScale,
|
||||
CustomScale,
|
||||
};
|
||||
|
||||
// Each kernel signature has its own specification for scales.
|
||||
struct KernelSpec {
|
||||
// Scale constraint
|
||||
ScaleConstraintType type;
|
||||
|
||||
// Custom function to derive the scales. Only available when the scale
|
||||
// constraint is `CustomScale`.
|
||||
ScaleFn scale_fn;
|
||||
};
|
||||
|
||||
class KernelSpecs {
|
||||
public:
|
||||
using Signature = llvm::SmallVector<quant::AnyQuantizedType, 4>;
|
||||
|
||||
// Returns the kernel specification for the kernel signature.
|
||||
Optional<KernelSpec> Find(const Signature& signature) const {
|
||||
auto spec_it = all_signatures_.find(signature);
|
||||
if (spec_it != all_signatures_.end()) {
|
||||
return spec_it->second;
|
||||
} else {
|
||||
return llvm::None;
|
||||
}
|
||||
}
|
||||
|
||||
// Adds the kernel signature with the kernel specification.
|
||||
LogicalResult Add(const Signature& signature, const KernelSpec& spec) {
|
||||
if (all_signatures_.insert({signature, spec}).second) return success();
|
||||
return failure();
|
||||
}
|
||||
|
||||
private:
|
||||
// The signature is pattern match based.
|
||||
struct SignatureInfo : public llvm::DenseMapInfo<Signature> {
|
||||
static inline Signature getEmptyKey() { return {}; }
|
||||
static inline Signature getTombstoneKey() { return {nullptr}; }
|
||||
static unsigned getHashValue(Signature val) {
|
||||
return llvm::hash_combine_range(val.begin(), val.end());
|
||||
}
|
||||
static bool isEqual(Signature LHS, Signature RHS) {
|
||||
if (RHS == getEmptyKey()) return LHS == getEmptyKey();
|
||||
if (RHS == getTombstoneKey()) return LHS == getTombstoneKey();
|
||||
if (LHS.size() != RHS.size()) return false;
|
||||
for (auto arg : llvm::zip(LHS, RHS)) {
|
||||
if (std::get<0>(arg) != std::get<1>(arg)) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// Maps the signature to the kernel spec. Note that the matching is
|
||||
// pattern match based.
|
||||
llvm::DenseMap<Signature, KernelSpec, SignatureInfo> all_signatures_;
|
||||
};
|
||||
|
||||
class DeviceTarget {
|
||||
public:
|
||||
explicit DeviceTarget(MLIRContext* ctx);
|
||||
|
||||
// Retrieves the kernel spec for the quant region op.
|
||||
Optional<KernelSpec> Get(quant::QuantizeRegionOp op) const;
|
||||
|
||||
protected:
|
||||
// Adds the kernel spec with the custom scale function for the kernel.
|
||||
LogicalResult RegisterKernel(llvm::StringRef kernel,
|
||||
const KernelSpecs::Signature& signature,
|
||||
const ScaleFn& fn);
|
||||
|
||||
// Adds the kernel spec with the scale constraint type for the kernel.
|
||||
LogicalResult RegisterKernel(llvm::StringRef kernel,
|
||||
const KernelSpecs::Signature& signature,
|
||||
const ScaleConstraintType constraint);
|
||||
|
||||
// converts specification to signature:
|
||||
// - UniformedQuantizedType -> AnyQuantizedType
|
||||
// - AnyQuantizedType (int) -> AnyQuantizedType
|
||||
// - Float -> {}
|
||||
void AppendToSignature(ArrayAttr specs_attr,
|
||||
KernelSpecs::Signature* signature) const;
|
||||
|
||||
// A set of parameters are required to build the signatures.
|
||||
FloatType f32_;
|
||||
IntegerType i8_;
|
||||
int64_t i8_min_, i8_max_;
|
||||
AnyQuantizedType any_, qi8_, qi8n_;
|
||||
|
||||
private:
|
||||
// Maps the kernel names to all the available kernels.
|
||||
llvm::StringMap<KernelSpecs> specs_;
|
||||
|
||||
// Points to the global MLIRContext.
|
||||
MLIRContext* ctx_;
|
||||
};
|
||||
|
||||
} // namespace quant
|
||||
} // namespace mlir
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_DEVICE_TARGET_H_
|
@ -34,11 +34,11 @@ limitations under the License.
|
||||
namespace mlir {
|
||||
namespace lite {
|
||||
|
||||
// TODO(fengliuai): check the result for `allow_float` flag.
|
||||
// TODO(fengliuai): check the result for `fully_quantize` flag.
|
||||
TfLiteStatus QuantizeModel(
|
||||
const tflite::ModelT& input_model, const tflite::TensorType& input_type,
|
||||
const tflite::TensorType& output_type,
|
||||
const std::unordered_set<std::string>& operator_names, bool allow_float,
|
||||
const std::unordered_set<std::string>& operator_names, bool fully_quantize,
|
||||
flatbuffers::FlatBufferBuilder* builder,
|
||||
tflite::ErrorReporter* error_reporter) {
|
||||
// TODO(b/142502494): remove this restriction by improving the `emit_adaptor`
|
||||
|
@ -27,11 +27,11 @@ namespace lite {
|
||||
|
||||
// Quantize the `input_model` and write the result to a flatbuffer `builder`.
|
||||
// The `input_type` and `output_type` can be float32/qint8/int8.
|
||||
// Return partially quantized model if `allow_float` is true.
|
||||
// Return partially quantized model if `fully_quantize` is false.
|
||||
TfLiteStatus QuantizeModel(
|
||||
const tflite::ModelT& input_model, const tflite::TensorType& input_type,
|
||||
const tflite::TensorType& output_type,
|
||||
const std::unordered_set<std::string>& operator_names, bool allow_float,
|
||||
const std::unordered_set<std::string>& operator_names, bool fully_quantize,
|
||||
flatbuffers::FlatBufferBuilder* builder,
|
||||
tflite::ErrorReporter* error_reporter);
|
||||
} // namespace lite
|
||||
|
@ -47,7 +47,7 @@ TfLiteStatus QuantizeAnnotatedModel(llvm::StringRef buffer,
|
||||
tflite::StderrReporter error_reporter;
|
||||
return mlir::lite::QuantizeModel(
|
||||
*model, tflite::TensorType_INT8, tflite::TensorType_INT8, {},
|
||||
/*allow_float=*/false, builder, &error_reporter);
|
||||
/*fully_quantize=*/true, builder, &error_reporter);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -0,0 +1,239 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
|
||||
#define DEBUG_TYPE "quantization-context"
|
||||
|
||||
namespace mlir {
|
||||
namespace quant {
|
||||
|
||||
QuantizeContext::QuantizeContext(FuncOp func, const DeviceTarget &spec)
|
||||
: func_(func), target_spec_(spec) {
|
||||
llvm::DenseMap<Value, int> value_to_state;
|
||||
func.walk([&](quant::QuantizeRegionOp op) {
|
||||
for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
|
||||
states_manager_.InitializeOperandState(op, i, &value_to_state);
|
||||
}
|
||||
|
||||
for (int res = 0, e = op.getNumResults(); res != e; ++res) {
|
||||
states_manager_.InitializeResultState(op, res, &value_to_state);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
llvm::ArrayRef<quant::QuantizeRegionOp> QuantizeContext::GetAllOps() {
|
||||
llvm::SmallVector<quant::QuantizeRegionOp, 64> all_ops;
|
||||
func_.walk([&](quant::QuantizeRegionOp op) { all_ops.push_back(op); });
|
||||
return all_ops;
|
||||
}
|
||||
|
||||
LogicalResult QuantizeContext::Handle(
|
||||
quant::QuantizeRegionOp op, llvm::SmallVectorImpl<Operation *> *new_items,
|
||||
bool *changed) {
|
||||
auto spec = target_spec_.Get(op);
|
||||
if (!spec.hasValue()) {
|
||||
op.emitWarning(
|
||||
"Couldn't find kernel from the registeration for quantization.");
|
||||
return success();
|
||||
}
|
||||
switch (spec->type) {
|
||||
case ScaleConstraintType::OutputInputFreeScale: {
|
||||
// no propagation.
|
||||
*changed = false;
|
||||
break;
|
||||
}
|
||||
case ScaleConstraintType::CustomScale: {
|
||||
if (failed(spec->scale_fn(this, op, new_items, changed))) {
|
||||
return failure();
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
llvm_unreachable("no implementation.");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult QuantizeContext::Finalize() {
|
||||
MLIRContext *context = func_.getContext();
|
||||
func_.walk([&](quant::QuantizeRegionOp op) {
|
||||
llvm::SmallVector<Attribute, 4> input_specs;
|
||||
auto original_input_specs = op.input_specs().getValue();
|
||||
for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
|
||||
auto &state = states_manager_.GetOperandQuantState(op, i);
|
||||
auto &requantize = states_manager_.GetOperandRequantizeState(op, i);
|
||||
if (state.IsEmpty() && requantize.pos == RequantizeState::NO_REQUANTIZE) {
|
||||
input_specs.push_back(original_input_specs[i]);
|
||||
} else if (requantize.pos == RequantizeState::ON_OUTPUT) {
|
||||
input_specs.push_back(TypeAttr::get(requantize.params));
|
||||
} else {
|
||||
input_specs.push_back(TypeAttr::get(state.params));
|
||||
}
|
||||
}
|
||||
op.setAttr("input_specs", ArrayAttr::get(input_specs, context));
|
||||
|
||||
llvm::SmallVector<Attribute, 4> output_specs;
|
||||
auto original_output_specs = op.output_specs().getValue();
|
||||
for (int res = 0, e = op.getNumResults(); res != e; ++res) {
|
||||
auto &state = states_manager_.GetResultQuantState(op, res);
|
||||
auto &requantize = states_manager_.GetResultRequantizeState(op, res);
|
||||
if (state.IsEmpty() && requantize.pos == RequantizeState::NO_REQUANTIZE) {
|
||||
output_specs.push_back(original_output_specs[res]);
|
||||
} else if (requantize.pos == RequantizeState::ON_INPUT) {
|
||||
output_specs.push_back(TypeAttr::get(requantize.params));
|
||||
} else {
|
||||
output_specs.push_back(TypeAttr::get(state.params));
|
||||
}
|
||||
}
|
||||
op.setAttr("output_specs", ArrayAttr::get(output_specs, context));
|
||||
});
|
||||
return success();
|
||||
}
|
||||
|
||||
void QuantizeContext::DumpStates(QuantizeRegionOp current_op) {
|
||||
if (current_op) {
|
||||
llvm::errs() << "\n\n\n" << current_op.logical_kernel() << "\n";
|
||||
}
|
||||
func_.walk([&](QuantizeRegionOp op) {
|
||||
if (current_op == op) llvm::errs() << "===>>>";
|
||||
llvm::errs() << op.logical_kernel() << " : (";
|
||||
for (auto i = 0; i < op.getNumOperands(); ++i) {
|
||||
if (auto params = GetOperandParams(op, i))
|
||||
params.print(llvm::errs());
|
||||
else
|
||||
llvm::errs() << "_";
|
||||
llvm::errs() << ",";
|
||||
}
|
||||
llvm::errs() << ") -> (";
|
||||
for (auto i = 0; i < op.getNumResults(); ++i) {
|
||||
if (auto params = GetResultParams(op, i))
|
||||
params.print(llvm::errs());
|
||||
else
|
||||
llvm::errs() << "_";
|
||||
llvm::errs() << ",";
|
||||
}
|
||||
llvm::errs() << ")\n";
|
||||
});
|
||||
}
|
||||
|
||||
int QuantizeContext::StatesManager::InitializeState(quant::QuantizeRegionOp op,
|
||||
int index, bool as_result) {
|
||||
Attribute params_attr;
|
||||
if (as_result) {
|
||||
params_attr = op.output_specs()[index];
|
||||
} else {
|
||||
params_attr = op.input_specs()[index];
|
||||
}
|
||||
QuantParams params =
|
||||
params_attr.cast<TypeAttr>().getValue().dyn_cast<QuantParams>();
|
||||
bool immutable = !EmptyParams(params);
|
||||
int next_state_index = states_.size();
|
||||
states_.push_back({params, immutable});
|
||||
if (as_result) {
|
||||
result_states_.insert({{op, index}, next_state_index});
|
||||
} else {
|
||||
operand_states_.insert({{op, index}, next_state_index});
|
||||
}
|
||||
return next_state_index;
|
||||
}
|
||||
|
||||
void QuantizeContext::StatesManager::InitializeOperandState(
|
||||
quant::QuantizeRegionOp op, int index, llvm::DenseMap<Value, int> *cache) {
|
||||
Value in = op.getOperand(index);
|
||||
auto cached = cache->insert({in, 0});
|
||||
if (!cached.second) {
|
||||
operand_states_.insert({{op, index}, cached.first->second});
|
||||
return;
|
||||
}
|
||||
cached.first->second = InitializeState(op, index, /*as_result=*/false);
|
||||
}
|
||||
|
||||
void QuantizeContext::StatesManager::InitializeResultState(
|
||||
quant::QuantizeRegionOp op, int index, llvm::DenseMap<Value, int> *cache) {
|
||||
auto res = op.getResult(index);
|
||||
auto cached = cache->insert({res, 0});
|
||||
if (!cached.second) {
|
||||
result_states_.insert({{op, index}, cached.first->second});
|
||||
return;
|
||||
}
|
||||
cached.first->second = InitializeState(op, index, /*as_result=*/true);
|
||||
}
|
||||
|
||||
bool QuantizeContext::StatesManager::SetConstantResultParams(Operation *op) {
|
||||
llvm_unreachable("no implementation.");
|
||||
return false;
|
||||
}
|
||||
|
||||
bool QuantizeContext::StatesManager::SetResultParams(Operation *op,
|
||||
int res_index,
|
||||
QuantParams params) {
|
||||
auto &state = GetResultQuantState(op, res_index);
|
||||
if (state.params == params) {
|
||||
return false;
|
||||
}
|
||||
if (!state.IsEmpty()) {
|
||||
auto &rescale = GetResultRequantizeState(op, res_index);
|
||||
rescale.params = params;
|
||||
rescale.pos = RequantizeState::ON_INPUT;
|
||||
return false;
|
||||
}
|
||||
state.params = params;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool QuantizeContext::StatesManager::SetOperandParams(Operation *op, int index,
|
||||
QuantParams params) {
|
||||
auto &state = GetOperandQuantState(op, index);
|
||||
if (state.params == params) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!state.IsEmpty()) {
|
||||
auto &rescale = GetOperandRequantizeState(op, index);
|
||||
rescale.params = params;
|
||||
rescale.pos = RequantizeState::ON_OUTPUT;
|
||||
return false;
|
||||
}
|
||||
state.params = params;
|
||||
return true;
|
||||
}
|
||||
} // namespace quant
|
||||
} // namespace mlir
|
@ -0,0 +1,217 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace quant {
|
||||
|
||||
static bool EmptyParams(QuantParams p) { return p == quant::QuantizedType(); }
|
||||
|
||||
// The state for each op result during the quantization parameters propagation.
|
||||
struct QuantState {
|
||||
// Quantization parameters propagated to an op result.
|
||||
QuantParams params;
|
||||
// A flag indicates this state (the params) shouldn't be changed after it is
|
||||
// initialized. This flag will be set to true if the quantization parameters
|
||||
// are from the quantization-aware training.
|
||||
const bool immutable;
|
||||
|
||||
bool IsEmpty() { return EmptyParams(params); }
|
||||
};
|
||||
|
||||
// The state for rescaling the propagated quantization parameters. This can be
|
||||
// on the input side to satisfy the constraint of previous operation, or on the
|
||||
// output side to satisfy the constraint of the next operation.
|
||||
struct RequantizeState {
|
||||
// Sometimes, we have to "requantize" the quantization result to satisfy all
|
||||
// the constraints. The "requantize" can happen either on the input or output
|
||||
// of the quantization result.
|
||||
enum RequantizePosition {
|
||||
NO_REQUANTIZE,
|
||||
ON_INPUT,
|
||||
ON_OUTPUT
|
||||
} pos = NO_REQUANTIZE;
|
||||
|
||||
// Quantization parameters will be used to add the requantize ops.
|
||||
QuantParams params;
|
||||
};
|
||||
|
||||
// This class manages all the intermedaite quantization states.
|
||||
class QuantizeContext {
|
||||
public:
|
||||
QuantizeContext(FuncOp func, const DeviceTarget &spec);
|
||||
|
||||
// Returns all the quant region ops.
|
||||
ArrayRef<quant::QuantizeRegionOp> GetAllOps();
|
||||
|
||||
// For each quant region op, propagates its quantization parameters according
|
||||
// to the kernel specification and also returns the adjcent quant region ops
|
||||
// which get the new quantization parameters propagated.
|
||||
LogicalResult Handle(quant::QuantizeRegionOp op,
|
||||
llvm::SmallVectorImpl<Operation *> *new_items,
|
||||
bool *changed);
|
||||
|
||||
// Updates the port quantization specifications of all the quant region ops
|
||||
// with the propagation results.
|
||||
LogicalResult Finalize();
|
||||
|
||||
// Dumps the states stores in the state manager.
|
||||
void DumpStates(QuantizeRegionOp current_op = {});
|
||||
|
||||
// Update the quantization parameter for certain result of the op. By this
|
||||
// method, the quantization parameter is propagated to all the users of the
|
||||
// result as well.
|
||||
bool SetResultParams(Operation *op, int index, QuantParams params) {
|
||||
return states_manager_.SetResultParams(op, index, params);
|
||||
}
|
||||
|
||||
// Update the quantization parameter for certain operand of the op. By this
|
||||
// method, the quantization parameter is propagated to the defining op of
|
||||
// operand as well.
|
||||
bool SetOperandParams(Operation *op, int index, QuantParams params) {
|
||||
return states_manager_.SetOperandParams(op, index, params);
|
||||
}
|
||||
|
||||
// Return the quantization parameter of certain result of the op.
|
||||
QuantParams GetResultParams(Operation *op, int index) {
|
||||
return states_manager_.GetResultParams(op, index);
|
||||
}
|
||||
|
||||
// Return the quantization parameter of certain operand of the op.
|
||||
QuantParams GetOperandParams(Operation *op, int index) {
|
||||
return states_manager_.GetOperandParams(op, index);
|
||||
}
|
||||
|
||||
private:
|
||||
class StatesManager {
|
||||
public:
|
||||
// Sets the quantization parameters of the constant result according to its
|
||||
// content.
|
||||
//
|
||||
// Always returns true.
|
||||
bool SetConstantResultParams(Operation *op);
|
||||
|
||||
// Sets the quantization parameters of the result to a fixed value. If any
|
||||
// quantization parameters have been propagated, a `requantize` will happen
|
||||
// on the input of propagated quantization.
|
||||
//
|
||||
// Returns true, if the users of the result needs to be added to the
|
||||
// worklist.
|
||||
bool SetResultParams(Operation *op, int index, QuantParams params);
|
||||
|
||||
// Sets the quantization parameters of the operand to a fixed value. If any
|
||||
// quantization parameters have been propagated, a `requantize` will happen
|
||||
// on the output of propagated quantization.
|
||||
//
|
||||
// Returns true, if the defining op of the operand needs to be added to the
|
||||
// worklist.
|
||||
bool SetOperandParams(Operation *op, int index, QuantParams params);
|
||||
|
||||
// Returns the quantization parameters of the index-th result of the op.
|
||||
QuantParams GetResultParams(Operation *op, int index) {
|
||||
return states_[result_states_[{op, index}]].params;
|
||||
}
|
||||
|
||||
// Returns the quantization parameters of the index-th operand of the op.
|
||||
QuantParams GetOperandParams(Operation *op, int index) {
|
||||
return states_[operand_states_[{op, index}]].params;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class QuantizeContext;
|
||||
|
||||
// Uses the type of `val` to set the initial state of the index-th result if
|
||||
// `as_result` is true or index-th operand if `as_result` is false. The
|
||||
// state is immutable if the type is a quantized type. Returns the index of
|
||||
// this new state in the state vector.
|
||||
int InitializeState(quant::QuantizeRegionOp op, int index, bool as_result);
|
||||
|
||||
// Sets the state of the index-th operand of the op. If this operand is
|
||||
// cached, uses the cached result without creating new entry in the state
|
||||
// vector. Otherwise, allocate a new entry in the state vector.
|
||||
void InitializeOperandState(quant::QuantizeRegionOp op, int index,
|
||||
llvm::DenseMap<Value, int> *cache);
|
||||
|
||||
// Sets the state of the index-th result of the op. If this result is
|
||||
// cached, uses the cached result without creating new entry in the state
|
||||
// vector. Otherwise, allocate a new entry in the state vector.
|
||||
void InitializeResultState(quant::QuantizeRegionOp op, int index,
|
||||
llvm::DenseMap<Value, int> *cache);
|
||||
|
||||
// Returns the state of the index-th operand of the op.
|
||||
QuantState &GetOperandQuantState(Operation *op, int index) {
|
||||
return states_[operand_states_[{op, index}]];
|
||||
}
|
||||
|
||||
// Returns the state of the index-th result of the op.
|
||||
QuantState &GetResultQuantState(Operation *op, int index) {
|
||||
return states_[result_states_[{op, index}]];
|
||||
}
|
||||
|
||||
// Returns the state of the index-th operand of the op.
|
||||
RequantizeState &GetOperandRequantizeState(Operation *op, int index) {
|
||||
return rescale_states_[operand_states_[{op, index}]];
|
||||
}
|
||||
|
||||
// Returns the state of the index-th result of the op.
|
||||
RequantizeState &GetResultRequantizeState(Operation *op, int index) {
|
||||
return rescale_states_[result_states_[{op, index}]];
|
||||
}
|
||||
|
||||
private:
|
||||
// This is used to identify an operand or result of an op. The second
|
||||
// element of this pair is the index of the operand or result.
|
||||
using OpValue = std::pair<mlir::Operation *, int>;
|
||||
|
||||
// The vector contains all the quantization parameters propagated from the
|
||||
// defining operations of the value, or from the quantization aware
|
||||
// training.
|
||||
std::vector<QuantState> states_;
|
||||
|
||||
// The map contains all the quantization parameters which are required to
|
||||
// satisfy the same operands and results constraint. The keys of this map
|
||||
// are the values from `operand_states_` and `result_state_`.
|
||||
std::unordered_map<int, RequantizeState> rescale_states_;
|
||||
|
||||
// Maps of indexes to the propagation state vector from the ops operands,
|
||||
// results and arguments.
|
||||
llvm::DenseMap<OpValue, int> operand_states_;
|
||||
llvm::DenseMap<OpValue, int> result_states_;
|
||||
};
|
||||
|
||||
FuncOp func_;
|
||||
|
||||
DeviceTarget target_spec_;
|
||||
|
||||
StatesManager states_manager_;
|
||||
};
|
||||
|
||||
} // namespace quant
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_
|
@ -33,7 +33,9 @@ cc_library(
|
||||
"passes.h",
|
||||
],
|
||||
deps = [
|
||||
":cpu_device_target",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_context",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/compiler/xla/client/lib:quantize",
|
||||
@ -49,6 +51,23 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpu_device_target",
|
||||
srcs = [
|
||||
"cpu_device_target.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"cpu_device_target.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite/quantization:device_target",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_context",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "quantize",
|
||||
srcs = [
|
||||
|
@ -0,0 +1,59 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.h"
|
||||
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
namespace ph = std::placeholders;
|
||||
|
||||
CpuDeviceTarget::CpuDeviceTarget(MLIRContext* ctx) : DeviceTarget(ctx) {
|
||||
RegisterKernel("generic.concat", {qi8_, qi8_, qi8_},
|
||||
quant::ScaleConstraintType::OutputInputSameScale);
|
||||
RegisterKernel("generic.mul", {qi8_, qi8_, qi8_},
|
||||
quant::ScaleConstraintType::OutputInputFreeScale);
|
||||
RegisterKernel("generic.mul_add", {qi8_, qi8n_, any_, qi8_},
|
||||
std::bind(&CpuDeviceTarget::HandleMultiplyAccumulateScale,
|
||||
this, ph::_1, ph::_2, ph::_3, ph::_4));
|
||||
RegisterKernel("generic.matmul_add", {qi8_, qi8n_, any_, qi8_},
|
||||
std::bind(&CpuDeviceTarget::HandleMultiplyAccumulateScale,
|
||||
this, ph::_1, ph::_2, ph::_3, ph::_4));
|
||||
}
|
||||
|
||||
LogicalResult CpuDeviceTarget::HandleMultiplyAccumulateScale(
|
||||
quant::QuantizeContext* ctx, Operation* op,
|
||||
quant::AdjacentOperations* new_items, bool* changed) {
|
||||
auto bias_params = ctx->GetOperandParams(op, 2);
|
||||
if (!EmptyParams(bias_params)) {
|
||||
return success();
|
||||
}
|
||||
std::vector<quant::QuantParams> op_types{ctx->GetOperandParams(op, 0),
|
||||
ctx->GetOperandParams(op, 1)};
|
||||
auto bias_scale = GetUniformQuantizedTypeForBias(op_types);
|
||||
if (bias_scale && ctx->SetOperandParams(op, 2, bias_scale)) {
|
||||
*changed = true;
|
||||
new_items->push_back(op->getOperand(2).getDefiningOp());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
@ -0,0 +1,40 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_CPU_DEVICE_TARGET_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_CPU_DEVICE_TARGET_H_
|
||||
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
// Target specs for cpu kernels
|
||||
class CpuDeviceTarget : public quant::DeviceTarget {
|
||||
public:
|
||||
explicit CpuDeviceTarget(MLIRContext* ctx);
|
||||
|
||||
private:
|
||||
LogicalResult HandleMultiplyAccumulateScale(
|
||||
quant::QuantizeContext* ctx, Operation* op,
|
||||
quant::AdjacentOperations* new_items, bool* changed);
|
||||
};
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_CPU_DEVICE_TARGET_H_
|
@ -26,7 +26,9 @@ limitations under the License.
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.h"
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> disable_per_channel(
|
||||
@ -59,9 +61,36 @@ struct PropagateQuantPass : public FunctionPass<PropagateQuantPass> {
|
||||
|
||||
void PropagateQuantPass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
// TODO(fengliuai): deprecate this old code generation path.
|
||||
// XLA only support uint8/uint16 quantization for now.
|
||||
ApplyQuantizationParamsPropagation(func, /*is_signed*/ false,
|
||||
disable_per_channel, GetOpQuantSpec);
|
||||
|
||||
CpuDeviceTarget spec(&getContext());
|
||||
quant::QuantizeContext ctx(func, spec);
|
||||
|
||||
std::vector<quant::QuantizeRegionOp> work_list(ctx.GetAllOps());
|
||||
bool changed = false;
|
||||
while (!work_list.empty()) {
|
||||
quant::QuantizeRegionOp op = work_list.back();
|
||||
work_list.pop_back();
|
||||
|
||||
llvm::SmallVector<Operation *, 4> new_items;
|
||||
if (failed(ctx.Handle(op, &new_items, &changed))) {
|
||||
// The IR is still valid, thus we shouldn't fail.
|
||||
signalPassFailure();
|
||||
}
|
||||
for (auto item : new_items) {
|
||||
if (auto reg = llvm::dyn_cast_or_null<quant::QuantizeRegionOp>(item))
|
||||
work_list.push_back(reg);
|
||||
}
|
||||
}
|
||||
|
||||
if (!changed) return;
|
||||
|
||||
if (failed(ctx.Finalize())) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -0,0 +1,54 @@
|
||||
// RUN: tf-opt -xla-hlo-propagate-quant %s | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @mul_add_source_no_params
|
||||
func @mul_add_source_no_params(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
|
||||
%region = "quant.region"(%arg0, %arg1, %arg2) ( {
|
||||
^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
|
||||
%mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
%add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
"quant.return"(%add) : (tensor<4xf32>) -> ()
|
||||
}) {input_specs = [f32, f32, f32], logical_kernel = "generic.mul_add", output_specs = [f32]} :
|
||||
(tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %region : tensor<4xf32>
|
||||
|
||||
// CHECK: input_specs = [f32, f32, f32]
|
||||
// CHECK-SAME: output_specs = [f32]
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @mul_add_annotated_no_narrow_range
|
||||
func @mul_add_annotated_no_narrow_range(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
|
||||
%region = "quant.region"(%arg0, %arg1, %arg2) ( {
|
||||
^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
|
||||
%mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
%add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
"quant.return"(%add) : (tensor<4xf32>) -> ()
|
||||
}) {input_specs = [!quant.uniform<i8:f32, 1.0:-128>, !quant.uniform<i8:f32, 1.0:-128>, f32],
|
||||
logical_kernel = "generic.mul_add", output_specs = [!quant.uniform<i8:f32, 1.0:-128>]} :
|
||||
(tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %region : tensor<4xf32>
|
||||
|
||||
// CHECK: input_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>, !quant.uniform<i8:f32, 1.000000e+00:-128>, f32]
|
||||
// CHECK-SAME: output_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>]
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @mul_add_annotated
|
||||
func @mul_add_annotated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
|
||||
%region = "quant.region"(%arg0, %arg1, %arg2) ( {
|
||||
^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
|
||||
%mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
%add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
"quant.return"(%add) : (tensor<4xf32>) -> ()
|
||||
}) {input_specs = [!quant.uniform<i8:f32, 1.0:-128>, !quant.uniform<i8<-127:127>:f32, 1.0:-128>, f32],
|
||||
logical_kernel = "generic.mul_add", output_specs = [!quant.uniform<i8:f32, 1.0:-128>]} :
|
||||
(tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %region : tensor<4xf32>
|
||||
|
||||
// CHECK: input_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>, !quant.uniform<i8<-127:127>:f32, 1.000000e+00:-128>, !quant.uniform<i32:f32, 1.000000e+00>]
|
||||
// CHECK-SAME: output_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>]
|
||||
}
|
@ -5,6 +5,11 @@ package(licenses = ["notice"])
|
||||
glob_lit_tests(
|
||||
data = [":test_utilities"],
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
tags_override = {
|
||||
"legalize-tf.mlir": ["no_rocm"],
|
||||
"optimize.mlir": ["no_rocm"],
|
||||
"prepare-tf.mlir": ["no_rocm"],
|
||||
},
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# RUN: not tf_tfl_translate -mlir-pretty-debuginfo -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes=1,224,224,3 -tf-output-arrays=MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm -tf-debug-info=%s.debug %s -o - 2>&1 | FileCheck %s
|
||||
|
||||
# CHECK: fake/user/code/file_C.py: error: 'tf.Conv2D' op attribute 'data_format' failed to satisfy constraint: 'NHWC' or 'NCHW' convnet data format
|
||||
# CHECK: fake/user/code/file_C.py:27:0: error: 'tf.Conv2D' op attribute 'data_format' failed to satisfy constraint: 'NHWC' or 'NCHW' convnet data format
|
||||
|
||||
node {
|
||||
name: "input"
|
||||
|
@ -1,9 +1,9 @@
|
||||
# RUN: not tf_tfl_translate -mlir-pretty-debuginfo -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes=1,224,224,3 -tf-output-arrays=MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm -tf-debug-info=%s.debug %s -o - 2>&1 | FileCheck %s
|
||||
|
||||
# CHECK: fake/user/code/file_C.py: error: 'tf.Conv2D' op attribute 'data_format' failed to satisfy constraint: 'NHWC' or 'NCHW' convnet data format
|
||||
# CHECK: fake/user/code/file_D.py: note: called from
|
||||
# CHECK: fake/user/code/file_E.py: note: called from
|
||||
# CHECK: fake/user/code/file_F.py: note: called from
|
||||
# CHECK: fake/user/code/file_C.py:27:0: error: 'tf.Conv2D' op attribute 'data_format' failed to satisfy constraint: 'NHWC' or 'NCHW' convnet data format
|
||||
# CHECK: fake/user/code/file_D.py:28:0: note: called from
|
||||
# CHECK: fake/user/code/file_E.py:29:0: note: called from
|
||||
# CHECK: fake/user/code/file_F.py:30:0: note: called from
|
||||
|
||||
node {
|
||||
name: "input"
|
||||
|
@ -8,6 +8,12 @@ glob_lit_tests(
|
||||
":test_utilities",
|
||||
],
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
tags_override = {
|
||||
"add.pbtxt": ["no_rocm"],
|
||||
"conv_2d.pbtxt": ["no_rocm"],
|
||||
"fake_quant_per_channel.pbtxt": ["no_rocm"],
|
||||
"ophint_lstm.pbtxt": ["no_rocm"],
|
||||
},
|
||||
test_file_exts = [
|
||||
"pbtxt",
|
||||
],
|
||||
|
@ -50,8 +50,8 @@ func @while_cond_10_frozen0(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: t
|
||||
// INLINE: ^bb0([[ARGS]]):
|
||||
// INLINE: %cst_2 = constant
|
||||
// INLINE: yield
|
||||
// INLINE: while_body
|
||||
// INLINE: while_cond
|
||||
// INLINE-NOT: while_body
|
||||
// INLINE-NOT: while_cond
|
||||
|
||||
// CANON-LABEL: func @while_main
|
||||
// CANON-SAME: ([[VAL_0:%.*]]: tensor<?x256x256xf32>)
|
||||
|
@ -192,11 +192,11 @@ func @argmin(%arg0: tensor<3xi32>, %arg1: tensor<i32>) -> tensor<i32> {
|
||||
// CHECK: "tfl.arg_min"(%arg0, %arg1) : (tensor<3xi32>, tensor<i32>) -> tensor<i32>
|
||||
}
|
||||
|
||||
func @sigmoid(%arg0: tensor<?x88xf16>) -> tensor<?x88xf16> {
|
||||
%0 = "tf.Sigmoid"(%arg0) : (tensor<?x88xf16>) -> tensor<?x88xf16>
|
||||
return %0 : tensor<?x88xf16>
|
||||
func @sigmoid(%arg0: tensor<?x88xf32>) -> tensor<?x88xf32> {
|
||||
%0 = "tf.Sigmoid"(%arg0) : (tensor<?x88xf32>) -> tensor<?x88xf32>
|
||||
return %0 : tensor<?x88xf32>
|
||||
// CHECK-LABEL: sigmoid
|
||||
// CHECK: "tfl.logistic"(%arg0) : (tensor<?x88xf16>) -> tensor<?x88xf16>
|
||||
// CHECK: "tfl.logistic"(%arg0) : (tensor<?x88xf32>) -> tensor<?x88xf32>
|
||||
}
|
||||
|
||||
func @sqrt(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
@ -1316,16 +1316,6 @@ func @assert_remove(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi1>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
func @reciprocal_f16(%arg0: tensor<8xf16>) -> tensor<8xf16> {
|
||||
%0 = "tf.Reciprocal"(%arg0) : (tensor<8xf16>) -> tensor<8xf16>
|
||||
return %0: tensor<8xf16>
|
||||
|
||||
// CHECK-LABEL: reciprocal_f16
|
||||
// CHECK: %cst = constant dense<1.000000e+00> : tensor<f16>
|
||||
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<f16>, tensor<8xf16>) -> tensor<8xf16>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
func @reciprocal_f32(%arg0: tensor<8xf32>) -> tensor<8xf32> {
|
||||
%0 = "tf.Reciprocal"(%arg0) : (tensor<8xf32>) -> tensor<8xf32>
|
||||
return %0: tensor<8xf32>
|
||||
@ -1336,16 +1326,6 @@ func @reciprocal_f32(%arg0: tensor<8xf32>) -> tensor<8xf32> {
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
func @reciprocal_complex_f32(%arg0: tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>> {
|
||||
%0 = "tf.Reciprocal"(%arg0) : (tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>>
|
||||
return %0: tensor<8xcomplex<f32>>
|
||||
|
||||
// CHECK-LABEL: reciprocal_complex_f32
|
||||
// CHECK: %cst = constant opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C455836342074656E736F725F7368617065207B2064696D207B2073697A653A2031207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3230303F5C3030305C3030305C3030305C30303022"> : tensor<complex<f32>>
|
||||
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<complex<f32>>, tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
func @reciprocal_i32(%arg0: tensor<8xi32>) -> tensor<8xi32> {
|
||||
%0 = "tf.Reciprocal"(%arg0) : (tensor<8xi32>) -> tensor<8xi32>
|
||||
return %0: tensor<8xi32>
|
||||
@ -1356,16 +1336,6 @@ func @reciprocal_i32(%arg0: tensor<8xi32>) -> tensor<8xi32> {
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
func @reciprocal_i64(%arg0: tensor<8xi64>) -> tensor<8xi64> {
|
||||
%0 = "tf.Reciprocal"(%arg0) : (tensor<8xi64>) -> tensor<8xi64>
|
||||
return %0: tensor<8xi64>
|
||||
|
||||
// CHECK-LABEL: reciprocal_i64
|
||||
// CHECK: %cst = constant dense<1> : tensor<i64>
|
||||
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<i64>, tensor<8xi64>) -> tensor<8xi64>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
func @random_uniform() -> tensor<2x5xf32> {
|
||||
%0 = "tf.Const"() { value = dense<[2, 5]> : tensor<2xi32> } : () -> tensor<2xi32>
|
||||
%1 = "tf.RandomUniform"(%0) { seed = 1, seed2 = 0} : (tensor<2xi32>) -> tensor<2x5xf32>
|
||||
|
@ -14,7 +14,6 @@
|
||||
// CHECK-NEXT: Tensor 2 val kTfLiteFloat32 kTfLiteMmapRo 4 bytes
|
||||
// CHECK-NEXT: Tensor 3 tfl.while kTfLiteInt32 kTfLiteArenaRw 4 bytes
|
||||
// CHECK-NEXT: Tensor 4 result kTfLiteFloat32 kTfLiteArenaRw 4 bytes
|
||||
// CHECK-NEXT: Tensor 5 tfl.while:2 kTfLiteInt32 kTfLiteArenaRw 4 bytes
|
||||
|
||||
// Verify while was not folded away:
|
||||
// ------------------------------------
|
||||
|
@ -16,7 +16,7 @@ func @testCos(tensor<? x f32>) -> tensor<? x f32> {
|
||||
// test invalid Cos input
|
||||
func @testCosWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
|
||||
^bb0(%arg0: tensor<?xi32>):
|
||||
// expected-error @+1 {{tfl.cos' op operand #0 must be tensor of floating-point values}}
|
||||
// expected-error @+1 {{tfl.cos' op operand #0 must be tensor of 32-bit float values}}
|
||||
%0 = "tfl.cos"(%arg0): (tensor<?xi32>) -> tensor<?xi32>
|
||||
return %0#0 : tensor<?xi32>
|
||||
}
|
||||
@ -103,7 +103,7 @@ func @testAddN(tensor<? x f32>, tensor<? x f32>, tensor<? x f32>) -> tensor<? x
|
||||
// test invalid AddN
|
||||
func @testAddNWrongOperandResultType(tensor<? x f16>, tensor<? x f16>, tensor<? x f16>) -> tensor<? x f16> {
|
||||
^bb0(%arg0: tensor<? x f16>, %arg1: tensor<? x f16>, %arg2: tensor<? x f16>):
|
||||
// expected-error @+1 {{'tfl.add_n' op operand #0 must be tensor of 32-bit float or 32-bit signless integer or QI16 type or QUI16 type values}}
|
||||
// expected-error @+1 {{'tfl.add_n' op operand #0 must be tensor of 32-bit float or 32-bit signless integer}}
|
||||
%0 = "tfl.add_n"(%arg0, %arg1, %arg2): (tensor<? x f16>, tensor<? x f16>, tensor<? x f16>) -> tensor<? x f16>
|
||||
return %0 : tensor<? x f16>
|
||||
}
|
||||
@ -147,7 +147,7 @@ func @testSin(tensor<? x f32>) -> tensor<? x f32> {
|
||||
// test invalid Sin input
|
||||
func @testSinWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
|
||||
^bb0(%arg0: tensor<?xi32>):
|
||||
// expected-error @+1 {{tfl.sin' op operand #0 must be tensor of floating-point values}}
|
||||
// expected-error @+1 {{tfl.sin' op operand #0 must be tensor of 32-bit float values}}
|
||||
%0 = "tfl.sin"(%arg0): (tensor<?xi32>) -> tensor<?xi32>
|
||||
return %0#0 : tensor<?xi32>
|
||||
}
|
||||
@ -157,7 +157,7 @@ func @testSinWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
|
||||
// test invalid Sqrt input
|
||||
func @testSqrtWithWrongInputType(tensor<? x i32>) -> tensor<? x i32> {
|
||||
^bb0(%arg0: tensor<? x i32>):
|
||||
// expected-error @+1 {{tfl.sqrt' op operand #0 must be tensor of floating-point values}}
|
||||
// expected-error @+1 {{tfl.sqrt' op operand #0 must be tensor of 32-bit float values}}
|
||||
%0 = "tfl.sqrt"(%arg0): (tensor<? x i32>) -> tensor<? x i32>
|
||||
return %0#0 : tensor<? x i32>
|
||||
}
|
||||
@ -167,7 +167,7 @@ func @testSqrtWithWrongInputType(tensor<? x i32>) -> tensor<? x i32> {
|
||||
// test invalid Square input
|
||||
func @testSquareWithWrongInputType(tensor<? x i32>) -> tensor<? x i32> {
|
||||
^bb0(%arg0: tensor<? x i32>):
|
||||
// expected-error @+1 {{tfl.square' op operand #0 must be tensor of floating-point values}}
|
||||
// expected-error @+1 {{tfl.square' op operand #0 must be tensor of 32-bit float values}}
|
||||
%0 = "tfl.square"(%arg0): (tensor<? x i32>) -> tensor<? x i32>
|
||||
return %0#0 : tensor<? x i32>
|
||||
}
|
||||
@ -425,7 +425,7 @@ func @testTileF32(%arg0: tensor<4 x 1 x f32>, %arg1: tensor<4 x i32>) -> tensor<
|
||||
// -----
|
||||
|
||||
func @testEluI32(%arg0: tensor<? x i32>) -> tensor<? x i32> {
|
||||
// expected-error @+1 {{operand #0 must be tensor of floating-point values}}
|
||||
// expected-error @+1 {{op operand #0 must be tensor of 32-bit float values}}
|
||||
%0 = "tfl.elu"(%arg0): (tensor<? x i32>) -> tensor<? x i32>
|
||||
return %0#0 : tensor<? x i32>
|
||||
}
|
||||
@ -531,11 +531,11 @@ func @testMaxUnpooling2D(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testLogistic
|
||||
func @testLogistic(tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> {
|
||||
^bb0(%arg0: tensor<1x2x3x4x5xbf16>):
|
||||
func @testLogistic(tensor<1x2x3x4x5xf32>) -> tensor<1x2x3x4x5xf32> {
|
||||
^bb0(%arg0: tensor<1x2x3x4x5xf32>):
|
||||
// CHECK: "tfl.logistic"(%arg0)
|
||||
%0 = "tfl.logistic"(%arg0): (tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16>
|
||||
return %0 : tensor<1x2x3x4x5xbf16>
|
||||
%0 = "tfl.logistic"(%arg0): (tensor<1x2x3x4x5xf32>) -> tensor<1x2x3x4x5xf32>
|
||||
return %0 : tensor<1x2x3x4x5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -543,7 +543,7 @@ func @testLogistic(tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> {
|
||||
// test invalid Logistic input
|
||||
func @testLogisticWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
|
||||
^bb0(%arg0: tensor<?xi32>):
|
||||
// expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of floating-point or QI8 type or QUI8 type or QI16 type or QUI16 type values}}
|
||||
// expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of 32-bit float or QI8 type or QUI8 type or QI16 type or QUI16 type values}}
|
||||
%0 = "tfl.logistic"(%arg0): (tensor<?xi32>) -> tensor<?xi32>
|
||||
return %0#0 : tensor<?xi32>
|
||||
}
|
||||
|
@ -400,8 +400,8 @@ func @FuseFullyConnectedReshapeAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<
|
||||
// FOLD: return %[[fc]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @NotReorderReshapeAddIfNotBroadcastable
|
||||
func @NotReorderReshapeAddIfNotBroadcastable(%arg0: tensor<40x10x4xf32>) -> tensor<40x40xf32> {
|
||||
// CHECK-LABEL: @NotReorderReshapeAddIfNotBroadcastableAfter
|
||||
func @NotReorderReshapeAddIfNotBroadcastableAfter(%arg0: tensor<40x10x4xf32>) -> tensor<40x40xf32> {
|
||||
%cst = constant dense<2.0> : tensor<40xf32>
|
||||
%shape = constant dense<[40, 40]> : tensor<2xi32>
|
||||
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<40x10x4xf32>, tensor<2xi32>) -> tensor<40x40xf32>
|
||||
@ -413,6 +413,19 @@ func @NotReorderReshapeAddIfNotBroadcastable(%arg0: tensor<40x10x4xf32>) -> tens
|
||||
// CHECK: return %[[rs2]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @NotReorderReshapeAddIfNotTailingDimAfter
|
||||
func @NotReorderReshapeAddIfNotTailingDimAfter(%arg0: tensor<1x30x1x96xf32>) -> tensor<1x30x96xf32> {
|
||||
%cst = constant dense<2.0> : tensor<1x30x96xf32>
|
||||
%shape = constant dense<[1, 30, 96]> : tensor<3xi32>
|
||||
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<1x30x1x96xf32>, tensor<3xi32>) -> tensor<1x30x96xf32>
|
||||
%2 = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<1x30x96xf32>, tensor<1x30x96xf32>) -> tensor<1x30x96xf32>
|
||||
return %2 : tensor<1x30x96xf32>
|
||||
|
||||
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0
|
||||
// CHECK: %[[rs2:.*]] = tfl.add %[[rs1]]
|
||||
// CHECK: return %[[rs2]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @NotReorderReshapeAddIfNotTailingDim
|
||||
func @NotReorderReshapeAddIfNotTailingDim(%arg0: tensor<40x40x1xf32>) -> tensor<40x40xf32> {
|
||||
%cst = constant dense<2.0> : tensor<1x40xf32>
|
||||
|
@ -15,14 +15,14 @@ func @while() -> tensor<1xf32>
|
||||
%0:2 = "tfl.while"(%cst0, %cst1) ( {
|
||||
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>):
|
||||
// CHECK: call @WhileOp_cond
|
||||
// CHECK-SAME: (tensor<*xi32>, tensor<*xf32>, tensor<i32>)
|
||||
// CHECK-SAME: (tensor<*xi32>, tensor<*xf32>)
|
||||
%cst_0 = constant dense<0> : tensor<i32>
|
||||
%1 = "tfl.greater"(%arg2, %cst_0) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
|
||||
"tfl.yield"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>):
|
||||
// CHECK: call @WhileOp_body
|
||||
// CHECK-SAME: (tensor<*xi32>, tensor<*xf32>, tensor<i32>)
|
||||
// CHECK-SAME: (tensor<*xi32>, tensor<*xf32>)
|
||||
%1 = "tfl.sub"(%arg2, %cst) {fused_activation_function = "NONE"} :
|
||||
(tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%2 = tfl.add %arg3, %arg3 {fused_activation_function = "NONE"} : tensor<*xf32>
|
||||
@ -40,8 +40,7 @@ func @while() -> tensor<1xf32>
|
||||
|
||||
// CHECK-LABEL: func @while2
|
||||
// Verify that while body//cond with implicitly captured values result in changing while operands/results.
|
||||
func @while2() -> tensor<1xf32> attributes {tf.entry_function = {outputs = "result"}} {
|
||||
%cst = constant dense<1> : tensor<i32>
|
||||
func @while2(%cst : tensor<i32>) -> tensor<1xf32> attributes {tf.entry_function = {outputs = "result"}} {
|
||||
%cst_0 = constant dense<5> : tensor<i32>
|
||||
%cst_1 = constant dense<3.000000e+00> : tensor<1xf32>
|
||||
// Verifies 3 operands post outlining.
|
||||
@ -148,22 +147,21 @@ func @rnn(%arg0: tensor<4x4x3xf32> {tf.device = "/device:CPU:0"}) -> tensor<4x?x
|
||||
// CHECK: tfl.while
|
||||
// CHECK: tfl.yield
|
||||
// CHECK-SAME: (tensor<i1>) -> ()
|
||||
// CHECK: [[VAL_41:%.*]]:18 =
|
||||
// CHECK: [[VAL_30:%.*]]:7 =
|
||||
// CHECK: call @tfl.while_body
|
||||
// CHECK: tfl.yield
|
||||
// CHECK-SAME: (tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<i32>, tensor<i32>, tensor<4x4x3xf32>, tensor<8x5xf32>, tensor<8xf32>, tensor<f32>, tensor<1xi32>, tensor<i32>, tensor<1xi32>, tensor<1xi32>, tensor<i32>, tensor<1xi32>) -> ()
|
||||
// CHECK-SAME: (tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<4x4x3xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: func @tfl.while_cond(
|
||||
// CHECK-SAME: [[VAL_56:%.*]]: tensor<i32>, [[VAL_57:%.*]]: tensor<i32>, [[VAL_58:%.*]]: tensor<*xf32>, [[VAL_59:%.*]]: tensor<4x2xf32>, [[VAL_60:%.*]]: tensor<4x2xf32>, [[VAL_61:%.*]]: tensor<*xf32>, [[VAL_62:%.*]]: tensor<i32>, [[VAL_63:%.*]]: tensor<i32>, [[VAL_64:%.*]]: tensor<4x4x3xf32>, [[VAL_65:%.*]]: tensor<8x5xf32>, [[VAL_66:%.*]]: tensor<8xf32>, [[VAL_67:%.*]]: tensor<f32>, [[VAL_68:%.*]]: tensor<1xi32>, [[VAL_69:%.*]]: tensor<i32>, [[VAL_70:%.*]]: tensor<1xi32>, [[VAL_71:%.*]]: tensor<1xi32>, [[VAL_72:%.*]]: tensor<i32>, [[VAL_73:%.*]]: tensor<1xi32>) -> tensor<i1> attributes {sym_visibility = "private"} {
|
||||
// CHECK-SAME: [[VAL_35:%.*]]: tensor<i32>, [[VAL_36:%.*]]: tensor<i32>, [[VAL_37:%.*]]: tensor<*xf32>, [[VAL_38:%.*]]: tensor<4x2xf32>, [[VAL_39:%.*]]: tensor<4x2xf32>, [[VAL_40:%.*]]: tensor<*xf32>, [[VAL_41:%.*]]: tensor<4x4x3xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
|
||||
// CHECK: return
|
||||
// CHECK-SAME: tensor<i1>
|
||||
// CHECK: }
|
||||
|
||||
// CHECK-LABEL: func @tfl.while_body(
|
||||
// CHECK-SAME: [[VAL_77:%.*]]: tensor<i32>, [[VAL_78:%.*]]: tensor<i32>, [[VAL_79:%.*]]: tensor<*xf32>, [[VAL_80:%.*]]: tensor<4x2xf32>, [[VAL_81:%.*]]: tensor<4x2xf32>, [[VAL_82:%.*]]: tensor<*xf32>, [[VAL_83:%.*]]: tensor<i32>, [[VAL_84:%.*]]: tensor<i32>, [[VAL_85:%.*]]: tensor<4x4x3xf32>, [[VAL_86:%.*]]: tensor<8x5xf32>, [[VAL_87:%.*]]: tensor<8xf32>, [[VAL_88:%.*]]: tensor<f32>, [[VAL_89:%.*]]: tensor<1xi32>, [[VAL_90:%.*]]: tensor<i32>, [[VAL_91:%.*]]: tensor<1xi32>, [[VAL_92:%.*]]: tensor<1xi32>, [[VAL_93:%.*]]: tensor<i32>, [[VAL_94:%.*]]: tensor<1xi32>) -> (tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<i32>, tensor<i32>, tensor<4x4x3xf32>, tensor<8x5xf32>, tensor<8xf32>, tensor<f32>, tensor<1xi32>, tensor<i32>, tensor<1xi32>, tensor<1xi32>, tensor<i32>, tensor<1xi32>) attributes {sym_visibility = "private"} {
|
||||
// CHECK: [[VAL_123:%.*]] = "tfl.cast"
|
||||
// CHECK-SAME: [[VAL_46:%.*]]: tensor<i32>, [[VAL_47:%.*]]: tensor<i32>, [[VAL_48:%.*]]: tensor<*xf32>, [[VAL_49:%.*]]: tensor<4x2xf32>, [[VAL_50:%.*]]: tensor<4x2xf32>, [[VAL_51:%.*]]: tensor<*xf32>, [[VAL_52:%.*]]: tensor<4x4x3xf32>) -> (tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<4x4x3xf32>) attributes {sym_visibility = "private"} {
|
||||
// CHECK: [[VAL_91:%.*]] = "tfl.cast"
|
||||
// CHECK: return
|
||||
// CHECK-SAME: [[VAL_123]], [[VAL_83]], [[VAL_84]], [[VAL_85]], [[VAL_86]], [[VAL_87]], [[VAL_88]], [[VAL_89]], [[VAL_90]], [[VAL_91]], [[VAL_92]], [[VAL_93]], [[VAL_94]] : tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<i32>, tensor<i32>, tensor<4x4x3xf32>, tensor<8x5xf32>, tensor<8xf32>, tensor<f32>, tensor<1xi32>, tensor<i32>, tensor<1xi32>, tensor<1xi32>, tensor<i32>, tensor<1xi32>
|
||||
// CHECK-SAME: [[VAL_91]], [[VAL_52]] : tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<4x4x3xf32>
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
|
||||
|
@ -79,10 +79,12 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
pass_config.quant_specs.serialized_quant_stats));
|
||||
}
|
||||
|
||||
if (pass_config.lower_tensor_list_ops) {
|
||||
// TODO(haoliang): Add this pass by default.
|
||||
pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
|
||||
}
|
||||
// The conversion pipeline has to follow the following orders:
|
||||
// 1) Try to convert ophint nodes if present first like ophint lstm.
|
||||
// 2) Saved model related optimization like decompose resource ops
|
||||
// 3) Convert composite functions like lstm/rnns, along with proper function
|
||||
// inlining & dce.
|
||||
// 4) Lower static tensor list pass.
|
||||
|
||||
// The ophint extractions happen before lots of other passes:
|
||||
// The assumption of ophint-extraction is each ophinted region is a black-box
|
||||
@ -105,29 +107,40 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(
|
||||
mlir::TFDevice::CreateDecomposeResourceOpsPass());
|
||||
|
||||
// This pass does resource analysis of saved model global tensors and marks
|
||||
// those deemed read-only as immutable.
|
||||
pass_manager->addPass(
|
||||
mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
|
||||
// Note:
|
||||
// We need to fuse composite ops before LowerStaticTensorList pass.
|
||||
// The tensorflow list is not supported right now by that pass.
|
||||
// Enable fusing composite ops that can be lowered to built-in TFLite ops.
|
||||
if (pass_config.emit_builtin_tflite_ops) {
|
||||
pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
|
||||
}
|
||||
|
||||
// This pass marks non-exported functions as symbol visibility 'private'
|
||||
// those deemed read-only as immutable.
|
||||
pass_manager->addPass(
|
||||
mlir::tf_saved_model::
|
||||
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass());
|
||||
|
||||
// Enable fusing composite ops that can be lowered to built-in TFLite ops.
|
||||
if (pass_config.emit_builtin_tflite_ops) {
|
||||
pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
|
||||
pass_manager->addPass(mlir::createInlinerPass());
|
||||
pass_manager->addPass(mlir::createSymbolDCEPass());
|
||||
|
||||
if (pass_config.lower_tensor_list_ops) {
|
||||
// TODO(haoliang): Add this pass by default.
|
||||
pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
|
||||
}
|
||||
|
||||
// This pass does resource analysis of saved model global tensors and marks
|
||||
// those deemed read-only as immutable.
|
||||
pass_manager->addPass(
|
||||
mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
|
||||
|
||||
// Legalize while early to allow further constant folding.
|
||||
// TODO(jpienaar): This may not actually matter as we do canonicalization
|
||||
// after the legalize below, for now it needs to be below the above passes
|
||||
// that work on TF dialect and before inliner so that the function calls in
|
||||
// body and cond are inlined for optimization.
|
||||
if (pass_config.legalize_tf_while) {
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(
|
||||
mlir::TFL::CreateLegalizeTFWhilePass());
|
||||
pass_manager->addPass(mlir::TFL::CreateLegalizeTFWhilePass());
|
||||
}
|
||||
|
||||
// Add function inlining pass. Both TF and TFLite dialects are opted into
|
||||
|
@ -31,44 +31,52 @@ namespace {
|
||||
|
||||
// Legalize TF While to TFL While with calls to the original functions from the
|
||||
// cond and body regions.
|
||||
struct LegalizeWhile : public FunctionPass<LegalizeWhile> {
|
||||
void runOnFunction() override {
|
||||
auto func = getFunction();
|
||||
// Convert all TF WhileOps inside the function body to TFL While ops.
|
||||
func.getBody().walk([](TF::WhileOp while_op) {
|
||||
Operation* op = while_op.getOperation();
|
||||
// Create new TFL While op that will be used to replace TF While op.
|
||||
auto new_op = OpBuilder(op).create<TFL::WhileOp>(
|
||||
op->getLoc(), op->getResultTypes(), op->getOperands(),
|
||||
while_op.is_stateless());
|
||||
// Insert call to the given function into the 'region'.
|
||||
auto create_region_with_call = [&while_op](FlatSymbolRefAttr symbol,
|
||||
Region& region) {
|
||||
OpBuilder builder(region);
|
||||
auto block = builder.createBlock(®ion);
|
||||
SmallVector<Value, 4> new_operands;
|
||||
auto func = while_op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
|
||||
symbol.getValue());
|
||||
for (Type t : func.getType().getInputs())
|
||||
new_operands.push_back(block->addArgument(t));
|
||||
auto call =
|
||||
builder.create<CallOp>(while_op.getLoc(), symbol,
|
||||
func.getType().getResults(), new_operands);
|
||||
builder.create<YieldOp>(while_op.getLoc(), call.getResults());
|
||||
};
|
||||
create_region_with_call(while_op.condAttr(), new_op.cond());
|
||||
create_region_with_call(while_op.bodyAttr(), new_op.body());
|
||||
struct LegalizeWhile : public ModulePass<LegalizeWhile> {
|
||||
void RunOnFunction(FuncOp func);
|
||||
|
||||
op->replaceAllUsesWith(new_op.getResults());
|
||||
op->erase();
|
||||
});
|
||||
void runOnModule() override {
|
||||
for (auto op : getModule().getOps<FuncOp>()) RunOnFunction(op);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void RunOnWhile(TF::WhileOp while_op) {
|
||||
Operation* op = while_op.getOperation();
|
||||
// Create new TFL While op that will be used to replace TF While op.
|
||||
auto new_op = OpBuilder(op).create<TFL::WhileOp>(
|
||||
op->getLoc(), op->getResultTypes(), op->getOperands(),
|
||||
while_op.is_stateless());
|
||||
// Insert call to the given function into the 'region'.
|
||||
auto create_region_with_call = [&while_op](FlatSymbolRefAttr symbol,
|
||||
Region& region) {
|
||||
OpBuilder builder(region);
|
||||
auto block = builder.createBlock(®ion);
|
||||
SmallVector<Value, 4> new_operands;
|
||||
auto func = while_op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
|
||||
symbol.getValue());
|
||||
for (Type t : func.getType().getInputs())
|
||||
new_operands.push_back(block->addArgument(t));
|
||||
auto call = builder.create<CallOp>(
|
||||
while_op.getLoc(), symbol, func.getType().getResults(), new_operands);
|
||||
builder.create<YieldOp>(while_op.getLoc(), call.getResults());
|
||||
// Mark old function as private so that it can be DCE'd if not called.
|
||||
func.setVisibility(SymbolTable::Visibility::Private);
|
||||
};
|
||||
create_region_with_call(while_op.condAttr(), new_op.cond());
|
||||
create_region_with_call(while_op.bodyAttr(), new_op.body());
|
||||
|
||||
op->replaceAllUsesWith(new_op.getResults());
|
||||
op->erase();
|
||||
}
|
||||
|
||||
void LegalizeWhile::RunOnFunction(FuncOp func) {
|
||||
// Convert all TF WhileOps inside the function body to TFL While ops.
|
||||
func.getBody().walk([](TF::WhileOp while_op) { RunOnWhile(while_op); });
|
||||
}
|
||||
|
||||
// Creates an instance of the TensorFlow While to TFLite While pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFWhilePass() {
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateLegalizeTFWhilePass() {
|
||||
return std::make_unique<LegalizeWhile>();
|
||||
}
|
||||
|
||||
|
@ -862,6 +862,11 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
|
||||
target.addLegalOp<ConstantOp>();
|
||||
target.addLegalOp<FuncOp>();
|
||||
target.addLegalOp<ReturnOp>();
|
||||
// Register fused LSTM/RNN ops as legal.
|
||||
target.addLegalOp<TFL::LSTMOp>();
|
||||
target.addLegalOp<TFL::UnidirectionalSequenceLSTMOp>();
|
||||
target.addLegalOp<TFL::UnidirectionalSequenceRNNOp>();
|
||||
target.addLegalOp<TFL::BidirectionalSequenceLSTMOp>();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
populateWithGenerated(context, &patterns);
|
||||
|
@ -339,11 +339,15 @@ foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in {
|
||||
(ConstantOp:$rhs $a), TFL_AF_None),
|
||||
(TFL_ReshapeOp (BinaryOp $input, $rhs, TFL_AF_None), $shape),
|
||||
// The broadcasting of "BinaryOp" only happens in the lower
|
||||
// dimensions, and the higher dimensions are same.
|
||||
// dimensions, and the higher dimensions are same, so we know the
|
||||
// result and input of the "BinaryOp" in the source pattern have
|
||||
// the same shape, which is defined by `shape`.
|
||||
[(IsTailOfShape $rhs, $lhs),
|
||||
(HasOneUse $lhs),
|
||||
// the two operands of the binary op is broadcastable
|
||||
(AreBroadcastableTypes $rhs, $input)]>;
|
||||
// The result of the new "BinaryOp" will have the same shape as
|
||||
// `input`. In other words, the shape of the `Reshape` op are not
|
||||
// changed after the transformation.
|
||||
(IsTailOfShape $rhs, $input)]>;
|
||||
}
|
||||
|
||||
foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
|
||||
@ -363,11 +367,15 @@ foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
|
||||
(ConstantOp:$rhs $a)),
|
||||
(TFL_ReshapeOp (BinaryOp $input, $rhs), $shape),
|
||||
// The broadcasting of "BinaryOp" only happens in the lower
|
||||
// dimensions, and the higher dimensions are same.
|
||||
// dimensions, and the higher dimensions are same, so we know the
|
||||
// result and input of the "BinaryOp" in the source pattern have
|
||||
// the same shape, which is defined by `shape`.
|
||||
[(IsTailOfShape $rhs, $lhs),
|
||||
(HasOneUse $lhs),
|
||||
// the two operands of the binary op is broadcastable
|
||||
(AreBroadcastableTypes $rhs, $input)]>;
|
||||
// The result of the new "BinaryOp" will have the same shape as
|
||||
// `input`. In other words, the shape of the `Reshape` op are not
|
||||
// changed after the transformation.
|
||||
(IsTailOfShape $rhs, $input)]>;
|
||||
}
|
||||
|
||||
// Returns shape of a ranked tensor.
|
||||
|
@ -86,7 +86,7 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateDefaultQuantParamsPass(
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateDenseToSparsePass();
|
||||
|
||||
// Creates function pass to legalize TF While to TFL While.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFWhilePass();
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateLegalizeTFWhilePass();
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateWhileOutlinePass();
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Identifier.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
@ -89,24 +90,20 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
|
||||
llvm::SmallVector<Region*, 2> regions{&while_op.cond(), &while_op.body()};
|
||||
for (auto it : llvm::enumerate(regions)) {
|
||||
llvm::SetVector<Value> region_extern_values;
|
||||
Value const_none = nullptr;
|
||||
getUsedValuesDefinedAbove(*it.value(), region_extern_values);
|
||||
|
||||
// Sink down none type constants into the functions.
|
||||
// Sink down constants into the functions.
|
||||
for (auto extern_value : region_extern_values) {
|
||||
if (!extern_value.getType().isa<NoneType>()) {
|
||||
if (!matchPattern(extern_value, m_Constant())) {
|
||||
extern_values.insert(extern_value);
|
||||
continue;
|
||||
}
|
||||
if (!const_none) {
|
||||
// Add constant at start of region.
|
||||
auto const_builder =
|
||||
OpBuilder(&it.value()->front(), it.value()->front().begin());
|
||||
const_none = const_builder.create<ConstantOp>(
|
||||
while_op.getLoc(), extern_value.getType(),
|
||||
const_builder.getUnitAttr());
|
||||
}
|
||||
replaceAllUsesInRegionWith(extern_value, const_none, *it.value());
|
||||
// Add constant at start of region.
|
||||
auto const_builder =
|
||||
OpBuilder(&it.value()->front(), it.value()->front().begin());
|
||||
auto const_value = const_builder.clone(*extern_value.getDefiningOp());
|
||||
replaceAllUsesInRegionWith(extern_value, const_value->getResult(0),
|
||||
*it.value());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -21,6 +21,10 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
@ -86,6 +90,17 @@ MlirOptimizationPassRegistry& MlirOptimizationPassRegistry::Global() {
|
||||
return *global;
|
||||
}
|
||||
|
||||
static void RegisterDialects() {
|
||||
static bool init_once = []() {
|
||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
|
||||
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
|
||||
return true;
|
||||
}();
|
||||
(void)init_once;
|
||||
}
|
||||
|
||||
Status MlirFunctionOptimizationPass::Run(
|
||||
const DeviceSet& device_set, const ConfigProto& config_proto,
|
||||
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
|
||||
@ -107,6 +122,7 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
<< "(registered " << registry_->passes().size() << " passes)";
|
||||
|
||||
GraphDebugInfo debug_info;
|
||||
RegisterDialects();
|
||||
mlir::MLIRContext context;
|
||||
GraphImportConfig import_config;
|
||||
import_config.graph_as_function = true;
|
||||
@ -178,9 +194,12 @@ Status MlirV1CompatGraphOptimizationPass::Run(
|
||||
<< "(registered" << registry_->passes().size() << " passes)";
|
||||
|
||||
GraphDebugInfo debug_info;
|
||||
RegisterDialects();
|
||||
mlir::MLIRContext context;
|
||||
GraphImportConfig import_config;
|
||||
import_config.upgrade_legacy = true;
|
||||
// TODO(b/150959075): Running functionalization before TPU cluster formation
|
||||
// is not semantics preserving and should be disabled for now.
|
||||
import_config.upgrade_legacy = false;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto module_ref,
|
||||
ConvertGraphToMlir(**options.graph, debug_info, *options.flib_def,
|
||||
|
@ -10,6 +10,7 @@ package_group(
|
||||
name = "friends",
|
||||
includes = ["//third_party/mlir:subpackages"],
|
||||
packages = [
|
||||
"//learning/brain/experimental/tfrt/...",
|
||||
"//learning/pathways/data_parallel/tf2xla/...",
|
||||
"//tensorflow/compiler/...",
|
||||
"//tensorflow/lite/experimental/tf_runtime/...",
|
||||
@ -986,7 +987,6 @@ cc_library(
|
||||
tf_cc_test(
|
||||
name = "error_util_test",
|
||||
srcs = ["utils/error_util_test.cc"],
|
||||
tags = ["no_rocm"],
|
||||
deps = [
|
||||
":error_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
@ -1065,6 +1065,7 @@ COMPILE_MLIR_UTIL_DEPS = [
|
||||
":tensorflow_dialect_registration",
|
||||
":tensorflow_passes",
|
||||
":translate_utils",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
@ -1083,6 +1084,8 @@ COMPILE_MLIR_UTIL_DEPS = [
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
]
|
||||
|
||||
# Prefer to link 'compile_mlir_util' library that also links necessary
|
||||
|
@ -1143,6 +1143,29 @@ that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Clips tensor values to a specified min and max.";
|
||||
|
||||
let description = [{
|
||||
Given a tensor `t`, this operation returns a tensor of the same type and
|
||||
shape as `t` with its values clipped to `clip_value_min` and `clip_value_max`.
|
||||
Any values less than `clip_value_min` are set to `clip_value_min`. Any values
|
||||
greater than `clip_value_max` are set to `clip_value_max`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$t,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$clip_value_min,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$clip_value_max
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_ComplexOp : TF_Op<"Complex", [NoSideEffect, ResultsBroadcastableShape]> {
|
||||
let summary = "Converts two real numbers to a complex number.";
|
||||
|
||||
@ -2592,7 +2615,7 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
|
||||
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<3>;
|
||||
}
|
||||
|
||||
def TF_FusedBatchNormGradV3Op : TF_Op<"FusedBatchNormGradV3", [NoSideEffect]> {
|
||||
def TF_FusedBatchNormGradV3Op : TF_Op<"FusedBatchNormGradV3", [NoSideEffect, TF_LayoutSensitiveInterface]> {
|
||||
let summary = "Gradient for batch normalization.";
|
||||
|
||||
let description = [{
|
||||
@ -2623,9 +2646,17 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<3>;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// TF_LayoutSensitiveInterface:
|
||||
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0, 1}; }
|
||||
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
|
||||
StringRef GetOptimalLayout(const RuntimeDevices& devices);
|
||||
LogicalResult UpdateDataFormat(StringRef data_format);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
|
||||
def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> {
|
||||
let summary = "Batch normalization.";
|
||||
|
||||
let description = [{
|
||||
@ -2662,6 +2693,10 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
|
||||
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
||||
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
|
||||
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
|
||||
|
||||
// TF_LayoutSensitiveInterface:
|
||||
StringRef GetOptimalLayout(const RuntimeDevices& devices);
|
||||
LogicalResult UpdateDataFormat(StringRef data_format);
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -1520,6 +1520,34 @@ static LogicalResult Verify(FillOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FusedBatchNormGradOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO(b/150954845): Add benchmarks to verify that layout preference didn't
|
||||
// change in the latest GPU generations.
|
||||
|
||||
LogicalResult FusedBatchNormGradV3Op::UpdateDataFormat(StringRef data_format) {
|
||||
return ::mlir::TF::UpdateDataFormat(data_format, this);
|
||||
}
|
||||
|
||||
StringRef FusedBatchNormGradV3Op::GetOptimalLayout(
|
||||
const RuntimeDevices &devices) {
|
||||
// Keep current data format if no GPUs are available or if explicit placement
|
||||
// does not allow to use GPU for this operation.
|
||||
if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
|
||||
return data_format();
|
||||
|
||||
// For f16 data type on devices with Tensor Cores support NHWC data format
|
||||
// is up to ~2x faster.
|
||||
auto x_ty = x().getType().cast<TensorType>();
|
||||
const bool is_f16 = x_ty.getElementType().isF16();
|
||||
if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
|
||||
|
||||
// For all other data types prefer NCHW.
|
||||
return "NCHW";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FusedBatchNormOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1547,9 +1575,36 @@ static LogicalResult Verify(FusedBatchNormOp op) {
|
||||
|
||||
LogicalResult FusedBatchNormV3Op::FoldOperandsPermutation(
|
||||
ArrayRef<int64_t> permutation) {
|
||||
// FusedBatchNorm in training mode is a layout sentitive operation, and should
|
||||
// have already assigned an optimal data format.
|
||||
if (is_training()) return failure();
|
||||
|
||||
return ::mlir::TF::FoldOperandsPermutation(permutation, this);
|
||||
}
|
||||
|
||||
LogicalResult FusedBatchNormV3Op::UpdateDataFormat(StringRef data_format) {
|
||||
return ::mlir::TF::UpdateDataFormat(data_format, this);
|
||||
}
|
||||
|
||||
StringRef FusedBatchNormV3Op::GetOptimalLayout(const RuntimeDevices &devices) {
|
||||
// In inference mode FusedBatchNorm is not sensitive to data layout.
|
||||
if (!is_training()) return data_format();
|
||||
|
||||
// Keep current data format if no GPUs are available or if explicit placement
|
||||
// does not allow to use GPU for this operation.
|
||||
if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
|
||||
return data_format();
|
||||
|
||||
// For f16 data type on devices with Tensor Cores support NHWC data format
|
||||
// is up to ~2x faster.
|
||||
auto x_ty = x().getType().cast<TensorType>();
|
||||
const bool is_f16 = x_ty.getElementType().isF16();
|
||||
if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
|
||||
|
||||
// For all other data types prefer NCHW.
|
||||
return "NCHW";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GatherV2Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -5,6 +5,10 @@ package(licenses = ["notice"])
|
||||
glob_lit_tests(
|
||||
data = [":test_utilities"],
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
tags_override = {
|
||||
"optimize.mlir": ["no_rocm"],
|
||||
"tf_optimize.mlir": ["no_rocm"],
|
||||
},
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
||||
|
||||
|
@ -7,6 +7,37 @@ func @einsum_basic(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor
|
||||
// CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
|
||||
}
|
||||
|
||||
func @einsum_broadcast(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<3x4x6xf32> {
|
||||
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ijk,km->ijm"}: (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32>
|
||||
return %0 : tensor<3x4x6xf32>
|
||||
// CHECK-LABEL: einsum_broadcast
|
||||
// CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32>
|
||||
}
|
||||
|
||||
func @einsum_reducesum(%arg0: tensor<2x5x7xf32>, %arg1: tensor<5x2xf32>) -> tensor<5x7xf32> {
|
||||
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "lbh,bl->bh"}: (tensor<2x5x7xf32>, tensor<5x2xf32>) -> tensor<5x7xf32>
|
||||
return %0 : tensor<5x7xf32>
|
||||
// CHECK-LABEL: einsum_reducesum
|
||||
// CHECK: %[[cst:.*]] = constant dense<[1, 2, 0]> : tensor<3xi32>
|
||||
// CHECK: %[[cst_1:.*]] = constant dense<[5, 1, 2]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_2:.*]] = constant dense<2> : tensor<1xi32>
|
||||
// CHECK: %[[v0:.*]] = "tf.Transpose"(%arg0, %[[cst]]) : (tensor<2x5x7xf32>, tensor<3xi32>) -> tensor<5x7x2xf32>
|
||||
// CHECK: %[[v1:.*]] = "tf.Reshape"(%arg1, %[[cst_1]]) : (tensor<5x2xf32>, tensor<3xi64>) -> tensor<5x1x2xf32>
|
||||
// CHECK: %[[v2:.*]] = "tf.Mul"(%[[v0]], %[[v1]]) : (tensor<5x7x2xf32>, tensor<5x1x2xf32>) -> tensor<5x7x2xf32>
|
||||
// CHECK: "tf.Sum"(%[[v2]], %[[cst_2]]) {keep_dims = false} : (tensor<5x7x2xf32>, tensor<1xi32>) -> tensor<5x7xf32>
|
||||
}
|
||||
func @einsum_transpose_matmul(%arg0: tensor<2x5x7xf32>, %arg1: tensor<5x3x2xf32>) -> tensor<5x3x7xf32> {
|
||||
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "lbh,bkl->bkh"}: (tensor<2x5x7xf32>, tensor<5x3x2xf32>) -> tensor<5x3x7xf32>
|
||||
return %0 : tensor<5x3x7xf32>
|
||||
// CHECK-LABEL: einsum_transpose_matmul
|
||||
// CHECK: %[[cst:.*]] = constant dense<[1, 2, 0]> : tensor<3xi32>
|
||||
// CHECK: %[[cst_0:.*]] = constant dense<[0, 2, 1]> : tensor<3xi32>
|
||||
// CHECK: %[[v0:.*]] = "tf.Transpose"(%arg0, %[[cst]]) : (tensor<2x5x7xf32>, tensor<3xi32>) -> tensor<5x7x2xf32>
|
||||
// CHECK: %[[v1:.*]] = "tf.Transpose"(%arg1, %[[cst_0]]) : (tensor<5x3x2xf32>, tensor<3xi32>) -> tensor<5x2x3xf32>
|
||||
// CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) {adj_x = false, adj_y = false} : (tensor<5x7x2xf32>, tensor<5x2x3xf32>) -> tensor<5x7x3xf32>
|
||||
// CHECK: %[[v3:.*]] = "tf.Transpose"(%[[v2]], %[[cst_0]]) : (tensor<5x7x3xf32>, tensor<3xi32>) -> tensor<5x3x7xf32>
|
||||
}
|
||||
|
||||
func @einsum_4D(%arg0: tensor<2x5x7x3xf32>, %arg1: tensor<2x4x7x3xf32>) -> tensor<2x7x5x4xf32> {
|
||||
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "bfnh,btnh->bnft"}: (tensor<2x5x7x3xf32>, tensor<2x4x7x3xf32>) -> tensor<2x7x5x4xf32>
|
||||
return %0 : tensor<2x7x5x4xf32>
|
||||
|
@ -38,6 +38,24 @@ func @test_single_branch_direct_t() -> tensor<i32> {
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test_single_branch_direct_arg_f
|
||||
// CHECK: Switch
|
||||
// CHECK: tf.AddV2
|
||||
func @test_single_branch_direct_arg_f(%pred : tensor<i1>) -> tensor<i32> {
|
||||
%cst_0 = constant dense<10> : tensor<i32>
|
||||
%cst_1 = constant dense<1> : tensor<i32>
|
||||
%0 = tf_executor.graph {
|
||||
%7:3 = tf_executor.Switch %cst_0, %pred : tensor<i32>
|
||||
%8:2 = tf_executor.island {
|
||||
%12 = "tf.AddV2"(%7#1, %cst_1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
tf_executor.yield %12 : tensor<i32>
|
||||
}
|
||||
%11:3 = tf_executor.Merge %7#0, %8#0 : tensor<i32> {N = 2 : i64}
|
||||
tf_executor.fetch %11#0 : tensor<i32>
|
||||
}
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
// pred ? x + 1 : x - 1
|
||||
// CHECK-LABEL: ControlFlowTest.testCond_1f
|
||||
// CHECK-NOT: Switch
|
||||
|
@ -1,14 +1,15 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s --dump-input-on-failure
|
||||
|
||||
# Verify that the data_format attributes is pulled from the default value in the
|
||||
# registry when not present in the GraphDef
|
||||
# CHECK: tf.Conv2D
|
||||
# CHECK-SAME: data_format = "NHWC"
|
||||
|
||||
# Verify that we can also pull some attributes that are needed to be able to
|
||||
# create a Graph in memory, like `T`.
|
||||
# Verify that we don't import derived attributes as these will be added only on
|
||||
# export.
|
||||
# CHECK: tf.MaxPool
|
||||
# CHECK-SAME: T = f32
|
||||
# CHECK-NOT: T = f32
|
||||
# CHECK-SAME: : (tensor<?x112x112x32xf32>) -> tensor<?x56x56x32xf32>
|
||||
|
||||
node {
|
||||
name: "input"
|
||||
|
@ -3,6 +3,21 @@
|
||||
node {
|
||||
name: "unnamed"
|
||||
op: "foo"
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_disable_call_shape_inference"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
@ -39,7 +54,7 @@ versions {
|
||||
# Verify that functions from the library are properly imported.
|
||||
|
||||
# CHECK-LABEL: func @main() {
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo0}
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = true, f = @foo0}
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @bar0}
|
||||
|
||||
# CHECK-LABEL: func @foo0() {
|
||||
|
@ -105,7 +105,6 @@ versions {
|
||||
|
||||
# CHECK: func @main
|
||||
# CHECK: "tf.PartitionedCall"()
|
||||
# CHECK-SAME: Tout = ["tfdtype$DT_UINT8"]
|
||||
# CHECK-SAME: f = @[[FUNCTION:[A-Za-z0-9_]*]]
|
||||
# CHECK: func @[[FUNCTION]]() -> tensor<ui8>
|
||||
# CHECK: return {{.*}} : tensor<ui8>
|
||||
|
@ -0,0 +1,22 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s
|
||||
|
||||
node {
|
||||
name: "input0"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 29
|
||||
min_consumer: 12
|
||||
}
|
||||
|
||||
# Verify that functions from the library are properly imported.
|
||||
|
||||
# CHECK-LABEL: func @main() {
|
||||
# CHECK: "tf.Placeholder"
|
||||
# CHECK-NOT: _output_shapes
|
@ -1,7 +1,6 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-splatted-mlir %s -o - | FileCheck %s
|
||||
|
||||
# CHECK: tf.Const
|
||||
# CHECK-SAME: _output_shapes = ["tfshape$dim { size: 3 }"]
|
||||
# CHECK-SAME: value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B2073697A653A2033207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C30303022"> : tensor<3x!tf.string>
|
||||
|
||||
node {
|
||||
|
@ -62,4 +62,106 @@ func @transposeConv2DBackpropInput_f16(
|
||||
return %0 : tensor<1x28x28x64xf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transposeFusedBatchNormV3_f32
|
||||
func @transposeFusedBatchNormV3_f32(
|
||||
%arg0: tensor<1x28x28x64xf32>,
|
||||
%arg1: tensor<64xf32>
|
||||
) -> tensor<1x28x28x64xf32> {
|
||||
|
||||
// CHECK: "tf.FusedBatchNormV3"
|
||||
// CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %arg1, %arg1, %arg1, %arg1)
|
||||
// CHECK-SAME: data_format = "NCHW"
|
||||
%y, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3
|
||||
= "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1)
|
||||
{
|
||||
data_format = "NHWC",
|
||||
epsilon = 1.001 : f32,
|
||||
exponential_avg_factor = 1.0 : f32,
|
||||
is_training = true
|
||||
}
|
||||
: (tensor<1x28x28x64xf32>, tensor<64xf32>, tensor<64xf32>,
|
||||
tensor<64xf32>, tensor<64xf32>)
|
||||
-> (tensor<1x28x28x64xf32>, tensor<64xf32>, tensor<64xf32>,
|
||||
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
|
||||
|
||||
return %y : tensor<1x28x28x64xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transposeFusedBatchNormV3_f16
|
||||
func @transposeFusedBatchNormV3_f16(
|
||||
%arg0: tensor<1x28x28x64xf16>,
|
||||
%arg1: tensor<64xf32>
|
||||
) -> tensor<1x28x28x64xf16> {
|
||||
|
||||
// CHECK: "tf.FusedBatchNormV3"
|
||||
// CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %arg1, %arg1, %arg1, %arg1)
|
||||
// CHECK-SAME: data_format = "NCHW"
|
||||
%y, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3
|
||||
= "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1)
|
||||
{
|
||||
data_format = "NHWC",
|
||||
epsilon = 1.001 : f32,
|
||||
exponential_avg_factor = 1.0 : f32,
|
||||
is_training = true
|
||||
}
|
||||
: (tensor<1x28x28x64xf16>, tensor<64xf32>, tensor<64xf32>,
|
||||
tensor<64xf32>, tensor<64xf32>)
|
||||
-> (tensor<1x28x28x64xf16>, tensor<64xf32>, tensor<64xf32>,
|
||||
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
|
||||
|
||||
return %y : tensor<1x28x28x64xf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transposeFusedBatchNormGradV3_f32
|
||||
func @transposeFusedBatchNormGradV3_f32(
|
||||
%arg0: tensor<1x28x28x64xf32>,
|
||||
%arg1: tensor<1x28x28x64xf32>,
|
||||
%arg2: tensor<64xf32>
|
||||
) -> tensor<1x28x28x64xf32> {
|
||||
|
||||
// CHECK: "tf.FusedBatchNormGradV3"
|
||||
// CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %[[Y_TRANSPOSE:[0-9]*]],
|
||||
// CHECK-SAME: data_format = "NCHW"
|
||||
%x_backprop, %scale_backprop, %offset_backprop, %reserve_1, %reserve_2
|
||||
= "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg2, %arg2, %arg2)
|
||||
{
|
||||
data_format = "NHWC",
|
||||
epsilon = 1.001 : f32,
|
||||
exponential_avg_factor = 1.0 : f32,
|
||||
is_training = true
|
||||
}
|
||||
: (tensor<1x28x28x64xf32>, tensor<1x28x28x64xf32>,
|
||||
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
|
||||
-> (tensor<1x28x28x64xf32>,
|
||||
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
|
||||
|
||||
return %x_backprop : tensor<1x28x28x64xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transposeFusedBatchNormGradV3_f16
|
||||
func @transposeFusedBatchNormGradV3_f16(
|
||||
%arg0: tensor<1x28x28x64xf16>,
|
||||
%arg1: tensor<1x28x28x64xf16>,
|
||||
%arg2: tensor<64xf32>
|
||||
) -> tensor<1x28x28x64xf16> {
|
||||
|
||||
// CHECK: "tf.FusedBatchNormGradV3"
|
||||
// CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %[[Y_TRANSPOSE:[0-9]*]],
|
||||
// CHECK-SAME: data_format = "NCHW"
|
||||
%x_backprop, %scale_backprop, %offset_backprop, %reserve_1, %reserve_2
|
||||
= "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg2, %arg2, %arg2)
|
||||
{
|
||||
data_format = "NHWC",
|
||||
epsilon = 1.001 : f32,
|
||||
exponential_avg_factor = 1.0 : f32,
|
||||
is_training = true
|
||||
}
|
||||
: (tensor<1x28x28x64xf16>, tensor<1x28x28x64xf16>,
|
||||
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
|
||||
-> (tensor<1x28x28x64xf16>,
|
||||
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
|
||||
|
||||
return %x_backprop : tensor<1x28x28x64xf16>
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -143,4 +143,106 @@ func @transposeConv2DBackpropInput_f16(
|
||||
return %0 : tensor<1x64x28x28xf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transposeFusedBatchNormV3_f32
|
||||
func @transposeFusedBatchNormV3_f32(
|
||||
%arg0: tensor<1x28x28x64xf32>,
|
||||
%arg1: tensor<64xf32>
|
||||
) -> tensor<1x28x28x64xf32> {
|
||||
|
||||
// CHECK: "tf.FusedBatchNormV3"
|
||||
// CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %arg1, %arg1, %arg1, %arg1)
|
||||
// CHECK-SAME: data_format = "NCHW"
|
||||
%y, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3
|
||||
= "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1)
|
||||
{
|
||||
data_format = "NHWC",
|
||||
epsilon = 1.001 : f32,
|
||||
exponential_avg_factor = 1.0 : f32,
|
||||
is_training = true
|
||||
}
|
||||
: (tensor<1x28x28x64xf32>, tensor<64xf32>, tensor<64xf32>,
|
||||
tensor<64xf32>, tensor<64xf32>)
|
||||
-> (tensor<1x28x28x64xf32>, tensor<64xf32>, tensor<64xf32>,
|
||||
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
|
||||
|
||||
return %y : tensor<1x28x28x64xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transposeFusedBatchNormV3_f16
|
||||
func @transposeFusedBatchNormV3_f16(
|
||||
%arg0: tensor<1x64x28x28xf16>,
|
||||
%arg1: tensor<64xf32>
|
||||
) -> tensor<1x64x28x28xf16> {
|
||||
|
||||
// CHECK: "tf.FusedBatchNormV3"
|
||||
// CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %arg1, %arg1, %arg1, %arg1)
|
||||
// CHECK-SAME: data_format = "NHWC"
|
||||
%y, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3
|
||||
= "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1)
|
||||
{
|
||||
data_format = "NCHW",
|
||||
epsilon = 1.001 : f32,
|
||||
exponential_avg_factor = 1.0 : f32,
|
||||
is_training = true
|
||||
}
|
||||
: (tensor<1x64x28x28xf16>, tensor<64xf32>, tensor<64xf32>,
|
||||
tensor<64xf32>, tensor<64xf32>)
|
||||
-> (tensor<1x64x28x28xf16>, tensor<64xf32>, tensor<64xf32>,
|
||||
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
|
||||
|
||||
return %y : tensor<1x64x28x28xf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transposeFusedBatchNormGradV3_f32
|
||||
func @transposeFusedBatchNormGradV3_f32(
|
||||
%arg0: tensor<1x28x28x64xf32>,
|
||||
%arg1: tensor<1x28x28x64xf32>,
|
||||
%arg2: tensor<64xf32>
|
||||
) -> tensor<1x28x28x64xf32> {
|
||||
|
||||
// CHECK: "tf.FusedBatchNormGradV3"
|
||||
// CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %[[Y_TRANSPOSE:[0-9]*]],
|
||||
// CHECK-SAME: data_format = "NCHW"
|
||||
%x_backprop, %scale_backprop, %offset_backprop, %reserve_1, %reserve_2
|
||||
= "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg2, %arg2, %arg2)
|
||||
{
|
||||
data_format = "NHWC",
|
||||
epsilon = 1.001 : f32,
|
||||
exponential_avg_factor = 1.0 : f32,
|
||||
is_training = true
|
||||
}
|
||||
: (tensor<1x28x28x64xf32>, tensor<1x28x28x64xf32>,
|
||||
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
|
||||
-> (tensor<1x28x28x64xf32>,
|
||||
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
|
||||
|
||||
return %x_backprop : tensor<1x28x28x64xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transposeFusedBatchNormGradV3_f16
|
||||
func @transposeFusedBatchNormGradV3_f16(
|
||||
%arg0: tensor<1x64x28x28xf16>,
|
||||
%arg1: tensor<1x64x28x28xf16>,
|
||||
%arg2: tensor<64xf32>
|
||||
) -> tensor<1x64x28x28xf16> {
|
||||
|
||||
// CHECK: "tf.FusedBatchNormGradV3"
|
||||
// CHECK-SAME: (%[[X_TRANSPOSE:[0-9]*]], %[[Y_TRANSPOSE:[0-9]*]],
|
||||
// CHECK-SAME: data_format = "NHWC"
|
||||
%x_backprop, %scale_backprop, %offset_backprop, %reserve_1, %reserve_2
|
||||
= "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg2, %arg2, %arg2)
|
||||
{
|
||||
data_format = "NCHW",
|
||||
epsilon = 1.001 : f32,
|
||||
exponential_avg_factor = 1.0 : f32,
|
||||
is_training = true
|
||||
}
|
||||
: (tensor<1x64x28x28xf16>, tensor<1x64x28x28xf16>,
|
||||
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
|
||||
-> (tensor<1x64x28x28xf16>,
|
||||
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
|
||||
|
||||
return %x_backprop : tensor<1x64x28x28xf16>
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -146,3 +146,82 @@ func @transposeConv2DBackpropInput(
|
||||
|
||||
return %0 : tensor<1x32x32x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transposeFusedBatchNormV3
|
||||
func @transposeFusedBatchNormV3(
|
||||
%arg0: tensor<1x28x28x64xf32>,
|
||||
%arg1: tensor<64xf32>
|
||||
) -> tensor<1x28x28x64xf32> {
|
||||
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"()
|
||||
// CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
|
||||
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||
|
||||
// CHECK: "tf.FusedBatchNormV3"
|
||||
// CHECK-SAME: (%[[ARG_TRANSPOSE]], %arg1, %arg1, %arg1, %arg1)
|
||||
// CHECK-SAME: data_format = "NCHW"
|
||||
// CHECK-SAME: (tensor<1x64x28x28xf32>, tensor<64xf32>,
|
||||
// CHECK-SAME: -> (tensor<1x64x28x28xf32>, tensor<64xf32>,
|
||||
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"()
|
||||
// CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi32>}
|
||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%y, %[[RES_PERM]])
|
||||
// CHECK: return %[[RES_TRANSPOSE]]
|
||||
|
||||
%y, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3
|
||||
= "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1)
|
||||
{
|
||||
data_format = "NHWC",
|
||||
epsilon = 1.001 : f32,
|
||||
exponential_avg_factor = 1.0 : f32,
|
||||
is_training = true
|
||||
}
|
||||
: (tensor<1x28x28x64xf32>, tensor<64xf32>, tensor<64xf32>,
|
||||
tensor<64xf32>, tensor<64xf32>)
|
||||
-> (tensor<1x28x28x64xf32>, tensor<64xf32>, tensor<64xf32>,
|
||||
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
|
||||
|
||||
return %y : tensor<1x28x28x64xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transposeFusedBatchNormGradV3
|
||||
func @transposeFusedBatchNormGradV3(
|
||||
%arg0: tensor<1x28x28x64xf32>,
|
||||
%arg1: tensor<1x28x28x64xf32>,
|
||||
%arg2: tensor<64xf32>
|
||||
) -> tensor<1x28x28x64xf32> {
|
||||
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"()
|
||||
// CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
|
||||
|
||||
// CHECK: %[[ARG0_TPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||
// CHECK: %[[ARG1_TPOSE:[0-9]*]] = "tf.Transpose"(%arg1, %[[ARG_PERM]])
|
||||
|
||||
// CHECK: "tf.FusedBatchNormGradV3"
|
||||
// CHECK-SAME: (%[[ARG0_TPOSE]], %[[ARG1_TPOSE]], %arg2, %arg2, %arg2, %arg2)
|
||||
// CHECK-SAME: data_format = "NCHW"
|
||||
// CHECK-SAME: (tensor<1x64x28x28xf32>, tensor<1x64x28x28xf32>,
|
||||
// CHECK-SAME: -> (tensor<1x64x28x28xf32>,
|
||||
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"()
|
||||
// CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi32>}
|
||||
|
||||
// CHECK: %[[RES_TPOSE:[0-9]*]] = "tf.Transpose"
|
||||
// CHECK-SAME: (%x_backprop, %[[RES_PERM]])
|
||||
// CHECK: return %[[RES_TPOSE]]
|
||||
|
||||
%x_backprop, %scale_backprop, %offset_backprop, %reserve_1, %reserve_2
|
||||
= "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg2, %arg2, %arg2)
|
||||
{
|
||||
data_format = "NHWC",
|
||||
epsilon = 1.001 : f32,
|
||||
exponential_avg_factor = 1.0 : f32,
|
||||
is_training = true
|
||||
}
|
||||
: (tensor<1x28x28x64xf32>, tensor<1x28x28x64xf32>,
|
||||
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
|
||||
-> (tensor<1x28x28x64xf32>,
|
||||
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
|
||||
|
||||
return %x_backprop : tensor<1x28x28x64xf32>
|
||||
}
|
||||
|
@ -33,3 +33,40 @@ func @transposeConv2D(%input: tensor<1x3x32x32xf32>, %filter: tensor<1x1x3x8xf32
|
||||
|
||||
return %0 : tensor<1x8x32x32xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transposeFusedBatchNormV3
|
||||
func @transposeFusedBatchNormV3(
|
||||
%arg0: tensor<1x64x28x28xf32>,
|
||||
%arg1: tensor<64xf32>
|
||||
) -> tensor<1x64x28x28xf32> {
|
||||
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"()
|
||||
// CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi32>}
|
||||
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||
|
||||
// CHECK: "tf.FusedBatchNormV3"
|
||||
// CHECK-SAME: (%[[ARG_TRANSPOSE]], %arg1, %arg1, %arg1, %arg1)
|
||||
// CHECK-SAME: data_format = "NHWC"
|
||||
// CHECK-SAME: (tensor<1x28x28x64xf32>, tensor<64xf32>,
|
||||
// CHECK-SAME: -> (tensor<1x28x28x64xf32>, tensor<64xf32>,
|
||||
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"()
|
||||
// CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
|
||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%y, %[[RES_PERM]])
|
||||
// CHECK: return %[[RES_TRANSPOSE]]
|
||||
|
||||
%y, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3
|
||||
= "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1)
|
||||
{
|
||||
data_format = "NCHW",
|
||||
epsilon = 1.001 : f32,
|
||||
exponential_avg_factor = 1.0 : f32,
|
||||
is_training = true
|
||||
}
|
||||
: (tensor<1x64x28x28xf32>, tensor<64xf32>, tensor<64xf32>,
|
||||
tensor<64xf32>, tensor<64xf32>)
|
||||
-> (tensor<1x64x28x28xf32>, tensor<64xf32>, tensor<64xf32>,
|
||||
tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
|
||||
|
||||
return %y : tensor<1x64x28x28xf32>
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -6,7 +6,7 @@ func @main() {
|
||||
// CHECK-NEXT: name: "predicate"
|
||||
// CHECK-NEXT: op: "Const"
|
||||
// CHECK-NEXT: attr {
|
||||
// CHECK-NEXT: key: "dtype"
|
||||
// CHECK: key: "dtype"
|
||||
// CHECK-NEXT: value {
|
||||
// CHECK-NEXT: type: DT_INT32
|
||||
// CHECK-NEXT: }
|
||||
|
@ -22,7 +22,7 @@ func @main() {
|
||||
// CHECK-NEXT: name: "tf.Const"
|
||||
// CHECK-NEXT: op: "Const"
|
||||
// CHECK-NEXT: attr {
|
||||
// CHECK-NEXT: key: "dtype"
|
||||
// CHECK: key: "dtype"
|
||||
// CHECK-NEXT: value {
|
||||
// CHECK-NEXT: type: DT_INT32
|
||||
// CHECK-NEXT: }
|
||||
@ -43,8 +43,8 @@ func @main() {
|
||||
// CHECK-NEXT: name: "tf.Empty"
|
||||
// CHECK-NEXT: op: "Empty"
|
||||
// CHECK-NEXT: input: "tf.Const:output:0"
|
||||
// CHECK-NEXT: attr {
|
||||
// CHECK-NEXT: key: "dtype"
|
||||
// CHECK-NEXT: attr {
|
||||
// CHECK: key: "dtype"
|
||||
// CHECK-NEXT: value {
|
||||
// CHECK-NEXT: type: DT_FLOAT
|
||||
// CHECK-NEXT: }
|
||||
|
@ -0,0 +1,64 @@
|
||||
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
|
||||
|
||||
func @main(%arg0: tensor<10xi32>) -> tensor<10xi32>
|
||||
attributes {tf.entry_function = {inputs = "input0", outputs = "Placeholder"}} {
|
||||
%graph = tf_executor.graph {
|
||||
%0:2 = tf_executor.island wraps "tf.Placeholder.input"(%arg0) {device = "", dtype = "tfdtype$DT_INT32", shape = "tfshape$dim { size: 10 }"} : (tensor<10xi32>) -> tensor<10xi32>
|
||||
tf_executor.fetch %0 : tensor<10xi32>
|
||||
}
|
||||
return %graph : tensor<10xi32>
|
||||
}
|
||||
|
||||
// CHECK: node {
|
||||
// CHECK-NEXT: name: "Placeholder"
|
||||
// CHECK-NEXT: op: "Placeholder"
|
||||
// CHECK-NEXT: attr {
|
||||
// CHECK-NEXT: key: "_output_shapes"
|
||||
// CHECK-NEXT: value {
|
||||
// CHECK-NEXT: list {
|
||||
// CHECK-NEXT: shape {
|
||||
// CHECK-NEXT: dim {
|
||||
// CHECK-NEXT: size: 10
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: attr {
|
||||
// CHECK-NEXT: key: "dtype"
|
||||
// CHECK-NEXT: value {
|
||||
// CHECK-NEXT: type: DT_INT32
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: attr {
|
||||
// CHECK-NEXT: key: "shape"
|
||||
// CHECK-NEXT: value {
|
||||
// CHECK-NEXT: shape {
|
||||
// CHECK-NEXT: dim {
|
||||
// CHECK-NEXT: size: 10
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: experimental_debug_info {
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: node {
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: op: "_Retval"
|
||||
// CHECK-NEXT: input: "Placeholder"
|
||||
// CHECK-NEXT: attr {
|
||||
// CHECK-NEXT: key: "T"
|
||||
// CHECK-NEXT: value {
|
||||
// CHECK-NEXT: type: DT_INT32
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: attr {
|
||||
// CHECK-NEXT: key: "index"
|
||||
// CHECK-NEXT: value {
|
||||
// CHECK-NEXT: i: 0
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: library {
|
||||
// CHECK-NEXT: }
|
@ -34,7 +34,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: attr {
|
||||
// CHECK-NEXT: key: "dense_shapes"
|
||||
// CHECK: key: "dense_shapes"
|
||||
// CHECK-NEXT: value {
|
||||
// CHECK-NEXT: list {
|
||||
// CHECK-NEXT: shape {
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
|
||||
// CHECK: attr {
|
||||
// CHECK-NEXT: key: "dtypes"
|
||||
// CHECK: key: "dtypes"
|
||||
// CHECK-NEXT: value {
|
||||
// CHECK-NEXT: list {
|
||||
// CHECK-NEXT: type: DT_INT32
|
||||
|
@ -5,7 +5,7 @@ func @main() {
|
||||
// CHECK-NEXT: name: "Empty/shape"
|
||||
// CHECK-NEXT: op: "Const"
|
||||
// CHECK: attr {
|
||||
// CHECK-NEXT: key: "dtype"
|
||||
// CHECK: key: "dtype"
|
||||
// CHECK-NEXT: value {
|
||||
// CHECK-NEXT: type: DT_INT32
|
||||
// CHECK-NEXT: }
|
||||
|
@ -4,8 +4,8 @@ func @main() {
|
||||
^bb0:
|
||||
// CHECK: name: "node_name"
|
||||
// CHECK-NEXT: op: "Const"
|
||||
// CHECK: attr {
|
||||
// CHECK-NEXT: key: "dtype"
|
||||
// CHECK-NEXT: attr {
|
||||
// CHECK: key: "dtype"
|
||||
// CHECK-NEXT: value {
|
||||
// CHECK-NEXT: type: DT_INT32
|
||||
// CHECK-NEXT: }
|
||||
|
@ -6,7 +6,7 @@ func @main() {
|
||||
// CHECK-NEXT: name: "Const"
|
||||
// CHECK-NEXT: op: "Const"
|
||||
// CHECK-NEXT: attr {
|
||||
// CHECK-NEXT: key: "dtype"
|
||||
// CHECK: key: "dtype"
|
||||
// CHECK-NEXT: value {
|
||||
// CHECK-NEXT: type: DT_FLOAT
|
||||
// CHECK-NEXT: }
|
||||
@ -31,7 +31,7 @@ func @main() {
|
||||
// CHECK-NEXT: name: "foo"
|
||||
// CHECK-NEXT: op: "foo"
|
||||
// CHECK-NEXT: input: "Const"
|
||||
// CHECK-NEXT: experimental_debug_info {
|
||||
// CHECK: experimental_debug_info {
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
%1:2 = tf_executor.island wraps "tf.foo"(%0#0) {device = ""} : (tensor<f32>) -> tensor<*xf32> loc("foo")
|
||||
|
@ -18,6 +18,11 @@ func @foo0(%arg0: tensor<*xi32>) -> tensor<*xi32> {
|
||||
// CHECK: node {
|
||||
// CHECK: name: "tf.LegacyCall"
|
||||
// CHECK-NEXT: op: "foo0"
|
||||
// CHECK: attr {
|
||||
// CHECK-NEXT: key: "_output_shapes"
|
||||
// CHECK-NEXT: value {
|
||||
// CHECK-NEXT: list {
|
||||
// CHECK-NEXT: shape {
|
||||
|
||||
// CHECK: library {
|
||||
// CHECK-NEXT: function {
|
||||
|
@ -14,8 +14,8 @@ attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} {
|
||||
// CHECK: node {
|
||||
// CHECK-NEXT: name: "input0"
|
||||
// CHECK-NEXT: op: "Placeholder"
|
||||
// CHECK-NEXT: attr {
|
||||
// CHECK-NEXT: key: "dtype"
|
||||
// CHECK-NEXT: attr {
|
||||
// CHECK: key: "dtype"
|
||||
// CHECK-NEXT: value {
|
||||
// CHECK-NEXT: type: DT_INT32
|
||||
// CHECK-NEXT: }
|
||||
@ -36,8 +36,8 @@ attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} {
|
||||
// CHECK-NEXT: node {
|
||||
// CHECK-NEXT: name: "input1"
|
||||
// CHECK-NEXT: op: "Placeholder"
|
||||
// CHECK-NEXT: attr {
|
||||
// CHECK-NEXT: key: "dtype"
|
||||
// CHECK-NEXT: attr {
|
||||
// CHECK: key: "dtype"
|
||||
// CHECK-NEXT: value {
|
||||
// CHECK-NEXT: type: DT_INT32
|
||||
// CHECK-NEXT: }
|
||||
@ -66,7 +66,7 @@ attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} {
|
||||
// CHECK-NEXT: type: DT_INT32
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: experimental_debug_info {
|
||||
// CHECK: experimental_debug_info {
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: node {
|
||||
|
@ -124,3 +124,36 @@ func @nested_ops(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) {
|
||||
// CHECK-NEXT: tf_device.return %[[OP_C]]
|
||||
// CHECK-NEXT: device = "c"
|
||||
// CHECK-NEXT: tf_device.return %[[SHAPE]], %[[LAUNCH_A]], %[[LAUNCH_B]], %[[LAUNCH_C]]
|
||||
|
||||
|
||||
// CHECK-LABEL: func @do_not_hoist_ops_with_virtual_device
|
||||
// CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf32>, [[VAL_1:%.*]]: tensor<*xf32>)
|
||||
func @do_not_hoist_ops_with_virtual_device(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) {
|
||||
%0:8 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<*xf32>) {devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2: i32} {
|
||||
%1 = "tf.Shape"(%ri) {device = "", T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<*xf32>) -> tensor<?xi32>
|
||||
%2 = "tf.opA"(%1) {device = "TPU_REPLICATED_CORE_0"} : (tensor<?xi32>) -> tensor<*xi32>
|
||||
%3 = "tf_device.launch"() ( {
|
||||
%b = "tf.opB"(%1) : (tensor<?xi32>) -> tensor<*xi32>
|
||||
tf_device.return %b : tensor<*xi32>
|
||||
}) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor<*xi32>
|
||||
%4 = "tf_device.launch"() ( {
|
||||
%c = "tf.opC"(%1) {device = "TPU_REPLICATED_CORE_0"} : (tensor<?xi32>) -> tensor<*xi32>
|
||||
tf_device.return %c : tensor<*xi32>
|
||||
}) {device = "c"} : () -> tensor<*xi32>
|
||||
tf_device.return %1, %2, %3, %4 : tensor<?xi32>, tensor<*xi32>, tensor<*xi32>, tensor<*xi32>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: [[SHAPE:%.*]] = "tf.Shape"([[VAL_0]])
|
||||
// CHECK: tf_device.replicate({{\[}}[[VAL_0]], [[VAL_1]]] as [[VAL_4:%.*]]: tensor<*xf32>) {devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
|
||||
// CHECK: [[OP_A:%.*]] = "tf.opA"([[SHAPE]]) {device = "TPU_REPLICATED_CORE_0"} : (tensor<?xi32>) -> tensor<*xi32>
|
||||
// CHECK: [[LAUNCH_B:%.*]] = "tf_device.launch"() ( {
|
||||
// CHECK: [[OP_B:%.*]] = "tf.opB"([[SHAPE]]) : (tensor<?xi32>) -> tensor<*xi32>
|
||||
// CHECK: tf_device.return [[OP_B]] : tensor<*xi32>
|
||||
// CHECK: }) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor<*xi32>
|
||||
// CHECK: [[LAUNCH_C:%.*]] = "tf_device.launch"() ( {
|
||||
// CHECK: [[OP_C:%.*]] = "tf.opC"([[SHAPE]]) {device = "TPU_REPLICATED_CORE_0"} : (tensor<?xi32>) -> tensor<*xi32>
|
||||
// CHECK: tf_device.return [[OP_C]] : tensor<*xi32>
|
||||
// CHECK: }) {device = "c"} : () -> tensor<*xi32>
|
||||
// CHECK: tf_device.return [[SHAPE]], [[OP_A]], [[LAUNCH_B]], [[LAUNCH_C]]
|
||||
|
@ -183,7 +183,7 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
|
||||
|
||||
// CHECK-LABEL: func @reused_if_then_branch
|
||||
// CHECK-SAME: (%arg0: tensor<*xf32>) -> tensor<*xf32>
|
||||
// expected-error @+1 {{expected control flow function reused_if_then_branch to have exactly 1 use}}
|
||||
// expected-warning @+1 {{expected control flow function reused_if_then_branch to have exactly 1 use}}
|
||||
func @reused_if_then_branch(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: return
|
||||
// CHECK-SAME: tensor<*xf32>
|
||||
@ -192,7 +192,7 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
|
||||
|
||||
// CHECK-LABEL: func @reused_if_else_branch
|
||||
// CHECK-SAME: (%arg0: tensor<*xf32>) -> tensor<*xf32>
|
||||
// expected-error @+1 {{expected control flow function reused_if_else_branch to have exactly 1 use}}
|
||||
// expected-warning @+1 {{expected control flow function reused_if_else_branch to have exactly 1 use}}
|
||||
func @reused_if_else_branch(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: "tf.Identity"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "tf.Identity"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>)
|
||||
@ -278,4 +278,23 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
|
||||
func @variant_body_func(%arg0: tensor<!tf.variant<tensor<16x1xf32>>>) -> tensor<!tf.variant<tensor<16x1xf32>>> {
|
||||
return %arg0 : tensor<!tf.variant<tensor<16x1xf32>>>
|
||||
}
|
||||
|
||||
// Test propagation from called functions to the call site.
|
||||
// CHECK-LABEL: func @stateful_partitioned_call(
|
||||
// CHECK-SAME: -> tensor<20xi32>
|
||||
func @stateful_partitioned_call(%arg0: tensor<20xi32>) -> tensor<*xi32> {
|
||||
// CHECK: tf.PartitionedCall
|
||||
// CHECK-SAME: (tensor<20xi32>) -> tensor<20xi32>
|
||||
%0 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @a_called_func} : (tensor<20xi32>) -> (tensor<*xi32>)
|
||||
// CHECK: tf.StatefulPartitionedCall
|
||||
// CHECK-SAME: (tensor<20xi32>) -> tensor<20xi32>
|
||||
%1 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @stateful_partitioned_call_func} : (tensor<20xi32>) -> (tensor<*xi32>)
|
||||
return %0 : tensor<*xi32>
|
||||
}
|
||||
func @a_called_func(%arg0: tensor<?xi32>) -> (tensor<?xi32>) {
|
||||
return %arg0 : tensor<?xi32>
|
||||
}
|
||||
func @stateful_partitioned_call_func(%arg0: tensor<?xi32>) -> (tensor<?xi32>) {
|
||||
return %arg0 : tensor<?xi32>
|
||||
}
|
||||
}
|
||||
|
@ -187,3 +187,199 @@ func @main() {
|
||||
%write3 = "tf.TensorArrayWriteV3"(%grad3#0, %index, %value, %grad3#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests while loop with access to the tensor array defined outside and its
|
||||
// gradient defined inside. The gradient creation should be moved outside.
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
func @main() -> () {
|
||||
// CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
|
||||
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
|
||||
%index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
// CHECK: %[[GVAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
// CHECK: "tf.While"(%[[VAR]], %[[SIZE]], %[[GVAR]])
|
||||
%1:2 = "tf.While"(%ta#0, %size) {
|
||||
body = @while_body, cond = @while_cond, device = "", is_stateless = false}
|
||||
: (tensor<!tf.resource>, tensor<i32>) -> (tensor<!tf.resource>, tensor<i32>)
|
||||
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: "tf.Slice"(%[[READ]],
|
||||
%read = "tf.TensorArrayReadV3"(%1#0, %index, %ta#1) : (tensor<!tf.resource>, tensor<i32>, tensor<f32>) -> tensor<3xf32>
|
||||
return
|
||||
}
|
||||
// CHECK: func @while_body(%[[BARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[BARG1:.*]]: tensor<i32>, %[[BARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
|
||||
func @while_body(%arg0: tensor<!tf.resource>, %arg1: tensor<i32>) -> (tensor<!tf.resource>, tensor<i32>) {
|
||||
// CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[BARG1]], %[[CONST1]])
|
||||
%sub = "tf.Sub"(%arg1, %const1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%elem = "tf._SomeOp"() : () -> tensor<3xf32>
|
||||
%flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[BARG0]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]],
|
||||
// CHECK: "tf.AssignVariableOp"(%[[BARG0]], %[[UPDATE1]])
|
||||
%write = "tf.TensorArrayWriteV3"(%arg0, %sub, %elem, %flow) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
|
||||
%grad:2 = "tf.TensorArrayGradV3"(%arg0, %write) {source = "a"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
// CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[BARG2]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: %[[UPDATE2:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ2]],
|
||||
// CHECK: "tf.AssignVariableOp"(%[[BARG2]], %[[UPDATE2]])
|
||||
%gwrite = "tf.TensorArrayWriteV3"(%grad#0, %sub, %elem, %grad#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: return %[[BARG0]], %[[SUB]], %[[BARG2]]
|
||||
return %arg0, %sub : tensor<!tf.resource>, tensor<i32>
|
||||
}
|
||||
// CHECK: func @while_cond(%[[CARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG1:.*]]: tensor<i32>, %[[CARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
|
||||
func @while_cond(%arg0: tensor<!tf.resource>, %arg1: tensor<i32>) -> tensor<i32> {
|
||||
// CHECK-NEXT: return %[[CARG1]]
|
||||
return %arg1 : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests If op with access to the tensor array defined outside and its gradient
|
||||
// defined inside. The gradient creation should be moved outside.
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
func @main() -> () {
|
||||
// CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
|
||||
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
|
||||
%index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
// CHECK: %[[COND:.*]] = "tf._SomeOp"() : () -> tensor<i1>
|
||||
%cond = "tf._SomeOp"() : () -> tensor<i1>
|
||||
// CHECK: %[[GVAR1:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
// CHECK: %[[GVAR2:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
// CHECK: "tf.If"(%[[COND]], %[[VAR]], %[[GVAR1]], %[[GVAR2]])
|
||||
%1 = "tf.If"(%cond, %ta#0) {
|
||||
then_branch = @then_branch, else_branch = @else_branch, device = "", is_stateless = false}
|
||||
: (tensor<i1>, tensor<!tf.resource>) -> tensor<!tf.resource>
|
||||
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: "tf.Slice"(%[[READ]],
|
||||
%read = "tf.TensorArrayReadV3"(%1, %index, %ta#1) : (tensor<!tf.resource>, tensor<i32>, tensor<f32>) -> tensor<3xf32>
|
||||
return
|
||||
}
|
||||
// CHECK: func @then_branch(%[[TARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[TARG1:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[TARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
|
||||
func @then_branch(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
|
||||
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%elem = "tf._SomeOp"() : () -> tensor<3xf32>
|
||||
%flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[TARG1]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]],
|
||||
// CHECK: "tf.AssignVariableOp"(%[[TARG1]], %[[UPDATE1]])
|
||||
%grad:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "a"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
%gwrite = "tf.TensorArrayWriteV3"(%grad#0, %const1, %elem, %grad#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: return %[[TARG0]]
|
||||
return %arg0 : tensor<!tf.resource>
|
||||
}
|
||||
// CHECK: func @else_branch(%[[EARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[EARG1:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[EARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
|
||||
func @else_branch(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
|
||||
// CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%elem = "tf._SomeOp"() : () -> tensor<3xf32>
|
||||
%flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[EARG2]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: %[[UPDATE2:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ2]],
|
||||
// CHECK: "tf.AssignVariableOp"(%[[EARG2]], %[[UPDATE2]])
|
||||
%grad:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "b"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
%gwrite = "tf.TensorArrayWriteV3"(%grad#0, %const1, %elem, %grad#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: return %[[EARG0]]
|
||||
return %arg0 : tensor<!tf.resource>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests (Stateful)PartitionedCall op with access to the tensor array defined
|
||||
// outside and its gradient defined inside. The gradient creation should be
|
||||
// moved outside.
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
func @main() -> () {
|
||||
// CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
|
||||
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
|
||||
%index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
// CHECK: %[[COND:.*]] = "tf._SomeOp"() : () -> tensor<i1>
|
||||
%cond = "tf._SomeOp"() : () -> tensor<i1>
|
||||
// CHECK: %[[GVAR1:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
%grad:2 = "tf.TensorArrayGradV3"(%ta#0, %ta#1) {source = "a"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
// CHECK: %[[GVAR2:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
// CHECK: "tf.StatefulPartitionedCall"(%[[VAR]], %[[GVAR1]], %[[GVAR2]])
|
||||
// CHECK-SAME: f = @callee_tensorarray_decomposed
|
||||
%call = "tf.StatefulPartitionedCall"(%ta#0) {f = @callee, config = "", config_proto = "", executor_type = ""}
|
||||
: (tensor<!tf.resource>) -> tensor<!tf.resource>
|
||||
// CHECK: "tf.PartitionedCall"(%[[VAR]], %[[GVAR1]], %[[GVAR2]])
|
||||
// CHECK-SAME: f = @callee_tensorarray_decomposed
|
||||
%call2 = "tf.PartitionedCall"(%call) {f = @callee, config = "", config_proto = "", executor_type = ""}
|
||||
: (tensor<!tf.resource>) -> tensor<!tf.resource>
|
||||
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: "tf.Slice"(%[[READ]],
|
||||
%read = "tf.TensorArrayReadV3"(%call2, %index, %ta#1) : (tensor<!tf.resource>, tensor<i32>, tensor<f32>) -> tensor<3xf32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @callee
|
||||
// CHECK-SAME: (%[[OCARG0:.*]]: tensor<!tf.resource>) -> tensor<!tf.resource>
|
||||
func @callee(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
|
||||
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%elem = "tf._SomeOp"() : () -> tensor<3xf32>
|
||||
%flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
|
||||
%grad:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "a"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
%gwrite = "tf.TensorArrayWriteV3"(%grad#0, %const1, %elem, %grad#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
|
||||
%grad2:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "b"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
%gwrite2 = "tf.TensorArrayWriteV3"(%grad2#0, %const1, %elem, %grad2#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
|
||||
return %arg0 : tensor<!tf.resource>
|
||||
}
|
||||
// CHECK: func @callee_tensorarray_decomposed(%[[CARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG1:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
|
||||
// CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[CARG1]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]],
|
||||
// CHECK: "tf.AssignVariableOp"(%[[CARG1]], %[[UPDATE1]])
|
||||
// CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[CARG2]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: %[[UPDATE2:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ2]],
|
||||
// CHECK: "tf.AssignVariableOp"(%[[CARG2]], %[[UPDATE2]])
|
||||
// CHECK: return %[[CARG0]]
|
||||
|
||||
// -----
|
||||
|
||||
// Test the pass reports failure on unknown size.
|
||||
|
||||
func @main(%arg0: tensor<i32>) -> () {
|
||||
// expected-error @+1 {{unknown max element count}}
|
||||
%ta:2 = "tf.TensorArrayV3"(%arg0) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test the pass reports failure on unknown shape.
|
||||
|
||||
func @main(%arg0: tensor<i32>) -> () {
|
||||
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
|
||||
// expected-error @+1 {{unknown element shape}}
|
||||
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$unknown_rank: true", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass reports error on ambiguous tensor array.
|
||||
|
||||
func @main(%arg0: tensor<i1>) -> () {
|
||||
%size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
|
||||
%ta0:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
%ta1:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
%if_op = "tf.If"(%arg0, %ta0#0, %ta1#0) {then_branch = @if_then, else_branch = @if_else, is_stateless = false}
|
||||
: (tensor<i1>, tensor<!tf.resource>, tensor<!tf.resource>) -> tensor<!tf.resource>
|
||||
%index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// expected-error @+1 {{unknown tensor array}}
|
||||
%read = "tf.TensorArrayReadV3"(%if_op, %index, %ta0#1) : (tensor<!tf.resource>, tensor<i32>, tensor<f32>) -> tensor<3xf32>
|
||||
return
|
||||
}
|
||||
func @if_then(%arg0: tensor<!tf.resource>, %arg1: tensor<!tf.resource>) -> tensor<!tf.resource> {
|
||||
return %arg0 : tensor<!tf.resource>
|
||||
}
|
||||
func @if_else(%arg0: tensor<!tf.resource>, %arg1: tensor<!tf.resource>) -> tensor<!tf.resource> {
|
||||
return %arg1 : tensor<!tf.resource>
|
||||
}
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <climits>
|
||||
#include <cstdint>
|
||||
|
||||
@ -49,6 +50,9 @@ enum EinsumEquation {
|
||||
FourDMatrixDotProd,
|
||||
ThreeDReshapeTail,
|
||||
FourDBatchMatMul,
|
||||
BroadcastMatMul,
|
||||
ReduceSum,
|
||||
TransposeMatMul,
|
||||
UnsupportedEquation
|
||||
};
|
||||
|
||||
@ -121,6 +125,18 @@ EinsumEquation parseEquation(const std::vector<EquationToken>& eqn) {
|
||||
if (is_equal(eqn, {A, B, C, COMMA, C, D, E, ARROW, A, B, D, E})) {
|
||||
return EinsumEquation::ThreeDReshapeTail;
|
||||
}
|
||||
// BFH,HO->BFO
|
||||
if (is_equal(eqn, {A, B, C, COMMA, C, D, ARROW, A, B, D})) {
|
||||
return EinsumEquation::BroadcastMatMul;
|
||||
}
|
||||
// LBH,BL->BH
|
||||
if (is_equal(eqn, {A, B, C, COMMA, B, A, ARROW, B, C})) {
|
||||
return EinsumEquation::ReduceSum;
|
||||
}
|
||||
// LBH,BKL->BKH
|
||||
if (is_equal(eqn, {A, B, C, COMMA, B, D, A, ARROW, B, D, C})) {
|
||||
return EinsumEquation::TransposeMatMul;
|
||||
}
|
||||
return EinsumEquation::UnsupportedEquation;
|
||||
}
|
||||
|
||||
@ -151,6 +167,28 @@ TF::TransposeOp createTransposeOp(Value value, Location loc,
|
||||
perm_op);
|
||||
}
|
||||
|
||||
TF::SumOp createSumOp(Value value, Location loc,
|
||||
llvm::ArrayRef<int32_t> redux_axes,
|
||||
PatternRewriter* rewriter) {
|
||||
auto value_type = value.getType().cast<RankedTensorType>();
|
||||
auto shape = value_type.getShape();
|
||||
auto redux_type = RankedTensorType::get(
|
||||
{static_cast<int32_t>(redux_axes.size())}, rewriter->getIntegerType(32));
|
||||
auto redux_attr = DenseElementsAttr::get(redux_type, redux_axes);
|
||||
auto redux_op = rewriter->create<ConstantOp>(loc, redux_type, redux_attr);
|
||||
std::vector<int64_t> sum_shape(shape.size() - redux_axes.size());
|
||||
int count = 0;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
if (std::find(redux_axes.begin(), redux_axes.end(), i) ==
|
||||
redux_axes.end()) {
|
||||
sum_shape[count] = shape[i];
|
||||
count++;
|
||||
}
|
||||
}
|
||||
auto sum_type = RankedTensorType::get(sum_shape, value_type.getElementType());
|
||||
return rewriter->create<TF::SumOp>(loc, sum_type, value, redux_op);
|
||||
}
|
||||
|
||||
TF::ReshapeOp createReshapeOp(Value value, ArrayRef<int64_t> shape,
|
||||
Type element_type, Location loc,
|
||||
PatternRewriter* rewriter) {
|
||||
@ -173,7 +211,6 @@ LogicalResult ConvertTFEinsumOp::matchAndRewrite(
|
||||
Value lhs = op.getOperand(0);
|
||||
Value rhs = op.getOperand(1);
|
||||
Location loc = op.getLoc();
|
||||
|
||||
if (!lhs.getType().isa<RankedTensorType>()) {
|
||||
// LHS must be a ranked tensor type
|
||||
return failure();
|
||||
@ -193,10 +230,10 @@ LogicalResult ConvertTFEinsumOp::matchAndRewrite(
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Currently support use cases of LHS, RHS dims = 3 or 4
|
||||
// Currently support use cases of LHS dims \in {3,4} RHS dims \in {2, 3, 4}
|
||||
const int dims_lhs = lhs_shape.size();
|
||||
const int dims_rhs = rhs_shape.size();
|
||||
if (dims_rhs < 3 || dims_rhs > 4 || dims_lhs < 3 || dims_lhs > 4) {
|
||||
if (dims_lhs < 3 || dims_lhs > 4 || dims_rhs < 2 || dims_lhs > 4) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
@ -209,6 +246,46 @@ LogicalResult ConvertTFEinsumOp::matchAndRewrite(
|
||||
rewriter.replaceOp(op, bmm_op.getResult());
|
||||
return success();
|
||||
}
|
||||
if (einsum_eqn == EinsumEquation::BroadcastMatMul) {
|
||||
// Case "BFH,HO->BFO"
|
||||
auto bmm_op = rewriter.create<TF::BatchMatMulV2Op>(
|
||||
loc, ArrayRef<Type>{output_type}, lhs, rhs, rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false));
|
||||
rewriter.replaceOp(op, bmm_op.getResult());
|
||||
return success();
|
||||
}
|
||||
if (einsum_eqn == EinsumEquation::ReduceSum) {
|
||||
// Case "LBH,BL->BH"
|
||||
// Transpose LHS
|
||||
lhs = createTransposeOp(lhs, loc, {1, 2, 0}, &rewriter);
|
||||
// Reshape RHS
|
||||
auto rhs_element_type = rhs_type.getElementType();
|
||||
const int rhs_dim0 = rhs_shape[0];
|
||||
const int rhs_dim1 = rhs_shape[1];
|
||||
auto reshaped_rhs = createReshapeOp(rhs, {rhs_dim0, 1, rhs_dim1},
|
||||
rhs_element_type, loc, &rewriter);
|
||||
auto mul_op = rewriter.create<TF::MulOp>(loc, lhs, reshaped_rhs);
|
||||
|
||||
auto sum_op = createSumOp(mul_op, loc, {2}, &rewriter);
|
||||
rewriter.replaceOp(op, {sum_op.getResult()});
|
||||
return success();
|
||||
}
|
||||
if (einsum_eqn == EinsumEquation::TransposeMatMul) {
|
||||
// Case "LBH,BKL->BKH"
|
||||
// Transpose LHS
|
||||
lhs = createTransposeOp(lhs, loc, {1, 2, 0}, &rewriter);
|
||||
// Transpose RHS
|
||||
rhs = createTransposeOp(rhs, loc, {0, 2, 1}, &rewriter);
|
||||
std::vector<int64_t> bmm_shape = {lhs_shape[1], lhs_shape[2], rhs_shape[1]};
|
||||
auto bmm_type = RankedTensorType::get(bmm_shape, rhs_type.getElementType());
|
||||
auto bmm_op = rewriter.create<TF::BatchMatMulV2Op>(
|
||||
loc, ArrayRef<Type>{bmm_type}, lhs, rhs, rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false));
|
||||
|
||||
auto trans_bmm = createTransposeOp(bmm_op, loc, {0, 2, 1}, &rewriter);
|
||||
rewriter.replaceOp(op, {trans_bmm.getResult()});
|
||||
return success();
|
||||
}
|
||||
if (einsum_eqn == EinsumEquation::ThreeDReshapeTail) {
|
||||
// Case "BFD,DNH->BFNH"
|
||||
auto lhs_type = lhs.getType().cast<RankedTensorType>();
|
||||
|
@ -67,7 +67,7 @@ class SwitchFoldPass : public mlir::FunctionPass<SwitchFoldPass> {
|
||||
// Returns the defining op for a value looking through islands.
|
||||
static Operation* GetDefiningOp(Value val) {
|
||||
Operation* op = val.getDefiningOp();
|
||||
auto island_op = dyn_cast<tf_executor::IslandOp>(op);
|
||||
auto island_op = dyn_cast_or_null<tf_executor::IslandOp>(op);
|
||||
if (!island_op) return op;
|
||||
auto yield_op = island_op.GetYield();
|
||||
auto index = val.cast<mlir::OpResult>().getResultNumber();
|
||||
@ -84,7 +84,8 @@ static Operation* GetDefiningOp(Value val) {
|
||||
static Value LookThroughIdentityOp(Value pred_val) {
|
||||
if (!pred_val) return pred_val;
|
||||
auto op = GetDefiningOp(pred_val);
|
||||
if (auto id_op = dyn_cast<TF::IdentityOp>(op)) pred_val = id_op.input();
|
||||
if (auto id_op = dyn_cast_or_null<TF::IdentityOp>(op))
|
||||
pred_val = id_op.input();
|
||||
return pred_val;
|
||||
}
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user