Merge branch 'master' into google_upstream_rocblas_complex
This commit is contained in:
commit
b64dde60e8
2
.bazelrc
2
.bazelrc
@ -279,7 +279,6 @@ build:windows --host_linkopt=/OPT:REF
|
|||||||
build:windows --linkopt=/OPT:ICF
|
build:windows --linkopt=/OPT:ICF
|
||||||
build:windows --host_linkopt=/OPT:ICF
|
build:windows --host_linkopt=/OPT:ICF
|
||||||
build:windows --experimental_strict_action_env=true
|
build:windows --experimental_strict_action_env=true
|
||||||
build:windows --incompatible_windows_native_test_wrapper
|
|
||||||
|
|
||||||
# Verbose failure logs when something goes wrong
|
# Verbose failure logs when something goes wrong
|
||||||
build:windows --verbose_failures
|
build:windows --verbose_failures
|
||||||
@ -344,6 +343,7 @@ build:rbe_linux --config=avx_linux
|
|||||||
build:rbe_linux --config=short_logs
|
build:rbe_linux --config=short_logs
|
||||||
# TODO(gunan): Check why we need this specified in rbe, but not in other builds.
|
# TODO(gunan): Check why we need this specified in rbe, but not in other builds.
|
||||||
build:rbe_linux --linkopt=-lrt
|
build:rbe_linux --linkopt=-lrt
|
||||||
|
build:rbe_linux --linkopt=-lm
|
||||||
|
|
||||||
build:rbe_cpu_linux --config=rbe_linux
|
build:rbe_cpu_linux --config=rbe_linux
|
||||||
build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain"
|
build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain"
|
||||||
|
@ -1 +1 @@
|
|||||||
1.1.0
|
1.2.1
|
||||||
|
178
RELEASE.md
178
RELEASE.md
File diff suppressed because one or more lines are too long
38
WORKSPACE
38
WORKSPACE
@ -1,11 +1,13 @@
|
|||||||
workspace(name = "org_tensorflow")
|
workspace(name = "org_tensorflow")
|
||||||
|
|
||||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file")
|
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||||
|
load("//third_party:repo.bzl", "tf_http_archive")
|
||||||
|
|
||||||
http_archive(
|
tf_http_archive(
|
||||||
name = "io_bazel_rules_closure",
|
name = "io_bazel_rules_closure",
|
||||||
sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9",
|
sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9",
|
||||||
strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149",
|
strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149",
|
||||||
|
patch_file = "@org_tensorflow//third_party:rules_closure.patch",
|
||||||
urls = [
|
urls = [
|
||||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz",
|
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz",
|
||||||
"https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13
|
"https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13
|
||||||
@ -48,38 +50,6 @@ load("//third_party/toolchains/preconfig/generate:workspace.bzl",
|
|||||||
|
|
||||||
remote_config_workspace()
|
remote_config_workspace()
|
||||||
|
|
||||||
# Apple and Swift rules.
|
|
||||||
http_archive(
|
|
||||||
name = "build_bazel_rules_apple",
|
|
||||||
sha256 = "a045a436b642c70fb0c10ca84ff0fd2dcbd59cc89100d597a61e8374afafb366",
|
|
||||||
urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.18.0/rules_apple.0.18.0.tar.gz"],
|
|
||||||
) # https://github.com/bazelbuild/rules_apple/releases
|
|
||||||
http_archive(
|
|
||||||
name = "build_bazel_rules_swift",
|
|
||||||
sha256 = "18cd4df4e410b0439a4935f9ca035bd979993d42372ba79e7f2d4fafe9596ef0",
|
|
||||||
urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.12.1/rules_swift.0.12.1.tar.gz"],
|
|
||||||
) # https://github.com/bazelbuild/rules_swift/releases
|
|
||||||
http_archive(
|
|
||||||
name = "build_bazel_apple_support",
|
|
||||||
sha256 = "122ebf7fe7d1c8e938af6aeaee0efe788a3a2449ece5a8d6a428cb18d6f88033",
|
|
||||||
urls = ["https://github.com/bazelbuild/apple_support/releases/download/0.7.1/apple_support.0.7.1.tar.gz"],
|
|
||||||
) # https://github.com/bazelbuild/apple_support/releases
|
|
||||||
http_archive(
|
|
||||||
name = "bazel_skylib",
|
|
||||||
sha256 = "1dde365491125a3db70731e25658dfdd3bc5dbdfd11b840b3e987ecf043c7ca0",
|
|
||||||
urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.9.0/bazel-skylib.0.9.0.tar.gz"],
|
|
||||||
) # https://github.com/bazelbuild/bazel-skylib/releases
|
|
||||||
http_archive(
|
|
||||||
name = "com_github_apple_swift_swift_protobuf",
|
|
||||||
type = "zip",
|
|
||||||
strip_prefix = "swift-protobuf-1.6.0/",
|
|
||||||
urls = ["https://github.com/apple/swift-protobuf/archive/1.6.0.zip"],
|
|
||||||
) # https://github.com/apple/swift-protobuf/releases
|
|
||||||
http_file(
|
|
||||||
name = "xctestrunner",
|
|
||||||
executable = 1,
|
|
||||||
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.9/ios_test_runner.par"],
|
|
||||||
) # https://github.com/google/xctestrunner/releases
|
|
||||||
# Use `swift_rules_dependencies` to fetch the toolchains. With the
|
# Use `swift_rules_dependencies` to fetch the toolchains. With the
|
||||||
# `git_repository` rules above, the following call will skip redefining them.
|
# `git_repository` rules above, the following call will skip redefining them.
|
||||||
load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies")
|
load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies")
|
||||||
|
@ -49,8 +49,8 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
|
|||||||
_TF_WORKSPACE_ROOT = ''
|
_TF_WORKSPACE_ROOT = ''
|
||||||
_TF_BAZELRC = ''
|
_TF_BAZELRC = ''
|
||||||
_TF_CURRENT_BAZEL_VERSION = None
|
_TF_CURRENT_BAZEL_VERSION = None
|
||||||
_TF_MIN_BAZEL_VERSION = '1.0.0'
|
_TF_MIN_BAZEL_VERSION = '1.2.1'
|
||||||
_TF_MAX_BAZEL_VERSION = '1.1.0'
|
_TF_MAX_BAZEL_VERSION = '1.2.1'
|
||||||
|
|
||||||
NCCL_LIB_PATHS = [
|
NCCL_LIB_PATHS = [
|
||||||
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
|
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
# TensorFlow is a computational framework, primarily for use in machine
|
# TensorFlow is a computational framework, primarily for use in machine
|
||||||
# learning applications.
|
# learning applications.
|
||||||
|
|
||||||
|
load("@bazel_skylib//lib:selects.bzl", "selects")
|
||||||
load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary")
|
load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary")
|
||||||
load(
|
load(
|
||||||
"//tensorflow/core/platform:build_config.bzl",
|
"//tensorflow/core/platform:build_config.bzl",
|
||||||
@ -478,6 +479,7 @@ bzl_library(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core/platform:build_config_root_bzl",
|
"//tensorflow/core/platform:build_config_root_bzl",
|
||||||
|
"//tensorflow/core/platform:rules_cc_bzl",
|
||||||
"//tensorflow/core/platform/default:cuda_build_defs_bzl",
|
"//tensorflow/core/platform/default:cuda_build_defs_bzl",
|
||||||
"//third_party/mkl:build_defs_bzl",
|
"//third_party/mkl:build_defs_bzl",
|
||||||
"//third_party/mkl_dnn:build_defs_bzl",
|
"//third_party/mkl_dnn:build_defs_bzl",
|
||||||
|
@ -23,10 +23,6 @@ from __future__ import print_function
|
|||||||
# pylint: disable=g-bad-import-order
|
# pylint: disable=g-bad-import-order
|
||||||
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
||||||
|
|
||||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
|
||||||
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
|
|
||||||
del LazyLoader
|
|
||||||
|
|
||||||
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
|
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
|
||||||
from tensorflow.python.platform import app # pylint: disable=g-import-not-at-top
|
from tensorflow.python.platform import app # pylint: disable=g-import-not-at-top
|
||||||
app.flags = flags
|
app.flags = flags
|
||||||
|
@ -302,6 +302,7 @@ tf_cuda_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||||
|
"//tensorflow/core/common_runtime/eager:context",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||||
"//tensorflow/core/platform",
|
"//tensorflow/core/platform",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
@ -458,7 +458,7 @@ static void TF_Run_Helper(
|
|||||||
EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
|
EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
c_outputs[i] = TF_TensorFromTensor(src, status);
|
c_outputs[i] = TF_TensorFromTensor(src, &status->status);
|
||||||
if (!status->status.ok()) return;
|
if (!status->status.ok()) return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1493,7 +1493,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
|
|||||||
Tensor t;
|
Tensor t;
|
||||||
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
|
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
|
||||||
if (!status->status.ok()) return;
|
if (!status->status.ok()) return;
|
||||||
*value = TF_TensorFromTensor(t, status);
|
*value = TF_TensorFromTensor(t, &status->status);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
|
void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
|
||||||
@ -1504,7 +1504,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
|
|||||||
if (!status->status.ok()) return;
|
if (!status->status.ok()) return;
|
||||||
const auto len = std::min(max_values, static_cast<int>(ts.size()));
|
const auto len = std::min(max_values, static_cast<int>(ts.size()));
|
||||||
for (int i = 0; i < len; ++i) {
|
for (int i = 0; i < len; ++i) {
|
||||||
values[i] = TF_TensorFromTensor(ts[i], status);
|
values[i] = TF_TensorFromTensor(ts[i], &status->status);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2398,7 +2398,7 @@ unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
|
|||||||
graph->graph.versions().producer(), &evaluated, &result_tensor);
|
graph->graph.versions().producer(), &evaluated, &result_tensor);
|
||||||
if (evaluated) {
|
if (evaluated) {
|
||||||
DCHECK(status->status.ok());
|
DCHECK(status->status.ok());
|
||||||
*result = TF_TensorFromTensor(result_tensor, status);
|
*result = TF_TensorFromTensor(result_tensor, &status->status);
|
||||||
if (!status->status.ok()) evaluated = false;
|
if (!status->status.ok()) evaluated = false;
|
||||||
}
|
}
|
||||||
return evaluated;
|
return evaluated;
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
#include "tensorflow/compiler/jit/flags.h"
|
#include "tensorflow/compiler/jit/flags.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/framework/shape_inference.h"
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
@ -549,7 +550,7 @@ TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(TFE_Op* op,
|
|||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification;
|
TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification;
|
||||||
|
|
||||||
n->thread.reset(op->operation.EagerContext()->TFEnv()->StartThread(
|
n->thread.reset(op->operation.EagerContext().TFEnv()->StartThread(
|
||||||
tensorflow::ThreadOptions(), "ExecuteOpThread",
|
tensorflow::ThreadOptions(), "ExecuteOpThread",
|
||||||
[op, retvals, num_retvals, n]() {
|
[op, retvals, num_retvals, n]() {
|
||||||
TFE_Execute(op, retvals, num_retvals, n->status.get());
|
TFE_Execute(op, retvals, num_retvals, n->status.get());
|
||||||
@ -634,7 +635,7 @@ TF_Tensor* TF_CheckpointReaderGetTensor(TF_CheckpointReader* reader,
|
|||||||
std::unique_ptr<tensorflow::Tensor> tensor;
|
std::unique_ptr<tensorflow::Tensor> tensor;
|
||||||
reader->GetTensor(name, &tensor, status);
|
reader->GetTensor(name, &tensor, status);
|
||||||
if (!status->status.ok()) return nullptr;
|
if (!status->status.ok()) return nullptr;
|
||||||
return tensorflow::TF_TensorFromTensor(*tensor, status);
|
return tensorflow::TF_TensorFromTensor(*tensor, &status->status);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader,
|
void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader,
|
||||||
@ -767,8 +768,9 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
|
|||||||
} while (0);
|
} while (0);
|
||||||
|
|
||||||
// New server created for new server_def. Unused if updating server_def.
|
// New server created for new server_def. Unused if updating server_def.
|
||||||
|
tensorflow::EagerContext* context = ctx->context;
|
||||||
tensorflow::GrpcServer* grpc_server =
|
tensorflow::GrpcServer* grpc_server =
|
||||||
dynamic_cast<tensorflow::GrpcServer*>(ctx->context->GetServer());
|
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||||
if (grpc_server == nullptr) {
|
if (grpc_server == nullptr) {
|
||||||
std::unique_ptr<tensorflow::ServerInterface> new_server;
|
std::unique_ptr<tensorflow::ServerInterface> new_server;
|
||||||
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
|
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
|
||||||
@ -779,12 +781,12 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
|
|||||||
}
|
}
|
||||||
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
|
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
|
||||||
|
|
||||||
LOG_AND_RETURN_IF_ERROR(ctx->context->StoreCollectiveOpsServer(
|
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
|
||||||
std::move(new_server), grpc_server->worker_env()->device_mgr,
|
std::move(new_server), grpc_server->worker_env()->device_mgr,
|
||||||
grpc_server->worker_env()->collective_executor_mgr));
|
grpc_server->worker_env()->collective_executor_mgr));
|
||||||
} else {
|
} else {
|
||||||
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
|
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
|
||||||
LOG_AND_RETURN_IF_ERROR(ctx->context->StoreCollectiveOpsServer(
|
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
|
||||||
/*new_server=*/nullptr, grpc_server->worker_env()->device_mgr,
|
/*new_server=*/nullptr, grpc_server->worker_env()->device_mgr,
|
||||||
grpc_server->worker_env()->collective_executor_mgr));
|
grpc_server->worker_env()->collective_executor_mgr));
|
||||||
}
|
}
|
||||||
|
@ -1260,11 +1260,10 @@ TEST_F(CApiFunctionTest, GraphToFunctionDefWithPlaceholderAttr) {
|
|||||||
NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node3", "v2",
|
NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node3", "v2",
|
||||||
&node3);
|
&node3);
|
||||||
|
|
||||||
TF_Output inputs[] = {};
|
|
||||||
TF_Output outputs[] = {{node1, 0}, {node2, 0}, {node3, 0}};
|
TF_Output outputs[] = {{node1, 0}, {node2, 0}, {node3, 0}};
|
||||||
func_ = TF_GraphToFunction(
|
func_ = TF_GraphToFunction(
|
||||||
func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
|
func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
|
||||||
/*opers=*/nullptr, 0, inputs, 3, outputs,
|
/*opers=*/nullptr, 0, nullptr, 3, outputs,
|
||||||
/*output_names=*/nullptr,
|
/*output_names=*/nullptr,
|
||||||
/*opts=*/nullptr, /*description=*/nullptr, s.get());
|
/*opts=*/nullptr, /*description=*/nullptr, s.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
|
||||||
@ -1300,10 +1299,9 @@ TEST_F(CApiFunctionTest, GraphToFunctionDefWithArgAttr) {
|
|||||||
&node);
|
&node);
|
||||||
|
|
||||||
TF_Output inputs[] = {{node, 0}};
|
TF_Output inputs[] = {{node, 0}};
|
||||||
TF_Output outputs[] = {};
|
|
||||||
func_ = TF_GraphToFunction(
|
func_ = TF_GraphToFunction(
|
||||||
func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
|
func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
|
||||||
/*opers=*/nullptr, 1, inputs, 0, outputs,
|
/*opers=*/nullptr, 1, inputs, 0, nullptr,
|
||||||
/*output_names=*/nullptr,
|
/*output_names=*/nullptr,
|
||||||
/*opts=*/nullptr, /*description=*/nullptr, s.get());
|
/*opts=*/nullptr, /*description=*/nullptr, s.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
|
||||||
@ -1603,11 +1601,10 @@ void DefineStatefulFunction(const char* name, TF_Function** func) {
|
|||||||
TF_Operation* random =
|
TF_Operation* random =
|
||||||
RandomUniform(shape, TF_FLOAT, func_graph.get(), s.get());
|
RandomUniform(shape, TF_FLOAT, func_graph.get(), s.get());
|
||||||
|
|
||||||
TF_Output inputs[] = {};
|
|
||||||
TF_Output outputs[] = {{random, 0}};
|
TF_Output outputs[] = {{random, 0}};
|
||||||
*func = TF_GraphToFunction(func_graph.get(), name,
|
*func = TF_GraphToFunction(func_graph.get(), name,
|
||||||
/*append_hash_to_fn_name=*/false, -1,
|
/*append_hash_to_fn_name=*/false, -1,
|
||||||
/*opers=*/nullptr, 0, inputs, 1, outputs,
|
/*opers=*/nullptr, 0, nullptr, 1, outputs,
|
||||||
/*output_names=*/nullptr,
|
/*output_names=*/nullptr,
|
||||||
/*opts=*/nullptr, "", s.get());
|
/*opts=*/nullptr, "", s.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
|
||||||
|
@ -188,7 +188,7 @@ namespace tensorflow {
|
|||||||
|
|
||||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
||||||
|
|
||||||
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
|
TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
|
||||||
|
|
||||||
Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
|
Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
|
||||||
TF_Buffer* out);
|
TF_Buffer* out);
|
||||||
|
@ -51,7 +51,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/util/equal_graph_def.h"
|
#include "tensorflow/core/util/equal_graph_def.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
|
TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
|
||||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -227,7 +227,7 @@ TEST(CAPI, LibraryLoadFunctions) {
|
|||||||
|
|
||||||
void TestEncodeDecode(int line, const std::vector<string>& data) {
|
void TestEncodeDecode(int line, const std::vector<string>& data) {
|
||||||
const tensorflow::int64 n = data.size();
|
const tensorflow::int64 n = data.size();
|
||||||
TF_Status* status = TF_NewStatus();
|
Status status;
|
||||||
for (const std::vector<tensorflow::int64>& dims :
|
for (const std::vector<tensorflow::int64>& dims :
|
||||||
std::vector<std::vector<tensorflow::int64>>{
|
std::vector<std::vector<tensorflow::int64>>{
|
||||||
{n}, {1, n}, {n, 1}, {n / 2, 2}}) {
|
{n}, {1, n}, {n, 1}, {n / 2, 2}}) {
|
||||||
@ -236,8 +236,8 @@ void TestEncodeDecode(int line, const std::vector<string>& data) {
|
|||||||
for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) {
|
for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) {
|
||||||
src.flat<tstring>()(i) = data[i];
|
src.flat<tstring>()(i) = data[i];
|
||||||
}
|
}
|
||||||
TF_Tensor* dst = TF_TensorFromTensor(src, status);
|
TF_Tensor* dst = TF_TensorFromTensor(src, &status);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
ASSERT_TRUE(status.ok()) << status.error_message();
|
||||||
|
|
||||||
// Convert back to a C++ Tensor and ensure we get expected output.
|
// Convert back to a C++ Tensor and ensure we get expected output.
|
||||||
Tensor output;
|
Tensor output;
|
||||||
@ -249,7 +249,6 @@ void TestEncodeDecode(int line, const std::vector<string>& data) {
|
|||||||
|
|
||||||
TF_DeleteTensor(dst);
|
TF_DeleteTensor(dst);
|
||||||
}
|
}
|
||||||
TF_DeleteStatus(status);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CAPI, TensorEncodeDecodeStrings) {
|
TEST(CAPI, TensorEncodeDecodeStrings) {
|
||||||
@ -1394,8 +1393,9 @@ TEST(CAPI, SavedModel) {
|
|||||||
TF_Operation* input_op =
|
TF_Operation* input_op =
|
||||||
TF_GraphOperationByName(graph, input_op_name.c_str());
|
TF_GraphOperationByName(graph, input_op_name.c_str());
|
||||||
ASSERT_TRUE(input_op != nullptr);
|
ASSERT_TRUE(input_op != nullptr);
|
||||||
csession.SetInputs({{input_op, TF_TensorFromTensor(input, s)}});
|
Status status;
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
csession.SetInputs({{input_op, TF_TensorFromTensor(input, &status)}});
|
||||||
|
ASSERT_TRUE(status.ok()) << status.error_message();
|
||||||
|
|
||||||
const tensorflow::string output_op_name(
|
const tensorflow::string output_op_name(
|
||||||
tensorflow::ParseTensorName(output_name).first);
|
tensorflow::ParseTensorName(output_name).first);
|
||||||
@ -2522,12 +2522,11 @@ TEST(CAPI, TestTensorIsNotAligned) {
|
|||||||
|
|
||||||
// Take an unaligned slice.
|
// Take an unaligned slice.
|
||||||
Tensor y = x.Slice(1, 13);
|
Tensor y = x.Slice(1, 13);
|
||||||
TF_Status* status = TF_NewStatus();
|
Status status;
|
||||||
TF_Tensor* a = TF_TensorFromTensor(y, status);
|
TF_Tensor* a = TF_TensorFromTensor(y, &status);
|
||||||
if (EIGEN_MAX_ALIGN_BYTES > 0) {
|
if (EIGEN_MAX_ALIGN_BYTES > 0) {
|
||||||
EXPECT_FALSE(TF_TensorIsAligned(a));
|
EXPECT_FALSE(TF_TensorIsAligned(a));
|
||||||
}
|
}
|
||||||
TF_DeleteStatus(status);
|
|
||||||
TF_DeleteTensor(a);
|
TF_DeleteTensor(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
|||||||
#include <memory.h>
|
#include <memory.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include <sys/time.h>
|
#include <time.h>
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
|
|
||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
@ -58,12 +58,8 @@ int main(int argc, char** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
char file_name[100];
|
char file_name[100];
|
||||||
struct timeval t;
|
time_t t = time(NULL);
|
||||||
if (gettimeofday(&t, NULL)) {
|
snprintf(file_name, sizeof(file_name), "test-%d-%ld.txt", getpid(), t);
|
||||||
perror("gettimeofday failed");
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
snprintf(file_name, sizeof(file_name), "test-%d-%ld.txt", getpid(), t.tv_sec);
|
|
||||||
|
|
||||||
size_t length = 2 + strlen(path) + strlen(file_name);
|
size_t length = 2 + strlen(path) + strlen(file_name);
|
||||||
char* full_path = malloc(length);
|
char* full_path = malloc(length);
|
||||||
|
@ -26,8 +26,8 @@ tf_cuda_library(
|
|||||||
"c_api.cc",
|
"c_api.cc",
|
||||||
"c_api_debug.cc",
|
"c_api_debug.cc",
|
||||||
"c_api_experimental.h",
|
"c_api_experimental.h",
|
||||||
"c_api_internal.cc",
|
|
||||||
"c_api_internal.h",
|
"c_api_internal.h",
|
||||||
|
"tensor_handle_interface.h",
|
||||||
],
|
],
|
||||||
hdrs = ["c_api.h"],
|
hdrs = ["c_api.h"],
|
||||||
copts = tf_copts() + tfe_xla_copts(),
|
copts = tf_copts() + tfe_xla_copts(),
|
||||||
@ -93,6 +93,7 @@ filegroup(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"c_api_experimental.h",
|
"c_api_experimental.h",
|
||||||
"c_api_internal.h",
|
"c_api_internal.h",
|
||||||
|
"tensor_handle_interface.h",
|
||||||
],
|
],
|
||||||
visibility = [
|
visibility = [
|
||||||
"//tensorflow/core:__pkg__",
|
"//tensorflow/core:__pkg__",
|
||||||
@ -102,7 +103,10 @@ filegroup(
|
|||||||
|
|
||||||
tf_cuda_library(
|
tf_cuda_library(
|
||||||
name = "c_api_internal",
|
name = "c_api_internal",
|
||||||
srcs = ["c_api_experimental.h"],
|
srcs = [
|
||||||
|
"c_api_experimental.h",
|
||||||
|
"tensor_handle_interface.h",
|
||||||
|
],
|
||||||
hdrs = ["c_api_internal.h"],
|
hdrs = ["c_api_internal.h"],
|
||||||
visibility = [
|
visibility = [
|
||||||
"//learning/deepmind/courier:__subpackages__",
|
"//learning/deepmind/courier:__subpackages__",
|
||||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
|||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
#include "tensorflow/c/c_api_internal.h"
|
#include "tensorflow/c/c_api_internal.h"
|
||||||
|
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||||
#include "tensorflow/c/tf_tensor_internal.h"
|
#include "tensorflow/c/tf_tensor_internal.h"
|
||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
@ -81,6 +82,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
|
|
||||||
#include "tensorflow/core/lib/random/random.h"
|
#include "tensorflow/core/lib/random/random.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/thread_annotations.h"
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
@ -93,10 +95,8 @@ using tensorflow::string;
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) {
|
const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) {
|
||||||
if (op->inference_ctx) {
|
const tensorflow::OpDef* op_def = op->operation.OpDef();
|
||||||
return op->inference_ctx->op_def;
|
if (op_def) return op_def;
|
||||||
}
|
|
||||||
const tensorflow::OpDef* op_def;
|
|
||||||
status->status =
|
status->status =
|
||||||
tensorflow::OpDefForOp(op->operation.Name().c_str(), &op_def);
|
tensorflow::OpDefForOp(op->operation.Name().c_str(), &op_def);
|
||||||
return op_def;
|
return op_def;
|
||||||
@ -409,6 +409,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||||||
|
|
||||||
// New server created for new server_def. Unused if updating server_def.
|
// New server created for new server_def. Unused if updating server_def.
|
||||||
std::unique_ptr<tensorflow::ServerInterface> new_server;
|
std::unique_ptr<tensorflow::ServerInterface> new_server;
|
||||||
|
tensorflow::EagerContext* context = ctx->context;
|
||||||
tensorflow::GrpcServer* grpc_server;
|
tensorflow::GrpcServer* grpc_server;
|
||||||
if (reset_context) {
|
if (reset_context) {
|
||||||
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
|
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
|
||||||
@ -416,26 +417,25 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||||||
LOG_AND_RETURN_IF_ERROR(
|
LOG_AND_RETURN_IF_ERROR(
|
||||||
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
|
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
|
||||||
} else {
|
} else {
|
||||||
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(
|
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
|
||||||
ctx->context->GetServer(), worker_name, &curr_remote_workers));
|
&curr_remote_workers));
|
||||||
// No need to check the cast here, since `ListRemoteWorkers` already checks
|
// No need to check the cast here, since `ListRemoteWorkers` already checks
|
||||||
// if the server is a GRPC server or not.
|
// if the server is a GRPC server or not.
|
||||||
grpc_server =
|
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||||
dynamic_cast<tensorflow::GrpcServer*>(ctx->context->GetServer());
|
|
||||||
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
|
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
|
||||||
LOG_AND_RETURN_IF_ERROR(
|
LOG_AND_RETURN_IF_ERROR(
|
||||||
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
|
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
|
||||||
}
|
}
|
||||||
|
|
||||||
tensorflow::uint64 context_id = ctx->context->GetContextId();
|
tensorflow::uint64 context_id = context->GetContextId();
|
||||||
tensorflow::uint64 context_view_id = ctx->context->GetContextViewId();
|
tensorflow::uint64 context_view_id = context->GetContextViewId();
|
||||||
if (reset_context) {
|
if (reset_context) {
|
||||||
context_id = tensorflow::EagerContext::NewContextId();
|
context_id = tensorflow::EagerContext::NewContextId();
|
||||||
context_view_id = 0;
|
context_view_id = 0;
|
||||||
// Make master eager context accessible by local eager service, which might
|
// Make master eager context accessible by local eager service, which might
|
||||||
// receive send tensor requests from remote workers.
|
// receive send tensor requests from remote workers.
|
||||||
LOG_AND_RETURN_IF_ERROR(grpc_server->AddMasterEagerContextToEagerService(
|
LOG_AND_RETURN_IF_ERROR(
|
||||||
context_id, ctx->context));
|
grpc_server->AddMasterEagerContextToEagerService(context_id, context));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
|
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
|
||||||
@ -464,11 +464,11 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||||||
&new_remote_device_mgr));
|
&new_remote_device_mgr));
|
||||||
remote_device_mgr = new_remote_device_mgr.get();
|
remote_device_mgr = new_remote_device_mgr.get();
|
||||||
} else {
|
} else {
|
||||||
ctx->context->ClearCachesAndDefaultExecutor();
|
context->ClearCachesAndDefaultExecutor();
|
||||||
// TODO(b/143914772): Potential memory leak if rendezvous has pending
|
// TODO(b/143914772): Potential memory leak if rendezvous has pending
|
||||||
// tensors for removed / replaced workers.
|
// tensors for removed / replaced workers.
|
||||||
|
|
||||||
remote_device_mgr = ctx->context->GetOwnedRemoteDeviceMgr();
|
remote_device_mgr = context->GetOwnedRemoteDeviceMgr();
|
||||||
if (remote_device_mgr == nullptr) {
|
if (remote_device_mgr == nullptr) {
|
||||||
LOG_AND_RETURN_IF_ERROR(tensorflow::errors::InvalidArgument(
|
LOG_AND_RETURN_IF_ERROR(tensorflow::errors::InvalidArgument(
|
||||||
"Updating context with an invalid set of remote devices."));
|
"Updating context with an invalid set of remote devices."));
|
||||||
@ -479,8 +479,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||||||
&added_workers, &removed_workers,
|
&added_workers, &removed_workers,
|
||||||
&existing_workers);
|
&existing_workers);
|
||||||
LOG_AND_RETURN_IF_ERROR(GetReplacedFromExistingWorkers(
|
LOG_AND_RETURN_IF_ERROR(GetReplacedFromExistingWorkers(
|
||||||
&existing_workers, context_id, ctx->context->GetContextViewId(),
|
&existing_workers, context_id, context->GetContextViewId(), server_def,
|
||||||
server_def, remote_eager_workers.get(), &replaced_workers));
|
remote_eager_workers.get(), &replaced_workers));
|
||||||
if (VLOG_IS_ON(1)) {
|
if (VLOG_IS_ON(1)) {
|
||||||
VLOG(1) << "Updating cluster with following changes";
|
VLOG(1) << "Updating cluster with following changes";
|
||||||
for (const string& w : added_workers) VLOG(1) << " Added worker " << w;
|
for (const string& w : added_workers) VLOG(1) << " Added worker " << w;
|
||||||
@ -516,7 +516,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||||||
grpc_server->worker_env()->device_mgr->ListDeviceAttributes(
|
grpc_server->worker_env()->device_mgr->ListDeviceAttributes(
|
||||||
&local_device_attributes);
|
&local_device_attributes);
|
||||||
|
|
||||||
// This request make sure that we can create Rendevzous properly between
|
// This request make sure that we can create Rendezvous properly between
|
||||||
// Local and Remote context.
|
// Local and Remote context.
|
||||||
tensorflow::eager::CreateContextRequest base_request;
|
tensorflow::eager::CreateContextRequest base_request;
|
||||||
for (const auto& da : cluster_device_attributes) {
|
for (const auto& da : cluster_device_attributes) {
|
||||||
@ -534,9 +534,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||||||
if (reset_context) {
|
if (reset_context) {
|
||||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
||||||
remote_workers, context_id, context_view_id, keep_alive_secs,
|
remote_workers, context_id, context_view_id, keep_alive_secs,
|
||||||
server_def, remote_eager_workers.get(),
|
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||||
ctx->context->Executor().Async(),
|
context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||||
ctx->context->LazyCopyFunctionRemoteInputs(), base_request));
|
|
||||||
} else {
|
} else {
|
||||||
// The master's context_view_id will be incremented by one
|
// The master's context_view_id will be incremented by one
|
||||||
// the UpdateRemoteMaster call later. We want all new workers and
|
// the UpdateRemoteMaster call later. We want all new workers and
|
||||||
@ -545,9 +544,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||||||
// context_view_id + 1.
|
// context_view_id + 1.
|
||||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
||||||
added_workers, context_id, context_view_id + 1, keep_alive_secs,
|
added_workers, context_id, context_view_id + 1, keep_alive_secs,
|
||||||
server_def, remote_eager_workers.get(),
|
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||||
ctx->context->Executor().Async(),
|
context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||||
ctx->context->LazyCopyFunctionRemoteInputs(), base_request));
|
|
||||||
if (!existing_workers.empty()) {
|
if (!existing_workers.empty()) {
|
||||||
if (VLOG_IS_ON(1)) {
|
if (VLOG_IS_ON(1)) {
|
||||||
for (const string& w : existing_workers) {
|
for (const string& w : existing_workers) {
|
||||||
@ -578,12 +576,12 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||||||
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
|
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
|
||||||
|
|
||||||
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
|
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
|
||||||
tensorflow::eager::CreateClusterFLR(context_id, ctx->context,
|
tensorflow::eager::CreateClusterFLR(context_id, context,
|
||||||
worker_session.get());
|
worker_session.get());
|
||||||
auto remote_mgr = absl::make_unique<tensorflow::eager::RemoteMgr>(
|
auto remote_mgr = absl::make_unique<tensorflow::eager::RemoteMgr>(
|
||||||
/*is_master=*/true, ctx->context);
|
/*is_master=*/true, context);
|
||||||
|
|
||||||
LOG_AND_RETURN_IF_ERROR(ctx->context->InitializeRemoteMaster(
|
LOG_AND_RETURN_IF_ERROR(context->InitializeRemoteMaster(
|
||||||
std::move(new_server), grpc_server->worker_env(), worker_session,
|
std::move(new_server), grpc_server->worker_env(), worker_session,
|
||||||
std::move(remote_eager_workers), std::move(new_remote_device_mgr),
|
std::move(remote_eager_workers), std::move(new_remote_device_mgr),
|
||||||
remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr,
|
remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr,
|
||||||
@ -601,9 +599,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||||||
grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
|
grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
|
||||||
session_name, &worker_session));
|
session_name, &worker_session));
|
||||||
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
|
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
|
||||||
tensorflow::eager::CreateClusterFLR(context_id, ctx->context,
|
tensorflow::eager::CreateClusterFLR(context_id, context,
|
||||||
worker_session.get());
|
worker_session.get());
|
||||||
LOG_AND_RETURN_IF_ERROR(ctx->context->UpdateRemoteMaster(
|
LOG_AND_RETURN_IF_ERROR(context->UpdateRemoteMaster(
|
||||||
grpc_server->worker_env(), std::move(remote_eager_workers),
|
grpc_server->worker_env(), std::move(remote_eager_workers),
|
||||||
added_workers, removed_workers, context_id, r, device_mgr,
|
added_workers, removed_workers, context_id, r, device_mgr,
|
||||||
keep_alive_secs, cluster_flr));
|
keep_alive_secs, cluster_flr));
|
||||||
@ -614,77 +612,6 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||||||
}
|
}
|
||||||
#endif // !IS_MOBILE_PLATFORM
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
|
|
||||||
tensorflow::Status OpInferSingleInputAttrs(TFE_Op* op,
|
|
||||||
TFE_TensorHandle* input) {
|
|
||||||
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
|
|
||||||
const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++);
|
|
||||||
if (!input_def.number_attr().empty() || !input_def.type_list_attr().empty()) {
|
|
||||||
// Some clients that are still setting their input attributes manually are
|
|
||||||
// adding input list to their op by calling `TFE_OpAddInput` for each of
|
|
||||||
// its elements instead of calling `TFE_OpAddInputList`. When this happens,
|
|
||||||
// we cannot detect the end of such list, thus lose track of the input
|
|
||||||
// arguments in the op definition. To guarantee backward compatibility with
|
|
||||||
// those clients, disable automatic inference in this case.
|
|
||||||
op->inference_ctx.reset(nullptr);
|
|
||||||
return tensorflow::Status::OK();
|
|
||||||
}
|
|
||||||
const std::string& type_attr = input_def.type_attr();
|
|
||||||
if (!type_attr.empty() && ictx->attrs.find(type_attr) == ictx->attrs.end()) {
|
|
||||||
op->operation.MutableAttrs()->Set(type_attr, input->handle->dtype);
|
|
||||||
ictx->attrs.insert(type_attr);
|
|
||||||
}
|
|
||||||
return tensorflow::Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
void OpInferSingleTypeInputListAttrs(TFE_Op* op,
|
|
||||||
const tensorflow::OpDef::ArgDef& input_def,
|
|
||||||
TFE_TensorHandle** inputs,
|
|
||||||
int num_inputs) {
|
|
||||||
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
|
|
||||||
if (ictx->attrs.find(input_def.number_attr()) == ictx->attrs.end()) {
|
|
||||||
op->operation.MutableAttrs()->Set(input_def.number_attr(), num_inputs);
|
|
||||||
ictx->attrs.insert(input_def.number_attr());
|
|
||||||
}
|
|
||||||
if (ictx->attrs.find(input_def.type_attr()) == ictx->attrs.end()) {
|
|
||||||
op->operation.MutableAttrs()->Set(input_def.type_attr(),
|
|
||||||
inputs[0]->handle->dtype);
|
|
||||||
ictx->attrs.insert(input_def.type_attr());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void OpInferMixedTypeInputListAttrs(TFE_Op* op,
|
|
||||||
const tensorflow::OpDef::ArgDef& input_def,
|
|
||||||
TFE_TensorHandle** inputs, int num_inputs) {
|
|
||||||
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
|
|
||||||
if (ictx->attrs.find(input_def.type_list_attr()) == ictx->attrs.end()) {
|
|
||||||
std::unique_ptr<tensorflow::DataType[]> dtypes(
|
|
||||||
new tensorflow::DataType[num_inputs]);
|
|
||||||
for (int i = 0; i < num_inputs; ++i) {
|
|
||||||
dtypes[i] = inputs[i]->handle->dtype;
|
|
||||||
}
|
|
||||||
op->operation.MutableAttrs()->Set(
|
|
||||||
input_def.type_list_attr(),
|
|
||||||
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(dtypes.get(),
|
|
||||||
num_inputs));
|
|
||||||
ictx->attrs.insert(input_def.type_list_attr());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
tensorflow::Status OpInferInputListAttrs(TFE_Op* op, TFE_TensorHandle** inputs,
|
|
||||||
int num_inputs) {
|
|
||||||
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
|
|
||||||
const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++);
|
|
||||||
if (!input_def.type_list_attr().empty()) {
|
|
||||||
OpInferMixedTypeInputListAttrs(op, input_def, inputs, num_inputs);
|
|
||||||
} else if (!input_def.type_attr().empty() &&
|
|
||||||
!input_def.number_attr().empty()) {
|
|
||||||
OpInferSingleTypeInputListAttrs(op, input_def, inputs, num_inputs);
|
|
||||||
} else {
|
|
||||||
return tensorflow::errors::InvalidArgument("Invalid input list definition");
|
|
||||||
}
|
|
||||||
return tensorflow::Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
@ -720,12 +647,14 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
|||||||
tensorflow::Rendezvous* r =
|
tensorflow::Rendezvous* r =
|
||||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||||
|
|
||||||
return new TFE_Context(opts->session_options.options,
|
return new TFE_Context{new tensorflow::EagerContext(
|
||||||
opts->device_placement_policy, opts->mirroring_policy,
|
opts->session_options.options,
|
||||||
opts->async, opts->lazy_remote_inputs_copy,
|
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||||
device_mgr.release(),
|
opts->device_placement_policy),
|
||||||
/*device_mgr_owned*/ true, r,
|
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
|
||||||
tensorflow::GetDefaultCustomKernelCreator());
|
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
|
||||||
|
/*device_mgr_owned*/ true, r,
|
||||||
|
tensorflow::GetDefaultCustomKernelCreator())};
|
||||||
}
|
}
|
||||||
|
|
||||||
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
|
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
|
||||||
@ -736,22 +665,28 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
|
|||||||
tensorflow::Rendezvous* r =
|
tensorflow::Rendezvous* r =
|
||||||
new tensorflow::IntraProcessRendezvous(device_mgr);
|
new tensorflow::IntraProcessRendezvous(device_mgr);
|
||||||
|
|
||||||
return new TFE_Context(opts->session_options.options,
|
return new TFE_Context{new tensorflow::EagerContext(
|
||||||
opts->device_placement_policy, opts->mirroring_policy,
|
opts->session_options.options,
|
||||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
|
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||||
/*device_mgr_owned*/ false, r,
|
opts->device_placement_policy),
|
||||||
tensorflow::GetDefaultCustomKernelCreator());
|
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
|
||||||
|
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
|
||||||
|
/*device_mgr_owned*/ false, r,
|
||||||
|
tensorflow::GetDefaultCustomKernelCreator())};
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; }
|
void TFE_DeleteContext(TFE_Context* ctx) {
|
||||||
|
// context->RefCountIsOne() should be true here.
|
||||||
|
// TODO(iga): Remove EagerContext refcounting.
|
||||||
|
ctx->context->Unref();
|
||||||
|
|
||||||
|
delete ctx;
|
||||||
|
}
|
||||||
|
|
||||||
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
|
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
|
||||||
TF_DeviceList* list = new TF_DeviceList;
|
TF_DeviceList* l = new TF_DeviceList;
|
||||||
ctx->context->local_device_mgr()->ListDeviceAttributes(&list->response);
|
ctx->context->ListDevices(&l->response);
|
||||||
if (ctx->context->remote_device_mgr()) {
|
return l;
|
||||||
ctx->context->remote_device_mgr()->ListDeviceAttributes(&list->response);
|
|
||||||
}
|
|
||||||
return list;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextClearCaches(TFE_Context* ctx) {
|
void TFE_ContextClearCaches(TFE_Context* ctx) {
|
||||||
@ -812,8 +747,9 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
|||||||
"TFE_ContextSetServerDef not supported on mobile");
|
"TFE_ContextSetServerDef not supported on mobile");
|
||||||
return false;
|
return false;
|
||||||
#else // !defined(IS_MOBILE_PLATFORM)
|
#else // !defined(IS_MOBILE_PLATFORM)
|
||||||
|
tensorflow::EagerContext* context = ctx->context;
|
||||||
tensorflow::GrpcServer* grpc_server =
|
tensorflow::GrpcServer* grpc_server =
|
||||||
static_cast<tensorflow::GrpcServer*>(ctx->context->GetServer());
|
static_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||||
|
|
||||||
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
|
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
|
||||||
status->status = grpc_server->master_env()->worker_cache->GetEagerClientCache(
|
status->status = grpc_server->master_env()->worker_cache->GetEagerClientCache(
|
||||||
@ -832,7 +768,7 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
|||||||
|
|
||||||
// Send a rpc request to the worker to check aliveness.
|
// Send a rpc request to the worker to check aliveness.
|
||||||
tensorflow::eager::KeepAliveRequest request;
|
tensorflow::eager::KeepAliveRequest request;
|
||||||
request.set_context_id(ctx->context->GetContextId());
|
request.set_context_id(context->GetContextId());
|
||||||
tensorflow::eager::KeepAliveResponse response;
|
tensorflow::eager::KeepAliveResponse response;
|
||||||
|
|
||||||
tensorflow::Status keep_alive_status;
|
tensorflow::Status keep_alive_status;
|
||||||
@ -887,108 +823,180 @@ void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
|
|||||||
if (h == nullptr) return;
|
if (h == nullptr) return;
|
||||||
tensorflow::profiler::TraceMe activity(
|
tensorflow::profiler::TraceMe activity(
|
||||||
"TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
|
"TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
|
||||||
VLOG(1) << "Deleting tensor handle " << h << " with internal handle "
|
|
||||||
<< h->handle;
|
|
||||||
if (h->handle) {
|
|
||||||
h->handle->Unref();
|
|
||||||
}
|
|
||||||
delete h;
|
delete h;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tensorflow::TensorHandleInterface::~TensorHandleInterface() {
|
||||||
|
VLOG(1) << "Deleting tensor handle " << this << " with internal handle "
|
||||||
|
<< handle_;
|
||||||
|
if (handle_) {
|
||||||
|
handle_->Unref();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool tensorflow::TensorHandleInterface::IsValid(Status* status) const {
|
||||||
|
if (handle_ == nullptr) {
|
||||||
|
*status = tensorflow::errors::InvalidArgument(
|
||||||
|
"The passed in handle is a nullptr");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
|
TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
|
||||||
return static_cast<TF_DataType>(h->handle->dtype);
|
return h->handle->DataType();
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_DataType tensorflow::TensorHandleInterface::DataType() const {
|
||||||
|
return static_cast<TF_DataType>(handle_->dtype);
|
||||||
}
|
}
|
||||||
|
|
||||||
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
|
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
|
||||||
if (h == nullptr || h->handle == nullptr) {
|
if (h == nullptr) {
|
||||||
status->status = tensorflow::errors::InvalidArgument(
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
"The passed in handle is a nullptr");
|
"The passed in handle is a nullptr");
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return h->handle->NumDims(&status->status);
|
||||||
|
}
|
||||||
|
|
||||||
|
int tensorflow::TensorHandleInterface::NumDims(Status* status) const {
|
||||||
|
if (!IsValid(status)) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
int result;
|
int result;
|
||||||
status->status = h->handle->NumDims(&result);
|
*status = handle_->NumDims(&result);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
|
int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
|
||||||
if (h == nullptr || h->handle == nullptr) {
|
if (h == nullptr) {
|
||||||
status->status = tensorflow::errors::InvalidArgument(
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
"The passed in handle is a nullptr");
|
"The passed in handle is a nullptr");
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return h->handle->NumElements(&status->status);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t tensorflow::TensorHandleInterface::NumElements(Status* status) const {
|
||||||
|
if (!IsValid(status)) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
tensorflow::int64 result;
|
tensorflow::int64 result;
|
||||||
status->status = h->handle->NumElements(&result);
|
*status = handle_->NumElements(&result);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
|
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
if (h == nullptr || h->handle == nullptr) {
|
if (h == nullptr) {
|
||||||
status->status = tensorflow::errors::InvalidArgument(
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
"The passed in handle is a nullptr");
|
"The passed in handle is a nullptr");
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return h->handle->Dim(dim_index, &status->status);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t tensorflow::TensorHandleInterface::Dim(int dim_index,
|
||||||
|
Status* status) const {
|
||||||
|
if (!IsValid(status)) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
tensorflow::int64 result;
|
tensorflow::int64 result;
|
||||||
status->status = h->handle->Dim(dim_index, &result);
|
*status = handle_->Dim(dim_index, &result);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
|
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
|
||||||
if (h == nullptr || h->handle == nullptr) {
|
if (h == nullptr) {
|
||||||
status->status = tensorflow::errors::InvalidArgument(
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
"The passed in handle is a nullptr");
|
"The passed in handle is a nullptr");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
tensorflow::Device* d = h->handle->op_device();
|
return h->handle->DeviceName(&status->status);
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* tensorflow::TensorHandleInterface::DeviceName(
|
||||||
|
Status* status) const {
|
||||||
|
if (!IsValid(status)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
tensorflow::Device* d = handle_->op_device();
|
||||||
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||||
: d->name().c_str();
|
: d->name().c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
|
const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
if (h == nullptr || h->handle == nullptr) {
|
if (h == nullptr) {
|
||||||
status->status = tensorflow::errors::InvalidArgument(
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
"The passed in handle is a nullptr");
|
"The passed in handle is a nullptr");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
tensorflow::Device* d = h->handle->device();
|
return h->handle->BackingDeviceName(&status->status);
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* tensorflow::TensorHandleInterface::BackingDeviceName(
|
||||||
|
Status* status) const {
|
||||||
|
if (!IsValid(status)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
tensorflow::Device* d = handle_->device();
|
||||||
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||||
: d->name().c_str();
|
: d->name().c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
|
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
|
||||||
TFE_TensorHandle* h, TF_Status* status) {
|
TFE_TensorHandle* h, TF_Status* status) {
|
||||||
if (h == nullptr || h->handle == nullptr) {
|
if (h == nullptr || !h->handle->IsValid(&status->status)) {
|
||||||
status->status = tensorflow::errors::InvalidArgument(
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
"The passed in handle is a nullptr");
|
"The passed in handle is a nullptr");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
h->handle->Ref();
|
return new TFE_TensorHandle{
|
||||||
|
std::unique_ptr<AbstractTensorHandleInterface>(h->handle->Copy())};
|
||||||
|
}
|
||||||
|
|
||||||
return new TFE_TensorHandle(h->handle);
|
AbstractTensorHandleInterface* tensorflow::TensorHandleInterface::Copy() {
|
||||||
|
handle_->Ref();
|
||||||
|
return new TensorHandleInterface(handle_);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||||
if (h == nullptr || h->handle == nullptr) {
|
if (h == nullptr) {
|
||||||
status->status = tensorflow::errors::InvalidArgument(
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
"The passed in handle is a nullptr");
|
"The passed in handle is a nullptr");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
tensorflow::TensorHandle* handle = h->handle;
|
|
||||||
|
return h->handle->Resolve(&status->status);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
|
||||||
|
if (!IsValid(status)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(agarwal): move this implementation inside TFE_TensorHandle.
|
// TODO(agarwal): move this implementation inside TFE_TensorHandle.
|
||||||
if (handle->IsRemote()) {
|
if (handle_->IsRemote()) {
|
||||||
const tensorflow::Tensor* t = nullptr;
|
const tensorflow::Tensor* t = nullptr;
|
||||||
tensorflow::TensorHandle* h_cpu = nullptr;
|
tensorflow::TensorHandle* h_cpu = nullptr;
|
||||||
status->status = EagerCopyToDevice(
|
*status = EagerCopyToDevice(handle_, handle_->Context(),
|
||||||
handle, handle->Context(), &handle->Context()->Executor(),
|
&handle_->Context()->Executor(),
|
||||||
handle->Context()->HostCPU(), false, &h_cpu);
|
handle_->Context()->HostCPU(), false, &h_cpu);
|
||||||
if (!status->status.ok()) {
|
if (!status->ok()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
status->status = h_cpu->Tensor(&t);
|
*status = h_cpu->Tensor(&t);
|
||||||
if (!status->status.ok()) {
|
if (!status->ok()) {
|
||||||
h_cpu->Unref();
|
h_cpu->Unref();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
@ -997,28 +1005,30 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
|||||||
return retval;
|
return retval;
|
||||||
} else {
|
} else {
|
||||||
tensorflow::Tensor tensor;
|
tensorflow::Tensor tensor;
|
||||||
if (IsCPU(handle->device())) {
|
if (IsCPU(handle_->device())) {
|
||||||
const tensorflow::Tensor* src = nullptr;
|
const tensorflow::Tensor* src = nullptr;
|
||||||
status->status = handle->Tensor(&src);
|
*status = handle_->Tensor(&src);
|
||||||
if (!status->status.ok()) return nullptr;
|
if (!status->ok()) return nullptr;
|
||||||
tensor = *src;
|
tensor = *src;
|
||||||
} else {
|
} else {
|
||||||
tensorflow::EagerContext* ctx = handle->Context();
|
tensorflow::EagerContext* ctx = handle_->Context();
|
||||||
CHECK_NE(ctx, nullptr);
|
CHECK_NE(ctx, nullptr);
|
||||||
status->status = h->handle->CopyToDevice(ctx, ctx->HostCPU(), &tensor);
|
*status = handle_->CopyToDevice(*ctx, ctx->HostCPU(), &tensor);
|
||||||
if (!status->status.ok()) return nullptr;
|
if (!status->ok()) return nullptr;
|
||||||
}
|
}
|
||||||
return tensorflow::TF_TensorFromTensor(tensor, status);
|
return tensorflow::TF_TensorFromTensor(tensor, status);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
|
void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
|
||||||
if (h == nullptr || h->handle == nullptr) {
|
if (h == nullptr || !h->handle->IsValid(&status->status)) {
|
||||||
status->status = tensorflow::errors::InvalidArgument(
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
"The passed in handle is a nullptr");
|
"The passed in handle is a nullptr");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
tensorflow::TensorHandle* handle = h->handle;
|
tensorflow::TensorHandle* handle =
|
||||||
|
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||||
|
->Handle();
|
||||||
|
|
||||||
if (handle->IsRemote()) {
|
if (handle->IsRemote()) {
|
||||||
status->status = tensorflow::errors::InvalidArgument(
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
@ -1047,7 +1057,8 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
|||||||
void (*deallocator)(void* data, size_t len, void* arg),
|
void (*deallocator)(void* data, size_t len, void* arg),
|
||||||
void* deallocator_arg, TF_Status* status) {
|
void* deallocator_arg, TF_Status* status) {
|
||||||
tensorflow::Device* device;
|
tensorflow::Device* device;
|
||||||
status->status = ctx->context->FindDeviceFromName(device_name, &device);
|
tensorflow::EagerContext* context = ctx->context;
|
||||||
|
status->status = context->FindDeviceFromName(device_name, &device);
|
||||||
if (!status->status.ok()) {
|
if (!status->status.ok()) {
|
||||||
deallocator(data, len, deallocator_arg);
|
deallocator(data, len, deallocator_arg);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -1075,11 +1086,12 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
|||||||
buf->Unref();
|
buf->Unref();
|
||||||
tensorflow::TensorHandle* ret_handle;
|
tensorflow::TensorHandle* ret_handle;
|
||||||
status->status = tensorflow::TensorHandle::CreateLocalHandle(
|
status->status = tensorflow::TensorHandle::CreateLocalHandle(
|
||||||
t, device, ctx->context, &ret_handle);
|
t, device, context, &ret_handle);
|
||||||
if (!status->status.ok()) {
|
if (!status->status.ok()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return new TFE_TensorHandle(ret_handle);
|
return new TFE_TensorHandle{
|
||||||
|
std::make_unique<tensorflow::TensorHandleInterface>(ret_handle)};
|
||||||
}
|
}
|
||||||
|
|
||||||
// This function will block till the operation that produces `h` has
|
// This function will block till the operation that produces `h` has
|
||||||
@ -1087,12 +1099,14 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
|||||||
// bytes of the memory pointed to by the device pointer returned above.
|
// bytes of the memory pointed to by the device pointer returned above.
|
||||||
size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
if (h == nullptr || h->handle == nullptr) {
|
if (h == nullptr || !h->handle->IsValid(&status->status)) {
|
||||||
status->status = tensorflow::errors::InvalidArgument(
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
"The passed in handle is a nullptr");
|
"The passed in handle is a nullptr");
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
tensorflow::TensorHandle* handle = h->handle;
|
tensorflow::TensorHandle* handle =
|
||||||
|
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||||
|
->Handle();
|
||||||
|
|
||||||
if (handle->IsRemote()) {
|
if (handle->IsRemote()) {
|
||||||
status->status = tensorflow::errors::InvalidArgument(
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
@ -1110,8 +1124,14 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
|||||||
|
|
||||||
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
return NewOrResetOp(ctx, op_or_function_name, nullptr, status,
|
std::unique_ptr<TFE_Op> new_op(
|
||||||
/* op_to_reset= */ nullptr);
|
new TFE_Op{tensorflow::EagerOperation(ctx->context)});
|
||||||
|
status->status =
|
||||||
|
new_op->operation.Reset(op_or_function_name, nullptr, false, nullptr);
|
||||||
|
if (!status->status.ok()) {
|
||||||
|
new_op.reset();
|
||||||
|
}
|
||||||
|
return new_op.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_DeleteOp(TFE_Op* op) { delete op; }
|
void TFE_DeleteOp(TFE_Op* op) { delete op; }
|
||||||
@ -1122,7 +1142,7 @@ void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
|
|||||||
|
|
||||||
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
|
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
|
||||||
tensorflow::Device* device = (op->operation.Device() == nullptr)
|
tensorflow::Device* device = (op->operation.Device() == nullptr)
|
||||||
? op->operation.EagerContext()->HostCPU()
|
? op->operation.EagerContext().HostCPU()
|
||||||
: op->operation.Device();
|
: op->operation.Device();
|
||||||
return device->name().c_str();
|
return device->name().c_str();
|
||||||
}
|
}
|
||||||
@ -1136,20 +1156,23 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
|
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
|
||||||
op->operation.AddInput(input->handle);
|
tensorflow::TensorHandle* h =
|
||||||
if (op->inference_ctx) {
|
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||||
status->status = OpInferSingleInputAttrs(op, input);
|
input->handle.get())
|
||||||
}
|
->Handle();
|
||||||
|
op->operation.AddInput(h);
|
||||||
|
status->status = op->operation.MaybeInferSingleInputAttrs(h);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
|
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
for (int i = 0; i < num_inputs; ++i) {
|
for (int i = 0; i < num_inputs; ++i) {
|
||||||
op->operation.AddInput(inputs[i]->handle);
|
op->operation.AddInput(
|
||||||
}
|
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||||
if (op->inference_ctx) {
|
inputs[i]->handle.get())
|
||||||
status->status = OpInferInputListAttrs(op, inputs, num_inputs);
|
->Handle());
|
||||||
}
|
}
|
||||||
|
status->status = op->operation.InferInputListAttrs(num_inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
||||||
@ -1382,15 +1405,16 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
|
|||||||
|
|
||||||
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
VLOG(1) << "Calling TFE_Execute() on op " << op;
|
|
||||||
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
|
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
|
||||||
|
VLOG(1) << "Calling TFE_Execute() on op " << op;
|
||||||
status->status = tensorflow::EagerExecute(&op->operation,
|
status->status = tensorflow::EagerExecute(&op->operation,
|
||||||
handle_retvals.data(), num_retvals);
|
handle_retvals.data(), num_retvals);
|
||||||
if (!status->status.ok()) {
|
if (!status->status.ok()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
for (int i = 0; i < *num_retvals; ++i) {
|
for (int i = 0; i < *num_retvals; ++i) {
|
||||||
retvals[i] = new TFE_TensorHandle(handle_retvals[i]);
|
retvals[i] = new TFE_TensorHandle{
|
||||||
|
std::make_unique<tensorflow::TensorHandleInterface>(handle_retvals[i])};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1400,15 +1424,18 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
|||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
tensorflow::TensorHandle* handle = nullptr;
|
tensorflow::TensorHandle* handle = nullptr;
|
||||||
tensorflow::Device* device;
|
tensorflow::Device* device;
|
||||||
status->status = ctx->context->FindDeviceFromName(device_name, &device);
|
tensorflow::EagerContext* context = ctx->context;
|
||||||
|
status->status = context->FindDeviceFromName(device_name, &device);
|
||||||
if (!status->status.ok()) {
|
if (!status->status.ok()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context,
|
status->status = tensorflow::EagerCopyToDevice(
|
||||||
&ctx->context->Executor(),
|
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||||
device, false, &handle);
|
->Handle(),
|
||||||
|
context, &context->Executor(), device, false, &handle);
|
||||||
if (status->status.ok()) {
|
if (status->status.ok()) {
|
||||||
return new TFE_TensorHandle(handle);
|
return new TFE_TensorHandle{
|
||||||
|
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
@ -1456,11 +1483,12 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
|
|||||||
|
|
||||||
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
|
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
status->status = ctx->context->Executor().WaitForAllPendingNodes();
|
tensorflow::EagerContext* context = ctx->context;
|
||||||
|
status->status = context->Executor().WaitForAllPendingNodes();
|
||||||
if (!status->status.ok()) return;
|
if (!status->status.ok()) return;
|
||||||
tensorflow::mutex_lock ml(*ctx->context->MetadataMu());
|
tensorflow::mutex_lock ml(*context->MetadataMu());
|
||||||
status->status = MessageToBuffer(*ctx->context->RunMetadataProto(), buf);
|
status->status = MessageToBuffer(*context->RunMetadataProto(), buf);
|
||||||
ctx->context->ClearRunMetadata();
|
context->ClearRunMetadata();
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -206,14 +206,14 @@ typedef struct TFE_TensorDebugInfo TFE_TensorDebugInfo;
|
|||||||
// error and nullptr is returned. This function can block till the operation
|
// error and nullptr is returned. This function can block till the operation
|
||||||
// that produces `handle` has completed.
|
// that produces `handle` has completed.
|
||||||
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||||
TFE_TensorHandle* handle, TF_Status* status);
|
TFE_TensorHandle* h, TF_Status* status);
|
||||||
|
|
||||||
// Deletes `debug_info`.
|
// Deletes `debug_info`.
|
||||||
TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo(
|
TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo(
|
||||||
TFE_TensorDebugInfo* debug_info);
|
TFE_TensorDebugInfo* debug_info);
|
||||||
|
|
||||||
// Returns the number of dimensions used to represent the tensor on its device.
|
// Returns the number of dimensions used to represent the tensor on its device.
|
||||||
// The number of dimensions used to reprensent the tensor on device can be
|
// The number of dimensions used to represent the tensor on device can be
|
||||||
// different from the number returned by TFE_TensorHandleNumDims.
|
// different from the number returned by TFE_TensorHandleNumDims.
|
||||||
// The return value was current at the time of TFE_TensorDebugInfo creation.
|
// The return value was current at the time of TFE_TensorDebugInfo creation.
|
||||||
TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims(
|
TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims(
|
||||||
|
@ -28,19 +28,22 @@ using tensorflow::string;
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
std::vector<int64> TensorShapeAsVector(TFE_TensorHandle* handle,
|
std::vector<int64> TensorShapeAsVector(const tensorflow::TensorHandle& handle,
|
||||||
TF_Status* status) {
|
tensorflow::Status* status) {
|
||||||
std::vector<int64> shape;
|
std::vector<int64> shape;
|
||||||
int rank = TFE_TensorHandleNumDims(handle, status);
|
int rank = -1;
|
||||||
if (TF_GetCode(status) != TF_OK) {
|
*status = handle.NumDims(&rank);
|
||||||
|
if (!status->ok()) {
|
||||||
return shape;
|
return shape;
|
||||||
}
|
}
|
||||||
shape.reserve(rank);
|
shape.reserve(rank);
|
||||||
for (int i = 0; i < rank; ++i) {
|
for (int i = 0; i < rank; ++i) {
|
||||||
shape.push_back(TFE_TensorHandleDim(handle, i, status));
|
tensorflow::int64 dim;
|
||||||
if (TF_GetCode(status) != TF_OK) {
|
*status = handle.Dim(i, &dim);
|
||||||
|
if (!status->ok()) {
|
||||||
return shape;
|
return shape;
|
||||||
}
|
}
|
||||||
|
shape.push_back(dim);
|
||||||
}
|
}
|
||||||
return shape;
|
return shape;
|
||||||
}
|
}
|
||||||
@ -50,15 +53,20 @@ std::vector<int64> TensorShapeAsVector(TFE_TensorHandle* handle,
|
|||||||
extern "C" {
|
extern "C" {
|
||||||
|
|
||||||
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||||
TFE_TensorHandle* handle, TF_Status* status) {
|
TFE_TensorHandle* h, TF_Status* status) {
|
||||||
|
return h->handle->TensorDebugInfo(&status->status);
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
|
||||||
|
Status* status) {
|
||||||
const tensorflow::Tensor* tensor;
|
const tensorflow::Tensor* tensor;
|
||||||
status->status = handle->handle->Tensor(&tensor);
|
*status = handle_->Tensor(&tensor);
|
||||||
if (TF_GetCode(status) != TF_OK) {
|
if (!status->ok()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||||
tensorflow::Device* device = handle->handle->device();
|
tensorflow::Device* device = handle_->device();
|
||||||
|
|
||||||
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
|
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
|
||||||
tensorflow::XlaDevice* xla_device =
|
tensorflow::XlaDevice* xla_device =
|
||||||
@ -67,15 +75,15 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
|||||||
tensorflow::XlaDevice::PaddedShapeFn shape_fn =
|
tensorflow::XlaDevice::PaddedShapeFn shape_fn =
|
||||||
xla_device->metadata().padded_shape_fn();
|
xla_device->metadata().padded_shape_fn();
|
||||||
xla::Shape padded_shape;
|
xla::Shape padded_shape;
|
||||||
status->status = shape_fn(*tensor, &padded_shape);
|
*status = shape_fn(*tensor, &padded_shape);
|
||||||
if (!status->status.ok()) {
|
if (!status->ok()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (VLOG_IS_ON(3)) {
|
if (VLOG_IS_ON(3)) {
|
||||||
std::vector<int64> shape_to_log = TensorShapeAsVector(handle, status);
|
std::vector<int64> shape_to_log = TensorShapeAsVector(*handle_, status);
|
||||||
if (!status->status.ok()) {
|
if (!status->ok()) {
|
||||||
// Ignore the status here as we are simply logging.
|
// Ignore the status here as we are simply logging.
|
||||||
status->status = tensorflow::Status::OK();
|
*status = tensorflow::Status::OK();
|
||||||
} else {
|
} else {
|
||||||
VLOG(3) << "Fully padded shape of ["
|
VLOG(3) << "Fully padded shape of ["
|
||||||
<< absl::StrJoin(shape_to_log, ", ") << "] is "
|
<< absl::StrJoin(shape_to_log, ", ") << "] is "
|
||||||
@ -88,7 +96,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
|||||||
// Currently, the only case of XlaTensor containing a tuple shape is to
|
// Currently, the only case of XlaTensor containing a tuple shape is to
|
||||||
// represent 64 bit ints, doubles, and complex numbers (we don't support
|
// represent 64 bit ints, doubles, and complex numbers (we don't support
|
||||||
// 64bit complex numbers).
|
// 64bit complex numbers).
|
||||||
status->status = tensorflow::errors::InvalidArgument(
|
*status = tensorflow::errors::InvalidArgument(
|
||||||
"XlaTensors should only contain tuples of size 2. Shape: ",
|
"XlaTensors should only contain tuples of size 2. Shape: ",
|
||||||
padded_shape.DebugString());
|
padded_shape.DebugString());
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -100,13 +108,13 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
|||||||
const xla::Shape& shape1 =
|
const xla::Shape& shape1 =
|
||||||
xla::ShapeUtil::GetTupleElementShape(padded_shape, 1);
|
xla::ShapeUtil::GetTupleElementShape(padded_shape, 1);
|
||||||
if (shape0.IsTuple() || shape1.IsTuple()) {
|
if (shape0.IsTuple() || shape1.IsTuple()) {
|
||||||
status->status = tensorflow::errors::InvalidArgument(
|
*status = tensorflow::errors::InvalidArgument(
|
||||||
"XlaTensors should not contain nested tuples. Shape: ",
|
"XlaTensors should not contain nested tuples. Shape: ",
|
||||||
padded_shape.DebugString());
|
padded_shape.DebugString());
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (!xla::ShapeUtil::Equal(shape0, shape1)) {
|
if (!xla::ShapeUtil::Equal(shape0, shape1)) {
|
||||||
status->status = tensorflow::errors::InvalidArgument(
|
*status = tensorflow::errors::InvalidArgument(
|
||||||
"Subshapes of XlaTensors should be the same. Shape: ",
|
"Subshapes of XlaTensors should be the same. Shape: ",
|
||||||
padded_shape.DebugString());
|
padded_shape.DebugString());
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -131,15 +139,15 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
|||||||
dev_dims.push_back(padded_shape.dimensions(dim_index));
|
dev_dims.push_back(padded_shape.dimensions(dim_index));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
status->status = tensorflow::Status::OK();
|
*status = tensorflow::Status::OK();
|
||||||
return new TFE_TensorDebugInfo(dev_dims);
|
return new TFE_TensorDebugInfo(dev_dims);
|
||||||
}
|
}
|
||||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||||
|
|
||||||
// If the tensor is not an XLA tensor, the device shape is
|
// If the tensor is not an XLA tensor, the device shape is
|
||||||
// the same as regular tensor shape.
|
// the same as regular tensor shape.
|
||||||
std::vector<int64> dev_dims = TensorShapeAsVector(handle, status);
|
std::vector<int64> dev_dims = TensorShapeAsVector(*handle_, status);
|
||||||
if (TF_GetCode(status) != TF_OK) {
|
if (!status->ok()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return new TFE_TensorDebugInfo(dev_dims);
|
return new TFE_TensorDebugInfo(dev_dims);
|
||||||
|
@ -22,18 +22,18 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/profiler/rpc/client/capture_profile.h"
|
#include "tensorflow/core/profiler/rpc/client/capture_profile.h"
|
||||||
#include "tensorflow/core/profiler/rpc/profiler_server.h"
|
#include "tensorflow/core/profiler/rpc/profiler_server.h"
|
||||||
|
|
||||||
using tensorflow::string;
|
using tensorflow::string;
|
||||||
|
|
||||||
void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name,
|
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
||||||
const char* raw_device_name, TF_Status* status,
|
const char* raw_device_name, TF_Status* status) {
|
||||||
TFE_Op* op_to_reset) {
|
|
||||||
if (op_to_reset) {
|
if (op_to_reset) {
|
||||||
NewOrResetOp(ctx, op_or_function_name, raw_device_name, status,
|
status->status = op_to_reset->operation.Reset(
|
||||||
op_to_reset);
|
op_or_function_name, raw_device_name, false, nullptr);
|
||||||
} else {
|
} else {
|
||||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||||
"op_to_reset should not be nullptr");
|
"op_to_reset should not be nullptr");
|
||||||
@ -41,7 +41,9 @@ void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
|
void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
|
||||||
op->operation.ConsumeInput(h->handle);
|
op->operation.ConsumeInput(
|
||||||
|
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||||
|
->Handle());
|
||||||
}
|
}
|
||||||
|
|
||||||
TFE_Profiler* TFE_NewProfiler() { return new TFE_Profiler(); }
|
TFE_Profiler* TFE_NewProfiler() { return new TFE_Profiler(); }
|
||||||
|
@ -29,10 +29,10 @@ extern "C" {
|
|||||||
// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster
|
// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster
|
||||||
// than seperately calling it because if the existing op has the same
|
// than seperately calling it because if the existing op has the same
|
||||||
// `raw_device_name`, it skips parsing and just leave as it is.
|
// `raw_device_name`, it skips parsing and just leave as it is.
|
||||||
TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Context* ctx,
|
TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset,
|
||||||
const char* op_or_function_name,
|
const char* op_or_function_name,
|
||||||
const char* raw_device_name,
|
const char* raw_device_name,
|
||||||
TF_Status* status, TFE_Op* op_to_reset);
|
TF_Status* status);
|
||||||
|
|
||||||
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
|
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
|
@ -1,66 +0,0 @@
|
|||||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
|
||||||
|
|
||||||
#include "tensorflow/core/platform/errors.h"
|
|
||||||
#include "tensorflow/core/platform/host_info.h"
|
|
||||||
|
|
||||||
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
|
|
||||||
const char* raw_device_name, TF_Status* status,
|
|
||||||
TFE_Op* op_to_reset) {
|
|
||||||
const char* name = op_or_function_name; // Shorthand
|
|
||||||
const tensorflow::AttrTypeMap* types;
|
|
||||||
bool is_function = false;
|
|
||||||
status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function);
|
|
||||||
if (!status->status.ok()) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (op_to_reset && op_to_reset->ctx != ctx) {
|
|
||||||
status->status = tensorflow::errors::Internal(
|
|
||||||
"Cannot reset a TFE_Op from another TFE_Context");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
|
|
||||||
if (!is_function) {
|
|
||||||
const tensorflow::OpDef* op_def;
|
|
||||||
status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def);
|
|
||||||
if (!status->status.ok()) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
inference_ctx.reset(new TFE_OpInferenceContext(op_def));
|
|
||||||
} else if (!ctx->context->FindFunctionByName(name)) {
|
|
||||||
status->status = tensorflow::errors::NotFound(
|
|
||||||
"'", name,
|
|
||||||
"' is neither a type of a primitive operation nor a name "
|
|
||||||
"of a function registered in binary running on ",
|
|
||||||
tensorflow::port::Hostname(),
|
|
||||||
". Make sure the operation or function is "
|
|
||||||
"registered in the binary running in this process.");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (op_to_reset) {
|
|
||||||
status->status = op_to_reset->Reset(
|
|
||||||
name, is_function, types, raw_device_name, std::move(inference_ctx));
|
|
||||||
return op_to_reset;
|
|
||||||
}
|
|
||||||
|
|
||||||
TFE_Op* new_op =
|
|
||||||
new TFE_Op(ctx, name, is_function, types, std::move(inference_ctx));
|
|
||||||
status->status = new_op->operation.SetDeviceName(raw_device_name);
|
|
||||||
return new_op;
|
|
||||||
}
|
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/c_api_internal.h"
|
#include "tensorflow/c/c_api_internal.h"
|
||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
|
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
@ -62,36 +63,10 @@ struct TFE_ContextOptions {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct TFE_Context {
|
struct TFE_Context {
|
||||||
TFE_Context(const tensorflow::SessionOptions& opts,
|
|
||||||
TFE_ContextDevicePlacementPolicy default_device_placement_policy,
|
|
||||||
TFE_ContextMirroringPolicy default_mirroring_policy, bool async,
|
|
||||||
const bool lazy_remote_inputs_copy,
|
|
||||||
const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned,
|
|
||||||
tensorflow::Rendezvous* rendezvous,
|
|
||||||
const tensorflow::CustomKernelCreator* custom_kernel_creator)
|
|
||||||
: context(new tensorflow::EagerContext(
|
|
||||||
opts,
|
|
||||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
|
||||||
default_device_placement_policy),
|
|
||||||
static_cast<tensorflow::ContextMirroringPolicy>(
|
|
||||||
default_mirroring_policy),
|
|
||||||
async, lazy_remote_inputs_copy, device_mgr, device_mgr_owned,
|
|
||||||
rendezvous, custom_kernel_creator)) {}
|
|
||||||
|
|
||||||
~TFE_Context() {
|
|
||||||
// TODO(iga): Add a separate API method to shutdown TFE_Context so that we
|
|
||||||
// don't send RPCs and block in destructor.
|
|
||||||
context->WaitForAndCloseRemoteContexts();
|
|
||||||
// context->RefCountIsOne() should be true here.
|
|
||||||
// TODO(iga): Remove EagerContext refcounting.
|
|
||||||
context->Unref();
|
|
||||||
}
|
|
||||||
|
|
||||||
tensorflow::EagerContext* context;
|
tensorflow::EagerContext* context;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TFE_TensorHandle {
|
struct TFE_TensorHandle {
|
||||||
explicit TFE_TensorHandle(tensorflow::TensorHandle* h) : handle(h) {}
|
|
||||||
static TFE_TensorHandle* CreateLocalHandle(const class tensorflow::Tensor& t,
|
static TFE_TensorHandle* CreateLocalHandle(const class tensorflow::Tensor& t,
|
||||||
TF_Status* s) {
|
TF_Status* s) {
|
||||||
tensorflow::TensorHandle* handle;
|
tensorflow::TensorHandle* handle;
|
||||||
@ -99,10 +74,11 @@ struct TFE_TensorHandle {
|
|||||||
if (!s->status.ok()) {
|
if (!s->status.ok()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return new TFE_TensorHandle(handle);
|
return new TFE_TensorHandle{
|
||||||
|
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
|
||||||
}
|
}
|
||||||
|
|
||||||
tensorflow::TensorHandle* handle;
|
std::unique_ptr<AbstractTensorHandleInterface> handle;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TFE_TensorDebugInfo {
|
struct TFE_TensorDebugInfo {
|
||||||
@ -113,46 +89,10 @@ struct TFE_TensorDebugInfo {
|
|||||||
std::vector<tensorflow::int64> dev_dims;
|
std::vector<tensorflow::int64> dev_dims;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TFE_OpInferenceContext {
|
|
||||||
explicit TFE_OpInferenceContext(const tensorflow::OpDef* op_def)
|
|
||||||
: op_def(op_def) {}
|
|
||||||
|
|
||||||
const tensorflow::OpDef* op_def; // op definition from protobuf
|
|
||||||
int input_arg_idx = 0; // arg definition index for the next input to be added
|
|
||||||
tensorflow::gtl::FlatSet<std::string> attrs; // attributes inferred so far
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TFE_Op {
|
struct TFE_Op {
|
||||||
TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
|
|
||||||
const tensorflow::AttrTypeMap* t,
|
|
||||||
std::unique_ptr<TFE_OpInferenceContext> inference_ctx)
|
|
||||||
: ctx(ctx),
|
|
||||||
operation(ctx->context, op, is_function, t),
|
|
||||||
inference_ctx(std::move(inference_ctx)) {}
|
|
||||||
|
|
||||||
void Clear() {
|
|
||||||
operation.Clear();
|
|
||||||
inference_ctx.reset();
|
|
||||||
}
|
|
||||||
|
|
||||||
tensorflow::Status Reset(const char* op, bool is_function,
|
|
||||||
const tensorflow::AttrTypeMap* t,
|
|
||||||
const char* raw_device_name,
|
|
||||||
std::unique_ptr<TFE_OpInferenceContext> infer_ctx) {
|
|
||||||
inference_ctx = std::move(infer_ctx);
|
|
||||||
return operation.Reset(ctx->context, op, is_function, t, raw_device_name,
|
|
||||||
nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
TFE_Context* ctx;
|
|
||||||
tensorflow::EagerOperation operation;
|
tensorflow::EagerOperation operation;
|
||||||
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
|
|
||||||
const char* raw_device_name, TF_Status* status,
|
|
||||||
TFE_Op* op_to_reset = nullptr);
|
|
||||||
|
|
||||||
struct TFE_Profiler {
|
struct TFE_Profiler {
|
||||||
explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); }
|
explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); }
|
||||||
|
|
||||||
|
@ -1362,10 +1362,11 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
|
|||||||
TFE_TensorHandle* inputs[] = {input1, input2};
|
TFE_TensorHandle* inputs[] = {input1, input2};
|
||||||
TFE_OpAddInput(concatOp, dim, status);
|
TFE_OpAddInput(concatOp, dim, status);
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
CHECK(concatOp->inference_ctx);
|
CHECK(concatOp->operation.OpDef());
|
||||||
TFE_OpAddInput(concatOp, inputs[0], status);
|
TFE_OpAddInput(concatOp, inputs[0], status);
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
EXPECT_FALSE(concatOp->inference_ctx) << "Inference context is still present";
|
EXPECT_FALSE(concatOp->operation.OpDef())
|
||||||
|
<< "Inference context is still present";
|
||||||
TFE_OpAddInput(concatOp, inputs[1], status);
|
TFE_OpAddInput(concatOp, inputs[1], status);
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
@ -284,7 +284,7 @@ class ForwardAccumulator {
|
|||||||
// Temporarily push or pop transient state for this accumulator.
|
// Temporarily push or pop transient state for this accumulator.
|
||||||
//
|
//
|
||||||
// Allows an accumulator which is currently processing an operation to
|
// Allows an accumulator which is currently processing an operation to
|
||||||
// temporarily reset its state. Without pushing and poping, accumulators
|
// temporarily reset its state. Without pushing and popping, accumulators
|
||||||
// ignore operations executed as a direct result of their own jvp
|
// ignore operations executed as a direct result of their own jvp
|
||||||
// computations.
|
// computations.
|
||||||
void PushState() { call_state_.emplace(nullptr, false); }
|
void PushState() { call_state_.emplace(nullptr, false); }
|
||||||
|
90
tensorflow/c/eager/tensor_handle_interface.h
Normal file
90
tensorflow/c/eager/tensor_handle_interface.h
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
||||||
|
#define TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||||
|
|
||||||
|
// Abstract interface to a TensorHandle.
|
||||||
|
//
|
||||||
|
// A TensorHandle is management class around a Tensor which may track additional
|
||||||
|
// metadata and synchronization.
|
||||||
|
//
|
||||||
|
// This allows us to hide concrete implementations of TensorHandle from header
|
||||||
|
// files. The interface lists the common functionality that must be provided by
|
||||||
|
// any concrete implementation. However, in cases where the true concrete class
|
||||||
|
// is needed a static_cast can be applied.
|
||||||
|
class AbstractTensorHandleInterface {
|
||||||
|
public:
|
||||||
|
virtual ~AbstractTensorHandleInterface() {}
|
||||||
|
|
||||||
|
// Check if the handle is in a valid initialized state.
|
||||||
|
virtual bool IsValid(tensorflow::Status* status) const = 0;
|
||||||
|
// Returns tensor dtype.
|
||||||
|
virtual TF_DataType DataType() const = 0;
|
||||||
|
// Returns number of dimensions.
|
||||||
|
virtual int NumDims(tensorflow::Status* status) const = 0;
|
||||||
|
// Returns number of elements across all dimensions.
|
||||||
|
virtual int64_t NumElements(tensorflow::Status* status) const = 0;
|
||||||
|
// Returns size of specified dimension
|
||||||
|
virtual int64_t Dim(int dim_index, tensorflow::Status* status) const = 0;
|
||||||
|
|
||||||
|
// Returns the device which created the handle.
|
||||||
|
virtual const char* DeviceName(tensorflow::Status* status) const = 0;
|
||||||
|
// Returns the device where the tensor was placed.
|
||||||
|
virtual const char* BackingDeviceName(tensorflow::Status* status) const = 0;
|
||||||
|
// Returns a tensor for the handle. If tensor is remote, it will be copied.
|
||||||
|
virtual TF_Tensor* Resolve(tensorflow::Status* status) = 0;
|
||||||
|
// Returns debug information about the tensor.
|
||||||
|
virtual TFE_TensorDebugInfo* TensorDebugInfo(tensorflow::Status* status) = 0;
|
||||||
|
|
||||||
|
// Return a copy of the handle.
|
||||||
|
virtual AbstractTensorHandleInterface* Copy() = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class TensorHandleInterface : public AbstractTensorHandleInterface {
|
||||||
|
public:
|
||||||
|
explicit TensorHandleInterface(TensorHandle* h) : handle_(h) {}
|
||||||
|
~TensorHandleInterface() override;
|
||||||
|
|
||||||
|
bool IsValid(Status* status) const override;
|
||||||
|
TF_DataType DataType() const override;
|
||||||
|
int NumDims(Status* status) const override;
|
||||||
|
int64_t NumElements(Status* status) const override;
|
||||||
|
int64_t Dim(int dim_index, Status* status) const override;
|
||||||
|
|
||||||
|
const char* DeviceName(Status* status) const override;
|
||||||
|
const char* BackingDeviceName(Status* status) const override;
|
||||||
|
TF_Tensor* Resolve(Status* status) override;
|
||||||
|
TFE_TensorDebugInfo* TensorDebugInfo(Status* status) override;
|
||||||
|
|
||||||
|
AbstractTensorHandleInterface* Copy() override;
|
||||||
|
|
||||||
|
// TODO(gjn): This is not a very generic interface, but is needed for specific
|
||||||
|
// use cases.
|
||||||
|
TensorHandle* Handle() { return handle_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
TensorHandle* handle_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
@ -18,37 +18,23 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Core TensorFlow depends on this, this will be included in main library
|
|
||||||
cc_library(
|
|
||||||
name = "filesystem_interface_impl",
|
|
||||||
srcs = ["filesystem_interface.cc"],
|
|
||||||
hdrs = ["filesystem_interface.h"],
|
|
||||||
deps = [
|
|
||||||
":modular_filesystem",
|
|
||||||
"//tensorflow/c:tf_file_statistics",
|
|
||||||
"//tensorflow/c:tf_status",
|
|
||||||
"//tensorflow/c:tf_status_internal",
|
|
||||||
"//tensorflow/core:ptr_util",
|
|
||||||
"//tensorflow/core/platform:env",
|
|
||||||
"//tensorflow/core/platform:logging",
|
|
||||||
"//tensorflow/core/platform:strcat",
|
|
||||||
"//tensorflow/core/platform:stringpiece",
|
|
||||||
],
|
|
||||||
alwayslink = 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Core TensorFlow depends on this, will be included in main library
|
# Core TensorFlow depends on this, will be included in main library
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "modular_filesystem",
|
name = "modular_filesystem",
|
||||||
srcs = ["modular_filesystem.cc"],
|
srcs = [
|
||||||
|
"modular_filesystem.cc",
|
||||||
|
"modular_filesystem_registration.cc",
|
||||||
|
"modular_filesystem_registration.h",
|
||||||
|
],
|
||||||
hdrs = ["modular_filesystem.h"],
|
hdrs = ["modular_filesystem.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":filesystem_interface",
|
":filesystem_interface",
|
||||||
"//tensorflow/c:tf_status_helper",
|
"//tensorflow/c:tf_status_helper",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/c:tf_status_internal",
|
||||||
"//tensorflow/core:ptr_util",
|
"//tensorflow/core:ptr_util",
|
||||||
"//tensorflow/core/platform:env",
|
"//tensorflow/core/platform:env",
|
||||||
"//tensorflow/core/platform:strcat",
|
"//tensorflow/core/platform:errors",
|
||||||
|
"//tensorflow/core/platform:status",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -63,16 +49,12 @@ tf_cc_test(
|
|||||||
"notap", # b/139060984, requires implementing modular support for Google filesystem
|
"notap", # b/139060984, requires implementing modular support for Google filesystem
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":filesystem_interface_impl",
|
":modular_filesystem",
|
||||||
"//tensorflow/c:tf_status",
|
|
||||||
"//tensorflow/c:tf_status_internal",
|
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core/lib/io:path",
|
"//tensorflow/core/lib/io:path",
|
||||||
"//tensorflow/core/platform:env",
|
"//tensorflow/core/platform:env",
|
||||||
"//tensorflow/core/platform:error",
|
"//tensorflow/core/platform:error",
|
||||||
"//tensorflow/core/platform:stacktrace_handler",
|
"//tensorflow/core/platform:stacktrace_handler",
|
||||||
"//tensorflow/core/platform:str_util",
|
|
||||||
"//tensorflow/core/platform:strcat",
|
|
||||||
"//tensorflow/core/platform:test",
|
"//tensorflow/core/platform:test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -1,366 +0,0 @@
|
|||||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
|
||||||
|
|
||||||
#include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
|
|
||||||
#include "tensorflow/c/tf_status_internal.h"
|
|
||||||
#include "tensorflow/core/platform/env.h"
|
|
||||||
#include "tensorflow/core/platform/logging.h"
|
|
||||||
#include "tensorflow/core/platform/strcat.h"
|
|
||||||
#include "tensorflow/core/platform/stringpiece.h"
|
|
||||||
#include "tensorflow/core/util/ptr_util.h"
|
|
||||||
|
|
||||||
/// This translation unit is linked in core TensorFlow and provides the
|
|
||||||
/// functionality needed for plugin registration to check ABI/API compatibility,
|
|
||||||
/// to ensure required methods are present, to ensure plugins are not allowed to
|
|
||||||
/// change functionality after being loaded and to register the filesystems
|
|
||||||
/// provided by a plugin. Consult the header file for more information about
|
|
||||||
/// how this is achieved.
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
// Checks if the plugin and core ABI numbers match, filling in `status`.
|
|
||||||
//
|
|
||||||
// If the numbers don't match, plugin cannot be loaded.
|
|
||||||
static bool CheckABIHelper(int pluginABI, int coreABI, StringPiece where,
|
|
||||||
TF_Status* status) {
|
|
||||||
if (pluginABI != coreABI) {
|
|
||||||
TF_SetStatus(
|
|
||||||
status, TF_FAILED_PRECONDITION,
|
|
||||||
strings::StrCat("Plugin ABI (", pluginABI, ") for ", where,
|
|
||||||
" operations doesn't match expected core ABI (",
|
|
||||||
coreABI, "). Plugin cannot be loaded.")
|
|
||||||
.c_str());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Checks if the plugin and core ABI numbers match, for all operations.
|
|
||||||
//
|
|
||||||
// If the numbers don't match, plugin cannot be loaded.
|
|
||||||
//
|
|
||||||
// Uses the simpler `CheckABIHelper(int, int, StringPiece, TF_Status*)`
|
|
||||||
static bool CheckABI(
|
|
||||||
int plugin_filesystem_ops_ABI,
|
|
||||||
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
|
|
||||||
int plugin_random_access_file_ops_ABI,
|
|
||||||
const TF_WritableFileOps* plugin_writable_file_ops,
|
|
||||||
int plugin_writable_file_ops_ABI,
|
|
||||||
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
|
|
||||||
int plugin_read_only_memory_region_ops_ABI, TF_Status* status) {
|
|
||||||
if (!CheckABIHelper(plugin_filesystem_ops_ABI, TF_FILESYSTEM_OPS_ABI,
|
|
||||||
"filesystem", status))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
if (plugin_random_access_file_ops != nullptr &&
|
|
||||||
!CheckABIHelper(plugin_random_access_file_ops_ABI,
|
|
||||||
TF_RANDOM_ACCESS_FILE_OPS_ABI, "random access file",
|
|
||||||
status))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
if (plugin_writable_file_ops != nullptr &&
|
|
||||||
!CheckABIHelper(plugin_writable_file_ops_ABI, TF_WRITABLE_FILE_OPS_ABI,
|
|
||||||
"writable file", status))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
if (plugin_read_only_memory_region_ops != nullptr &&
|
|
||||||
!CheckABIHelper(plugin_read_only_memory_region_ops_ABI,
|
|
||||||
TF_READ_ONLY_MEMORY_REGION_OPS_ABI,
|
|
||||||
"read only memory region", status))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Checks if the plugin and core API numbers match, logging mismatches.
|
|
||||||
static void CheckAPIHelper(int plugin_API, int core_API, StringPiece where) {
|
|
||||||
if (plugin_API != core_API) {
|
|
||||||
VLOG(0) << "Plugin API (" << plugin_API << ") for " << where
|
|
||||||
<< " operations doesn't match expected core API (" << core_API
|
|
||||||
<< "). Plugin will be loaded but functionality might be missing.";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Checks if the plugin and core API numbers match, for all operations.
|
|
||||||
//
|
|
||||||
// Uses the simpler `CheckAPIHelper(int, int, StringPiece)`.
|
|
||||||
static void CheckAPI(
|
|
||||||
int plugin_filesystem_ops_API,
|
|
||||||
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
|
|
||||||
int plugin_random_access_file_ops_API,
|
|
||||||
const TF_WritableFileOps* plugin_writable_file_ops,
|
|
||||||
int plugin_writable_file_ops_API,
|
|
||||||
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
|
|
||||||
int plugin_read_only_memory_region_ops_API) {
|
|
||||||
CheckAPIHelper(plugin_filesystem_ops_API, TF_FILESYSTEM_OPS_API,
|
|
||||||
"filesystem");
|
|
||||||
|
|
||||||
if (plugin_random_access_file_ops != nullptr)
|
|
||||||
CheckAPIHelper(plugin_random_access_file_ops_API,
|
|
||||||
TF_RANDOM_ACCESS_FILE_OPS_API, "random access file");
|
|
||||||
|
|
||||||
if (plugin_writable_file_ops != nullptr)
|
|
||||||
CheckAPIHelper(plugin_writable_file_ops_API, TF_WRITABLE_FILE_OPS_API,
|
|
||||||
"writable file");
|
|
||||||
|
|
||||||
if (plugin_read_only_memory_region_ops != nullptr)
|
|
||||||
CheckAPIHelper(plugin_read_only_memory_region_ops_API,
|
|
||||||
TF_READ_ONLY_MEMORY_REGION_OPS_API,
|
|
||||||
"read only memory region");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validates the filesystem operations supplied by the plugin.
|
|
||||||
static bool ValidateHelper(const TF_FilesystemOps* ops, TF_Status* status) {
|
|
||||||
if (ops == nullptr) {
|
|
||||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
|
||||||
"Trying to register filesystem without operations");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ops->init == nullptr) {
|
|
||||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
|
||||||
"Trying to register filesystem without `init` operation");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ops->cleanup == nullptr) {
|
|
||||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
|
||||||
"Trying to register filesystem without `cleanup` operation");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validates the random access file operations supplied by the plugin.
|
|
||||||
static bool ValidateHelper(const TF_RandomAccessFileOps* ops,
|
|
||||||
TF_Status* status) {
|
|
||||||
if (ops == nullptr) {
|
|
||||||
// We allow filesystems where files can only be written to (from TF code)
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ops->cleanup == nullptr) {
|
|
||||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
|
||||||
"Trying to register filesystem without `cleanup` operation on "
|
|
||||||
"random access files");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validates the writable file operations supplied by the plugin.
|
|
||||||
static bool ValidateHelper(const TF_WritableFileOps* ops, TF_Status* status) {
|
|
||||||
if (ops == nullptr) {
|
|
||||||
// We allow read-only filesystems
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ops->cleanup == nullptr) {
|
|
||||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
|
||||||
"Trying to register filesystem without `cleanup` operation on "
|
|
||||||
"writable files");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validates the read only memory region operations given by the plugin.
|
|
||||||
static bool ValidateHelper(const TF_ReadOnlyMemoryRegionOps* ops,
|
|
||||||
TF_Status* status) {
|
|
||||||
if (ops == nullptr) {
|
|
||||||
// read only memory region support is always optional
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ops->cleanup == nullptr) {
|
|
||||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
|
||||||
"Trying to register filesystem without `cleanup` operation on "
|
|
||||||
"read only memory regions");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ops->data == nullptr) {
|
|
||||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
|
||||||
"Trying to register filesystem without `data` operation on "
|
|
||||||
"read only memory regions");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ops->length == nullptr) {
|
|
||||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
|
||||||
"Trying to register filesystem without `length` operation on "
|
|
||||||
"read only memory regions");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validates the operations supplied by the plugin.
|
|
||||||
//
|
|
||||||
// Uses the 4 simpler `ValidateHelper(const TF_..., TF_Status*)` to validate
|
|
||||||
// each individual function table and then checks that the function table for a
|
|
||||||
// specific file type exists if the plugin offers support for creating that
|
|
||||||
// type of files.
|
|
||||||
static bool Validate(
|
|
||||||
const TF_FilesystemOps* plugin_filesystem_ops,
|
|
||||||
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
|
|
||||||
const TF_WritableFileOps* plugin_writable_file_ops,
|
|
||||||
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
|
|
||||||
TF_Status* status) {
|
|
||||||
if (!ValidateHelper(plugin_filesystem_ops, status)) return false;
|
|
||||||
if (!ValidateHelper(plugin_random_access_file_ops, status)) return false;
|
|
||||||
if (!ValidateHelper(plugin_writable_file_ops, status)) return false;
|
|
||||||
if (!ValidateHelper(plugin_read_only_memory_region_ops, status)) return false;
|
|
||||||
|
|
||||||
if (plugin_filesystem_ops->new_random_access_file != nullptr &&
|
|
||||||
plugin_random_access_file_ops == nullptr) {
|
|
||||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
|
||||||
"Filesystem allows creation of random access files but no "
|
|
||||||
"operations on them have been supplied.");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((plugin_filesystem_ops->new_writable_file != nullptr ||
|
|
||||||
plugin_filesystem_ops->new_appendable_file != nullptr) &&
|
|
||||||
plugin_writable_file_ops == nullptr) {
|
|
||||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
|
||||||
"Filesystem allows creation of writable files but no "
|
|
||||||
"operations on them have been supplied.");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (plugin_filesystem_ops->new_read_only_memory_region_from_file != nullptr &&
|
|
||||||
plugin_read_only_memory_region_ops == nullptr) {
|
|
||||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
|
||||||
"Filesystem allows creation of readonly memory regions but no "
|
|
||||||
"operations on them have been supplied.");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copies a function table from plugin memory space to core memory space.
|
|
||||||
//
|
|
||||||
// This has three benefits:
|
|
||||||
// * allows having newer plugins than the current core TensorFlow: the
|
|
||||||
// additional entries in the plugin's table are just discarded;
|
|
||||||
// * allows having older plugins than the current core TensorFlow (though
|
|
||||||
// we are still warning users): the entries that core TensorFlow expects
|
|
||||||
// but plugins didn't provide will be set to `nullptr` values and core
|
|
||||||
// TensorFlow will know to not call these on behalf of users;
|
|
||||||
// * increased security as plugins will not be able to alter function table
|
|
||||||
// after loading up. Thus, malicious plugins can't alter functionality to
|
|
||||||
// probe for gadgets inside core TensorFlow. We can even protect the area
|
|
||||||
// of memory where the copies reside to not allow any more writes to it
|
|
||||||
// after all copies are created.
|
|
||||||
template <typename T>
|
|
||||||
static std::unique_ptr<const T> CopyToCore(const T* plugin_ops,
|
|
||||||
size_t plugin_size) {
|
|
||||||
if (plugin_ops == nullptr) return nullptr;
|
|
||||||
|
|
||||||
size_t copy_size = sizeof(T);
|
|
||||||
if (plugin_size < copy_size) {
|
|
||||||
copy_size = plugin_size;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto core_ops = tensorflow::MakeUnique<T>();
|
|
||||||
memcpy(const_cast<T*>(core_ops.get()), plugin_ops, copy_size);
|
|
||||||
return core_ops;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
void RegisterFilesystemPlugin(
|
|
||||||
int plugin_filesystem_ops_ABI, int plugin_filesystem_ops_API,
|
|
||||||
size_t plugin_filesystem_ops_size, int plugin_random_access_file_ops_ABI,
|
|
||||||
int plugin_random_access_file_ops_API,
|
|
||||||
size_t plugin_random_access_file_ops_size, int plugin_writable_file_ops_ABI,
|
|
||||||
int plugin_writable_file_ops_API, size_t plugin_writable_file_ops_size,
|
|
||||||
int plugin_read_only_memory_region_ops_ABI,
|
|
||||||
int plugin_read_only_memory_region_ops_API,
|
|
||||||
size_t plugin_read_only_memory_region_ops_size, const char* scheme,
|
|
||||||
const TF_FilesystemOps* plugin_filesystem_ops,
|
|
||||||
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
|
|
||||||
const TF_WritableFileOps* plugin_writable_file_ops,
|
|
||||||
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
|
|
||||||
TF_Status* status) {
|
|
||||||
if (scheme == nullptr) {
|
|
||||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
|
||||||
"`scheme` argument must not be `nullptr`.");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// ABI numbers must match exactly for plugin to be loaded
|
|
||||||
if (!tensorflow::CheckABI(
|
|
||||||
plugin_filesystem_ops_ABI, plugin_random_access_file_ops,
|
|
||||||
plugin_random_access_file_ops_ABI, plugin_writable_file_ops,
|
|
||||||
plugin_writable_file_ops_ABI, plugin_read_only_memory_region_ops,
|
|
||||||
plugin_read_only_memory_region_ops_ABI, status)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// API numbers should match but mismatch doesn't block plugin load
|
|
||||||
tensorflow::CheckAPI(plugin_filesystem_ops_API, plugin_random_access_file_ops,
|
|
||||||
plugin_random_access_file_ops_API,
|
|
||||||
plugin_writable_file_ops, plugin_writable_file_ops_API,
|
|
||||||
plugin_read_only_memory_region_ops,
|
|
||||||
plugin_read_only_memory_region_ops_API);
|
|
||||||
|
|
||||||
// Plugin can only be loaded if all supplied ops are valid
|
|
||||||
if (!tensorflow::Validate(plugin_filesystem_ops,
|
|
||||||
plugin_random_access_file_ops,
|
|
||||||
plugin_writable_file_ops,
|
|
||||||
plugin_read_only_memory_region_ops, status)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy all the function tables to core TensorFlow memory space
|
|
||||||
auto core_filesystem_ops = tensorflow::CopyToCore<TF_FilesystemOps>(
|
|
||||||
plugin_filesystem_ops, plugin_filesystem_ops_size);
|
|
||||||
auto core_random_access_file_ops =
|
|
||||||
tensorflow::CopyToCore<TF_RandomAccessFileOps>(
|
|
||||||
plugin_random_access_file_ops, plugin_random_access_file_ops_size);
|
|
||||||
auto core_writable_file_ops = tensorflow::CopyToCore<TF_WritableFileOps>(
|
|
||||||
plugin_writable_file_ops, plugin_writable_file_ops_size);
|
|
||||||
auto core_read_only_memory_region_ops =
|
|
||||||
tensorflow::CopyToCore<TF_ReadOnlyMemoryRegionOps>(
|
|
||||||
plugin_read_only_memory_region_ops,
|
|
||||||
plugin_read_only_memory_region_ops_size);
|
|
||||||
|
|
||||||
// Initialize the opaque filesystem structure
|
|
||||||
auto filesystem = tensorflow::MakeUnique<TF_Filesystem>();
|
|
||||||
core_filesystem_ops->init(filesystem.get(), status);
|
|
||||||
if (!status->status.ok()) {
|
|
||||||
core_filesystem_ops->cleanup(filesystem.get());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register new filesystem
|
|
||||||
status->status = tensorflow::Env::Default()->RegisterFileSystem(
|
|
||||||
scheme, tensorflow::MakeUnique<tensorflow::ModularFileSystem>(
|
|
||||||
std::move(filesystem), std::move(core_filesystem_ops),
|
|
||||||
std::move(core_random_access_file_ops),
|
|
||||||
std::move(core_writable_file_ops),
|
|
||||||
std::move(core_read_only_memory_region_ops)));
|
|
||||||
}
|
|
@ -529,7 +529,7 @@ typedef struct TF_FilesystemOps {
|
|||||||
/// If `statuses` is not null, plugins must fill each element with detailed
|
/// If `statuses` is not null, plugins must fill each element with detailed
|
||||||
/// status for each file, as if calling `path_exists` on each one. Core
|
/// status for each file, as if calling `path_exists` on each one. Core
|
||||||
/// TensorFlow initializes the `statuses` array and plugins must use
|
/// TensorFlow initializes the `statuses` array and plugins must use
|
||||||
/// `TF_SetStatus` to set each element instead of dirrectly assigning.
|
/// `TF_SetStatus` to set each element instead of directly assigning.
|
||||||
///
|
///
|
||||||
/// DEFAULT IMPLEMENTATION: Checks existence of every file. Needs
|
/// DEFAULT IMPLEMENTATION: Checks existence of every file. Needs
|
||||||
/// `path_exists`.
|
/// `path_exists`.
|
||||||
@ -736,95 +736,108 @@ constexpr size_t TF_FILESYSTEM_OPS_SIZE = sizeof(TF_FilesystemOps);
|
|||||||
/// SECTION 4. Plugin registration and initialization
|
/// SECTION 4. Plugin registration and initialization
|
||||||
/// ----------------------------------------------------------------------------
|
/// ----------------------------------------------------------------------------
|
||||||
///
|
///
|
||||||
/// In this section we define two functions:
|
/// In this section we define the API used by core TensorFlow to initialize a
|
||||||
/// * `TF_InitPlugin`: must be present in the plugin shared object as it will
|
/// filesystem provided by a plugin. That is, we define the following:
|
||||||
/// be called by core TensorFlow when the filesystem plugin is loaded;
|
/// * `TF_InitPlugin` function: must be present in the plugin shared object as
|
||||||
/// * `RegisterFilesystemPlugin`: it is implemented by core TensorFlow but
|
/// it will be called by core TensorFlow when the filesystem plugin is
|
||||||
/// plugins must call it in their `TF_InitPlugin`, usually using the macro
|
/// loaded;
|
||||||
/// `TF_REGISTER_FILESYSTEM_PLUGIN`.
|
/// * `TF_FilesystemPluginInfo` struct: used to transfer information between
|
||||||
|
/// plugins and core TensorFlow about the operations provided and metadata;
|
||||||
|
/// * `TF_SetFilesystemVersionMetadata` function: must be called by plugins in
|
||||||
|
/// their `TF_InitPlugin` to record the versioning information the plugins
|
||||||
|
/// are compiled against.
|
||||||
///
|
///
|
||||||
/// The `TF_InitPlugin` function is used by plugins to set up the data
|
/// The `TF_InitPlugin` function is used by plugins to set up the data
|
||||||
/// structures that implement this interface, as presented in Section 2.
|
/// structures that implement this interface, as presented in Section 2. In
|
||||||
///
|
/// order to not have plugin shared objects call back symbols defined in core
|
||||||
/// The `RegisterFilesystemPlugin` is used by core TensorFlow to check that
|
/// TensorFlow, `TF_InitPlugin` has a `TF_FilesystemPluginInfo` argument which
|
||||||
/// plugins satisfy the requirements expected by core TensorFlow, as follows:
|
/// the plugin must fill (using the `TF_SetFilesystemVersionMetadata` for the
|
||||||
/// 1. If ABI numbers don't match we don't load the plugin, else we continue.
|
/// metadata and setting up all the supported operations and the URI schemes
|
||||||
/// 2. If the API numbers are mismatched, we warn the user and continue
|
/// that are supported).
|
||||||
/// loading the plugin.
|
|
||||||
/// 3. If any required operation is missing, we stop loading the plugin.
|
|
||||||
///
|
|
||||||
/// If all these checks succeed, we copy the plugin operations to a different
|
|
||||||
/// memory location so that core TensorFlow has the guarantee that they won't be
|
|
||||||
/// changed by plugins at a later time. Finally, we initialize the opaque
|
|
||||||
/// pointer of `TF_Filesystem` by calling the required `init` function of
|
|
||||||
/// `TF_FilesystemOps` and if that succeeds we register the filesystem.
|
|
||||||
|
|
||||||
// Initializes a TensorFlow plugin.
|
/// This structure incorporates the operations defined in Section 2 and the
|
||||||
//
|
/// metadata defined in section 3, allowing plugins to define different ops
|
||||||
// Must be implemented by the plugin DSO. It is called by TensorFlow runtime.
|
/// for different URI schemes.
|
||||||
//
|
///
|
||||||
// Filesystem plugins can be loaded on demand by users via
|
/// Every URI scheme is of the form "fs" for URIs of form "fs:///path/to/file".
|
||||||
// `Env::LoadLibrary` or during TensorFlow's startup if they are on certain
|
/// For local filesystems (i.e., when the URI is "/path/to/file"), the scheme
|
||||||
// paths (although this has a security risk if two plugins register for the
|
/// must be "". The scheme must never be `nullptr`.
|
||||||
// same filesystem and the malicious one loads before the legimitate one -
|
///
|
||||||
// but we consider this to be something that users should care about and
|
/// Every plugin fills this in `TF_InitPlugin`, using the alocator passed as
|
||||||
// manage themselves). In both of these cases, core TensorFlow looks for
|
/// argument to allocate memory. After `TF_InitPlugin` finishes, core
|
||||||
// the `TF_InitPlugin` symbol and calls that function.
|
/// TensorFlow uses the information present in this to initialize filesystems
|
||||||
//
|
/// for the URI schemes that the plugin requests.
|
||||||
// A plugin is loaded only if this `status` is `TF_OK` after the call.
|
///
|
||||||
TF_CAPI_EXPORT extern void TF_InitPlugin(TF_Status* status);
|
/// All pointers defined in this structure point to memory allocated by the DSO
|
||||||
|
/// using an allocator provided by core TensorFlow when calling `TF_InitPlugin`.
|
||||||
|
///
|
||||||
|
/// IMPORTANT: To maintain binary compatibility, the layout of this structure
|
||||||
|
/// must not change! In the unlikely case that a new type of file needs to be
|
||||||
|
/// supported, add the new ops and metadata at the end of the structure.
|
||||||
|
typedef struct TF_FilesystemPluginInfo {
|
||||||
|
char* scheme;
|
||||||
|
int filesystem_ops_abi;
|
||||||
|
int filesystem_ops_api;
|
||||||
|
size_t filesystem_ops_size;
|
||||||
|
TF_FilesystemOps* filesystem_ops;
|
||||||
|
int random_access_file_ops_abi;
|
||||||
|
int random_access_file_ops_api;
|
||||||
|
size_t random_access_file_ops_size;
|
||||||
|
TF_RandomAccessFileOps* random_access_file_ops;
|
||||||
|
int writable_file_ops_abi;
|
||||||
|
int writable_file_ops_api;
|
||||||
|
size_t writable_file_ops_size;
|
||||||
|
TF_WritableFileOps* writable_file_ops;
|
||||||
|
int read_only_memory_region_ops_abi;
|
||||||
|
int read_only_memory_region_ops_api;
|
||||||
|
size_t read_only_memory_region_ops_size;
|
||||||
|
TF_ReadOnlyMemoryRegionOps* read_only_memory_region_ops;
|
||||||
|
} TF_FilesystemPluginInfo;
|
||||||
|
|
||||||
/// Registers a filesystem plugin so that core TensorFlow can use it.
|
/// Convenience function for setting the versioning metadata.
|
||||||
///
|
///
|
||||||
/// Must be called by the plugin during `TF_InitPlugin`, usually by using the
|
/// The argument is guaranteed to not be `nullptr`.
|
||||||
/// convenience `TF_REGISTER_FILESYSTEM_PLUGIN` macro.
|
|
||||||
///
|
///
|
||||||
/// Arguments (grouped by category):
|
/// We want this to be defined in the plugin's memory space and we guarantee
|
||||||
/// * `..ABI`: ABI compatibility numbers (see Section 3.).
|
/// that core TensorFlow will never call this.
|
||||||
/// * `..API`: API compatibility numbers (see Section 3.).
|
static inline void TF_SetFilesystemVersionMetadata(
|
||||||
/// * `..Size`: Sizes of the operation tables (see Section 3.).
|
TF_FilesystemPluginInfo* info) {
|
||||||
/// * `scheme`: The URI scheme that plugin is registering filesystems for.
|
info->filesystem_ops_abi = TF_FILESYSTEM_OPS_ABI;
|
||||||
/// Must be of the form "fs" for URIs of form "fs:///path/to/file". For
|
info->filesystem_ops_api = TF_FILESYSTEM_OPS_API;
|
||||||
/// local filesystems (i.e., when the URI is "/path/to/file"), `scheme`
|
info->filesystem_ops_size = TF_FILESYSTEM_OPS_SIZE;
|
||||||
/// must be "". Must never be `nullptr`.
|
info->random_access_file_ops_abi = TF_RANDOM_ACCESS_FILE_OPS_ABI;
|
||||||
/// * `..Ops`: The function tables provided by the plugin. Owned by the
|
info->random_access_file_ops_api = TF_RANDOM_ACCESS_FILE_OPS_API;
|
||||||
/// plugin, but core TensorFlow makes a copy of these.
|
info->random_access_file_ops_size = TF_RANDOM_ACCESS_FILE_OPS_SIZE;
|
||||||
/// * `status`: The output variable for representing success/failure.
|
info->writable_file_ops_abi = TF_WRITABLE_FILE_OPS_ABI;
|
||||||
///
|
info->writable_file_ops_api = TF_WRITABLE_FILE_OPS_API;
|
||||||
/// Sets `status` to `TF_OK` if plugin was registered and filesystem operations
|
info->writable_file_ops_size = TF_WRITABLE_FILE_OPS_SIZE;
|
||||||
/// can be invoked from anywhere during TensorFlow's runtime. Any other value of
|
info->read_only_memory_region_ops_abi = TF_READ_ONLY_MEMORY_REGION_OPS_ABI;
|
||||||
/// `status` means that plugin failed to load properly and as such the
|
info->read_only_memory_region_ops_api = TF_READ_ONLY_MEMORY_REGION_OPS_API;
|
||||||
/// operations it provides cannot be used at all (i.e., core TensorFlow will
|
info->read_only_memory_region_ops_size = TF_READ_ONLY_MEMORY_REGION_OPS_SIZE;
|
||||||
/// never run them, returning early with `TF_UNIMPLEMENTED` or similar error
|
}
|
||||||
/// values).
|
|
||||||
TF_CAPI_EXPORT extern void RegisterFilesystemPlugin(
|
|
||||||
int pluginFilesystemOpsABI, int pluginFilesystemOpsAPI,
|
|
||||||
size_t pluginFilesystemOpsSize, int pluginRandomAccessFileOpsABI,
|
|
||||||
int pluginRandomAccessFileOpsAPI, size_t pluginRandomAccessFileOpsSize,
|
|
||||||
int pluginWritableFileOpsABI, int pluginWritableFileOpsAPI,
|
|
||||||
size_t pluginWritableFileOpsSize, int pluginReadOnlyMemoryRegionOpsABI,
|
|
||||||
int pluginReadOnlyMemoryRegionOpsAPI,
|
|
||||||
size_t pluginReadOnlyMemoryRegionOpsSize, const char* scheme,
|
|
||||||
const TF_FilesystemOps* pluginFilesystemOps,
|
|
||||||
const TF_RandomAccessFileOps* pluginRandomAccessFileOps,
|
|
||||||
const TF_WritableFileOps* pluginWritableFileOps,
|
|
||||||
const TF_ReadOnlyMemoryRegionOps* pluginReadOnlyMemoryRegionOps,
|
|
||||||
TF_Status* status);
|
|
||||||
|
|
||||||
/// This macro is just a convenience wrapper around `RegisterFilesystemPlugin`.
|
/// Initializes a TensorFlow plugin.
|
||||||
/// Plugins should prefer using this macro instead of a direct call.
|
///
|
||||||
#define TF_REGISTER_FILESYSTEM_PLUGIN( \
|
/// Must be implemented by the plugin DSO. It is called by TensorFlow runtime.
|
||||||
scheme, pluginFilesystemOps, pluginRandomAccessFileOps, \
|
///
|
||||||
pluginWritableFileOps, pluginReadOnlyMemoryRegionOps, status) \
|
/// Filesystem plugins can be loaded on demand by users via
|
||||||
RegisterFilesystemPlugin( \
|
/// `Env::LoadLibrary` or during TensorFlow's startup if they are on certain
|
||||||
TF_FILESYSTEM_OPS_ABI, TF_FILESYSTEM_OPS_API, TF_FILESYSTEM_OPS_SIZE, \
|
/// paths (although this has a security risk if two plugins register for the
|
||||||
TF_RANDOM_ACCESS_FILE_OPS_ABI, TF_RANDOM_ACCESS_FILE_OPS_API, \
|
/// same filesystem and the malicious one loads before the legimitate one -
|
||||||
TF_RANDOM_ACCESS_FILE_OPS_SIZE, TF_WRITABLE_FILE_OPS_ABI, \
|
/// but we consider this to be something that users should care about and
|
||||||
TF_WRITABLE_FILE_OPS_API, TF_WRITABLE_FILE_OPS_SIZE, \
|
/// manage themselves). In both of these cases, core TensorFlow looks for
|
||||||
TF_READ_ONLY_MEMORY_REGION_OPS_ABI, TF_READ_ONLY_MEMORY_REGION_OPS_API, \
|
/// the `TF_InitPlugin` symbol and calls this function.
|
||||||
TF_READ_ONLY_MEMORY_REGION_OPS_SIZE, scheme, pluginFilesystemOps, \
|
///
|
||||||
pluginRandomAccessFileOps, pluginWritableFileOps, \
|
/// All memory allocated by this function must be allocated via the `allocator`
|
||||||
pluginReadOnlyMemoryRegionOps, status)
|
/// argument.
|
||||||
|
///
|
||||||
|
/// For every filesystem URI scheme that this plugin supports, the plugin must
|
||||||
|
/// add one `TF_FilesystemPluginInfo` entry in `plugin_info`.
|
||||||
|
///
|
||||||
|
/// Returns number of entries in `plugin_info` (i.e., number of URI schemes
|
||||||
|
/// supported).
|
||||||
|
TF_CAPI_EXPORT extern int TF_InitPlugin(void* (*allocator)(size_t size),
|
||||||
|
TF_FilesystemPluginInfo** plugin_info);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} // end extern "C"
|
} // end extern "C"
|
||||||
|
@ -18,11 +18,10 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h"
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/file_system_helper.h"
|
#include "tensorflow/core/platform/file_system_helper.h"
|
||||||
#include "tensorflow/core/platform/strcat.h"
|
|
||||||
#include "tensorflow/core/util/ptr_util.h"
|
#include "tensorflow/core/util/ptr_util.h"
|
||||||
|
|
||||||
// TODO(mihaimaruseac): After all filesystems are converted, all calls to
|
// TODO(mihaimaruseac): After all filesystems are converted, all calls to
|
||||||
@ -435,4 +434,8 @@ Status ModularWritableFile::Tell(int64* position) {
|
|||||||
return StatusFromTF_Status(plugin_status.get());
|
return StatusFromTF_Status(plugin_status.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status RegisterFilesystemPlugin(const std::string& dso_path) {
|
||||||
|
return filesystem_registration::RegisterFilesystemPluginImpl(dso_path);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -32,7 +32,7 @@ namespace tensorflow {
|
|||||||
// TODO(b/143949615): After all filesystems are converted, this file will be
|
// TODO(b/143949615): After all filesystems are converted, this file will be
|
||||||
// moved to core/platform, and this class can become a singleton and replace the
|
// moved to core/platform, and this class can become a singleton and replace the
|
||||||
// need for `Env::Default()`. At that time, we might decide to remove the need
|
// need for `Env::Default()`. At that time, we might decide to remove the need
|
||||||
// for `Env::Default()` altoghether, but that's a different project, not in
|
// for `Env::Default()` altogether, but that's a different project, not in
|
||||||
// scope for now. I'm just mentioning this here as that transition will mean
|
// scope for now. I'm just mentioning this here as that transition will mean
|
||||||
// removal of the registration part from `Env` and adding it here instead: we
|
// removal of the registration part from `Env` and adding it here instead: we
|
||||||
// will need tables to hold for each scheme the function tables that implement
|
// will need tables to hold for each scheme the function tables that implement
|
||||||
@ -156,6 +156,9 @@ class ModularReadOnlyMemoryRegion final : public ReadOnlyMemoryRegion {
|
|||||||
TF_DISALLOW_COPY_AND_ASSIGN(ModularReadOnlyMemoryRegion);
|
TF_DISALLOW_COPY_AND_ASSIGN(ModularReadOnlyMemoryRegion);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Registers a filesystem plugin so that core TensorFlow can use it.
|
||||||
|
Status RegisterFilesystemPlugin(const std::string& dso_path);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_H_
|
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_H_
|
||||||
|
@ -0,0 +1,325 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h"
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||||
|
#include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
|
||||||
|
#include "tensorflow/c/tf_status_internal.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
#include "tensorflow/core/util/ptr_util.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Checks that all schemes provided by a plugin are valid.
|
||||||
|
// TODO(mihaimaruseac): More validation could be done here, based on supported
|
||||||
|
// charset, maximum length, etc. Punting it for later.
|
||||||
|
static Status ValidateScheme(const char* scheme) {
|
||||||
|
if (scheme == nullptr)
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Attempted to register filesystem with `nullptr` URI scheme");
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checks if the plugin and core ABI numbers match.
|
||||||
|
//
|
||||||
|
// If the numbers don't match, plugin cannot be loaded.
|
||||||
|
static Status CheckABI(int pluginABI, int coreABI, StringPiece where) {
|
||||||
|
if (pluginABI != coreABI)
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
strings::StrCat("Plugin ABI (", pluginABI, ") for ", where,
|
||||||
|
" operations doesn't match expected core ABI (",
|
||||||
|
coreABI, "). Plugin cannot be loaded."));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checks if the plugin and core ABI numbers match, for all operations.
|
||||||
|
//
|
||||||
|
// If the numbers don't match, plugin cannot be loaded.
|
||||||
|
//
|
||||||
|
// Uses the simpler `CheckABI(int, int, StringPiece)`.
|
||||||
|
static Status ValidateABI(const TF_FilesystemPluginInfo* info) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
CheckABI(info->filesystem_ops_abi, TF_FILESYSTEM_OPS_ABI, "filesystem"));
|
||||||
|
|
||||||
|
if (info->random_access_file_ops != nullptr)
|
||||||
|
TF_RETURN_IF_ERROR(CheckABI(info->random_access_file_ops_abi,
|
||||||
|
TF_RANDOM_ACCESS_FILE_OPS_ABI,
|
||||||
|
"random access file"));
|
||||||
|
|
||||||
|
if (info->writable_file_ops != nullptr)
|
||||||
|
TF_RETURN_IF_ERROR(CheckABI(info->writable_file_ops_abi,
|
||||||
|
TF_WRITABLE_FILE_OPS_ABI, "writable file"));
|
||||||
|
|
||||||
|
if (info->read_only_memory_region_ops != nullptr)
|
||||||
|
TF_RETURN_IF_ERROR(CheckABI(info->read_only_memory_region_ops_abi,
|
||||||
|
TF_READ_ONLY_MEMORY_REGION_OPS_ABI,
|
||||||
|
"read only memory region"));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checks if the plugin and core API numbers match, logging mismatches.
|
||||||
|
static void CheckAPI(int plugin_API, int core_API, StringPiece where) {
|
||||||
|
if (plugin_API != core_API) {
|
||||||
|
VLOG(0) << "Plugin API (" << plugin_API << ") for " << where
|
||||||
|
<< " operations doesn't match expected core API (" << core_API
|
||||||
|
<< "). Plugin will be loaded but functionality might be missing.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checks if the plugin and core API numbers match, for all operations.
|
||||||
|
//
|
||||||
|
// Uses the simpler `CheckAPIHelper(int, int, StringPiece)`.
|
||||||
|
static void ValidateAPI(const TF_FilesystemPluginInfo* info) {
|
||||||
|
CheckAPI(info->filesystem_ops_api, TF_FILESYSTEM_OPS_API, "filesystem");
|
||||||
|
|
||||||
|
if (info->random_access_file_ops != nullptr)
|
||||||
|
CheckAPI(info->random_access_file_ops_api, TF_RANDOM_ACCESS_FILE_OPS_API,
|
||||||
|
"random access file");
|
||||||
|
|
||||||
|
if (info->writable_file_ops != nullptr)
|
||||||
|
CheckAPI(info->writable_file_ops_api, TF_WRITABLE_FILE_OPS_API,
|
||||||
|
"writable file");
|
||||||
|
|
||||||
|
if (info->read_only_memory_region_ops != nullptr)
|
||||||
|
CheckAPI(info->read_only_memory_region_ops_api,
|
||||||
|
TF_READ_ONLY_MEMORY_REGION_OPS_API, "read only memory region");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validates the filesystem operations supplied by the plugin.
|
||||||
|
static Status ValidateHelper(const TF_FilesystemOps* ops) {
|
||||||
|
if (ops == nullptr)
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Trying to register filesystem without operations");
|
||||||
|
|
||||||
|
if (ops->init == nullptr)
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Trying to register filesystem without `init` operation");
|
||||||
|
|
||||||
|
if (ops->cleanup == nullptr)
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Trying to register filesystem without `cleanup` operation");
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validates the random access file operations supplied by the plugin.
|
||||||
|
static Status ValidateHelper(const TF_RandomAccessFileOps* ops) {
|
||||||
|
if (ops == nullptr) {
|
||||||
|
// We allow filesystems where files can only be written to (from TF code)
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ops->cleanup == nullptr)
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Trying to register filesystem without `cleanup` operation on random "
|
||||||
|
"access files");
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validates the writable file operations supplied by the plugin.
|
||||||
|
static Status ValidateHelper(const TF_WritableFileOps* ops) {
|
||||||
|
if (ops == nullptr) {
|
||||||
|
// We allow read-only filesystems
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ops->cleanup == nullptr)
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Trying to register filesystem without `cleanup` operation on writable "
|
||||||
|
"files");
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validates the read only memory region operations given by the plugin.
|
||||||
|
static Status ValidateHelper(const TF_ReadOnlyMemoryRegionOps* ops) {
|
||||||
|
if (ops == nullptr) {
|
||||||
|
// read only memory region support is always optional
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ops->cleanup == nullptr)
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Trying to register filesystem without `cleanup` operation on read "
|
||||||
|
"only memory regions");
|
||||||
|
|
||||||
|
if (ops->data == nullptr)
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Trying to register filesystem without `data` operation on read only "
|
||||||
|
"memory regions");
|
||||||
|
|
||||||
|
if (ops->length == nullptr)
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Trying to register filesystem without `length` operation on read only "
|
||||||
|
"memory regions");
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validates the operations supplied by the plugin.
|
||||||
|
//
|
||||||
|
// Uses the 4 simpler `ValidateHelper(const TF_...*)` to validate each
|
||||||
|
// individual function table and then checks that the function table for a
|
||||||
|
// specific file type exists if the plugin offers support for creating that
|
||||||
|
// type of files.
|
||||||
|
static Status ValidateOperations(const TF_FilesystemPluginInfo* info) {
|
||||||
|
TF_RETURN_IF_ERROR(ValidateHelper(info->filesystem_ops));
|
||||||
|
TF_RETURN_IF_ERROR(ValidateHelper(info->random_access_file_ops));
|
||||||
|
TF_RETURN_IF_ERROR(ValidateHelper(info->writable_file_ops));
|
||||||
|
TF_RETURN_IF_ERROR(ValidateHelper(info->read_only_memory_region_ops));
|
||||||
|
|
||||||
|
if (info->filesystem_ops->new_random_access_file != nullptr &&
|
||||||
|
info->random_access_file_ops == nullptr)
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Filesystem allows creation of random access files but no "
|
||||||
|
"operations on them have been supplied.");
|
||||||
|
|
||||||
|
if ((info->filesystem_ops->new_writable_file != nullptr ||
|
||||||
|
info->filesystem_ops->new_appendable_file != nullptr) &&
|
||||||
|
info->writable_file_ops == nullptr)
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Filesystem allows creation of writable files but no "
|
||||||
|
"operations on them have been supplied.");
|
||||||
|
|
||||||
|
if (info->filesystem_ops->new_read_only_memory_region_from_file != nullptr &&
|
||||||
|
info->read_only_memory_region_ops == nullptr)
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Filesystem allows creation of readonly memory regions but no "
|
||||||
|
"operations on them have been supplied.");
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copies a function table from plugin memory space to core memory space.
|
||||||
|
//
|
||||||
|
// This has three benefits:
|
||||||
|
// * allows having newer plugins than the current core TensorFlow: the
|
||||||
|
// additional entries in the plugin's table are just discarded;
|
||||||
|
// * allows having older plugins than the current core TensorFlow (though
|
||||||
|
// we are still warning users): the entries that core TensorFlow expects
|
||||||
|
// but plugins didn't provide will be set to `nullptr` values and core
|
||||||
|
// TensorFlow will know to not call these on behalf of users;
|
||||||
|
// * increased security as plugins will not be able to alter function table
|
||||||
|
// after loading up. Thus, malicious plugins can't alter functionality to
|
||||||
|
// probe for gadgets inside core TensorFlow. We can even protect the area
|
||||||
|
// of memory where the copies reside to not allow any more writes to it
|
||||||
|
// after all copies are created.
|
||||||
|
template <typename T>
|
||||||
|
static std::unique_ptr<const T> CopyToCore(const T* plugin_ops,
|
||||||
|
size_t plugin_size) {
|
||||||
|
if (plugin_ops == nullptr) return nullptr;
|
||||||
|
|
||||||
|
size_t copy_size = std::min(plugin_size, sizeof(T));
|
||||||
|
auto core_ops = tensorflow::MakeUnique<T>();
|
||||||
|
memset(core_ops.get(), 0, sizeof(T));
|
||||||
|
memcpy(core_ops.get(), plugin_ops, copy_size);
|
||||||
|
return core_ops;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Registers one filesystem from the plugin.
|
||||||
|
static Status RegisterFileSystem(const TF_FilesystemPluginInfo* info) {
|
||||||
|
// Step 1: Copy all the function tables to core TensorFlow memory space
|
||||||
|
auto core_filesystem_ops = CopyToCore<TF_FilesystemOps>(
|
||||||
|
info->filesystem_ops, info->filesystem_ops_size);
|
||||||
|
auto core_random_access_file_ops = CopyToCore<TF_RandomAccessFileOps>(
|
||||||
|
info->random_access_file_ops, info->random_access_file_ops_size);
|
||||||
|
auto core_writable_file_ops = CopyToCore<TF_WritableFileOps>(
|
||||||
|
info->writable_file_ops, info->writable_file_ops_size);
|
||||||
|
auto core_read_only_memory_region_ops =
|
||||||
|
CopyToCore<TF_ReadOnlyMemoryRegionOps>(
|
||||||
|
info->read_only_memory_region_ops,
|
||||||
|
info->read_only_memory_region_ops_size);
|
||||||
|
|
||||||
|
// Step 2: Initialize the opaque filesystem structure
|
||||||
|
auto filesystem = tensorflow::MakeUnique<TF_Filesystem>();
|
||||||
|
TF_Status* c_status = TF_NewStatus();
|
||||||
|
Status status = Status::OK();
|
||||||
|
core_filesystem_ops->init(filesystem.get(), c_status);
|
||||||
|
status = Status(c_status->status);
|
||||||
|
TF_DeleteStatus(c_status);
|
||||||
|
if (!status.ok()) return status;
|
||||||
|
|
||||||
|
// Step 3: Actual registration
|
||||||
|
return Env::Default()->RegisterFileSystem(
|
||||||
|
info->scheme, tensorflow::MakeUnique<tensorflow::ModularFileSystem>(
|
||||||
|
std::move(filesystem), std::move(core_filesystem_ops),
|
||||||
|
std::move(core_random_access_file_ops),
|
||||||
|
std::move(core_writable_file_ops),
|
||||||
|
std::move(core_read_only_memory_region_ops)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Registers all filesystems, if plugin is providing valid information.
|
||||||
|
//
|
||||||
|
// Extracted to a separate function so that pointers inside `info` are freed
|
||||||
|
// by the caller regardless of whether validation/registration failed or not.
|
||||||
|
static Status ValidateAndRegisterFilesystems(
|
||||||
|
const TF_FilesystemPluginInfo* info) {
|
||||||
|
TF_RETURN_IF_ERROR(ValidateScheme(info->scheme));
|
||||||
|
TF_RETURN_IF_ERROR(ValidateABI(info));
|
||||||
|
ValidateAPI(info); // we just warn on API number mismatch
|
||||||
|
TF_RETURN_IF_ERROR(ValidateOperations(info));
|
||||||
|
TF_RETURN_IF_ERROR(RegisterFileSystem(info));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Alocates memory in plugin DSO.
|
||||||
|
//
|
||||||
|
// Provided by core TensorFlow so that it can free this memory after DSO is
|
||||||
|
// loaded and filesystem information has been used to register the filesystem.
|
||||||
|
static void* basic_allocator(size_t size) { return calloc(1, size); }
|
||||||
|
|
||||||
|
namespace filesystem_registration {
|
||||||
|
|
||||||
|
Status RegisterFilesystemPluginImpl(const std::string& dso_path) {
|
||||||
|
// Step 1: Load plugin
|
||||||
|
Env* env = Env::Default();
|
||||||
|
void* dso_handle;
|
||||||
|
TF_RETURN_IF_ERROR(env->LoadLibrary(dso_path.c_str(), &dso_handle));
|
||||||
|
|
||||||
|
// Step 2: Load symbol for `TF_InitPlugin`
|
||||||
|
void* dso_symbol;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
env->GetSymbolFromLibrary(dso_handle, "TF_InitPlugin", &dso_symbol));
|
||||||
|
|
||||||
|
// Step 3: Call `TF_InitPlugin`
|
||||||
|
TF_FilesystemPluginInfo* info = nullptr;
|
||||||
|
auto TF_InitPlugin = reinterpret_cast<int (*)(
|
||||||
|
decltype(&basic_allocator), TF_FilesystemPluginInfo**)>(dso_symbol);
|
||||||
|
int num_schemes = TF_InitPlugin(&basic_allocator, &info);
|
||||||
|
if (num_schemes < 0 || info == nullptr)
|
||||||
|
return errors::InvalidArgument("DSO returned invalid filesystem data");
|
||||||
|
|
||||||
|
// Step 4: Validate and register all filesystems
|
||||||
|
// Try to register as many filesystems as possible.
|
||||||
|
// Free memory once we no longer need it
|
||||||
|
Status status;
|
||||||
|
for (int i = 0; i < num_schemes; i++) {
|
||||||
|
status.Update(ValidateAndRegisterFilesystems(&info[i]));
|
||||||
|
free(info[i].scheme);
|
||||||
|
free(info[i].filesystem_ops);
|
||||||
|
free(info[i].random_access_file_ops);
|
||||||
|
free(info[i].writable_file_ops);
|
||||||
|
free(info[i].read_only_memory_region_ops);
|
||||||
|
}
|
||||||
|
free(info);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace filesystem_registration
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -0,0 +1,28 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace filesystem_registration {
|
||||||
|
|
||||||
|
Status RegisterFilesystemPluginImpl(const std::string& dso_path);
|
||||||
|
|
||||||
|
} // namespace filesystem_registration
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
|
File diff suppressed because it is too large
Load Diff
@ -1,35 +1,47 @@
|
|||||||
# Experimental posix filesystem plugin.
|
# Experimental posix filesystem plugin.
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
|
||||||
|
|
||||||
package(
|
package(
|
||||||
|
default_visibility = ["//visibility:private"],
|
||||||
licenses = ["notice"], # Apache 2.0
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
# Although this target results in a shared object that will be loaded at
|
# Filesystem implementation for POSIX environments: Linux, MacOS, Android, etc.
|
||||||
# runtime, this target must be a `cc_library` instead of a `cc_binary`. Making
|
tf_cc_shared_object(
|
||||||
# it a `cc_binary` requires `linkshared = True`. In turn, this brings in several
|
name = "libposix_filesystem.so",
|
||||||
# TensorFlow symbols under `tensorflow::` namespace, for which we have no ABI
|
framework_so = [],
|
||||||
# guarantees. Hence, in order to maintain ABI compatibility, this is marked as a
|
linkstatic = False,
|
||||||
# `cc_library` for now and we will revisit in the future.
|
visibility = ["//visibility:public"],
|
||||||
# TODO(mihaimaruseac): Determine if `cc_binary` makes more sense (when all
|
deps = [":posix_filesystem_impl"],
|
||||||
# filesystems are converted and BUILD files are refactored to be modular).
|
)
|
||||||
# TODO(b/144585140): The helpers should be separated into a different BUILD target
|
|
||||||
# but doing that would result in symbols not being visible when loading plugin.
|
# The real implementation of the filesystem.
|
||||||
# Revisit this once POSIX filesystem completely lands. See also the other TODO.
|
|
||||||
# This also has the unfortunate effect that both versions of copy_file get
|
|
||||||
# compiled, regardless of which one actually gets used!
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "posix_filesystem",
|
name = "posix_filesystem_impl",
|
||||||
srcs = [
|
srcs = ["posix_filesystem.cc"],
|
||||||
"posix_filesystem.cc",
|
|
||||||
"posix_filesystem_helper.cc",
|
|
||||||
"posix_filesystem_helper.h",
|
|
||||||
"copy_file.h",
|
|
||||||
] + select({
|
|
||||||
"//tensorflow:linux_x86_64": ["copy_file_linux.cc"],
|
|
||||||
"//conditions:default": ["copy_file_portable.cc"],
|
|
||||||
}),
|
|
||||||
deps = [
|
deps = [
|
||||||
|
":posix_filesystem_helper",
|
||||||
"//tensorflow/c:tf_status",
|
"//tensorflow/c:tf_status",
|
||||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Library implementing helper functionality, so that the above only contains
|
||||||
|
# the API implementation for modular filesystems.
|
||||||
|
cc_library(
|
||||||
|
name = "posix_filesystem_helper",
|
||||||
|
srcs = ["posix_filesystem_helper.cc"],
|
||||||
|
hdrs = ["posix_filesystem_helper.h"],
|
||||||
|
deps = [":copy_file"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# On Linux, we can copy files faster using `sendfile`. But not elsewhere.
|
||||||
|
# Hence, this private library to select which implementation to use.
|
||||||
|
cc_library(
|
||||||
|
name = "copy_file",
|
||||||
|
srcs = select({
|
||||||
|
"//tensorflow:linux_x86_64": ["copy_file_linux.cc"],
|
||||||
|
"//conditions:default": ["copy_file_portable.cc"],
|
||||||
|
}),
|
||||||
|
hdrs = ["copy_file.h"],
|
||||||
|
)
|
||||||
|
@ -24,8 +24,6 @@ limitations under the License.
|
|||||||
#include <sys/stat.h>
|
#include <sys/stat.h>
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||||
#include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_helper.h"
|
#include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_helper.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
@ -396,48 +394,65 @@ static int GetChildren(const TF_Filesystem* filesystem, const char* path,
|
|||||||
|
|
||||||
} // namespace tf_posix_filesystem
|
} // namespace tf_posix_filesystem
|
||||||
|
|
||||||
void TF_InitPlugin(TF_Status* status) {
|
int TF_InitPlugin(void* (*allocator)(size_t), TF_FilesystemPluginInfo** info) {
|
||||||
TF_RandomAccessFileOps random_access_file_ops = {
|
const int num_schemes = 2;
|
||||||
tf_random_access_file::Cleanup,
|
*info = static_cast<TF_FilesystemPluginInfo*>(
|
||||||
tf_random_access_file::Read,
|
allocator(num_schemes * sizeof((*info)[0])));
|
||||||
};
|
|
||||||
TF_WritableFileOps writable_file_ops = {
|
|
||||||
tf_writable_file::Cleanup, tf_writable_file::Append,
|
|
||||||
tf_writable_file::Tell, tf_writable_file::Flush,
|
|
||||||
tf_writable_file::Sync, tf_writable_file::Close,
|
|
||||||
};
|
|
||||||
TF_ReadOnlyMemoryRegionOps read_only_memory_region_ops = {
|
|
||||||
tf_read_only_memory_region::Cleanup,
|
|
||||||
tf_read_only_memory_region::Data,
|
|
||||||
tf_read_only_memory_region::Length,
|
|
||||||
};
|
|
||||||
TF_FilesystemOps filesystem_ops = {
|
|
||||||
tf_posix_filesystem::Init,
|
|
||||||
tf_posix_filesystem::Cleanup,
|
|
||||||
tf_posix_filesystem::NewRandomAccessFile,
|
|
||||||
tf_posix_filesystem::NewWritableFile,
|
|
||||||
tf_posix_filesystem::NewAppendableFile,
|
|
||||||
tf_posix_filesystem::NewReadOnlyMemoryRegionFromFile,
|
|
||||||
tf_posix_filesystem::CreateDir,
|
|
||||||
/*recursively_create_dir=*/nullptr,
|
|
||||||
tf_posix_filesystem::DeleteFile,
|
|
||||||
tf_posix_filesystem::DeleteDir,
|
|
||||||
/*delete_recursively=*/nullptr,
|
|
||||||
tf_posix_filesystem::RenameFile,
|
|
||||||
tf_posix_filesystem::CopyFile,
|
|
||||||
tf_posix_filesystem::PathExists,
|
|
||||||
/*paths_exist=*/nullptr,
|
|
||||||
tf_posix_filesystem::Stat,
|
|
||||||
/*is_directory=*/nullptr,
|
|
||||||
/*get_file_size=*/nullptr,
|
|
||||||
/*translate_name=*/nullptr,
|
|
||||||
tf_posix_filesystem::GetChildren,
|
|
||||||
/*get_matching_paths=*/nullptr,
|
|
||||||
/*flush_caches=*/nullptr,
|
|
||||||
};
|
|
||||||
|
|
||||||
for (const char* scheme : {"", "file"})
|
for (int i = 0; i < num_schemes; i++) {
|
||||||
TF_REGISTER_FILESYSTEM_PLUGIN(scheme, &filesystem_ops,
|
TF_FilesystemPluginInfo* current_info = &((*info)[i]);
|
||||||
&random_access_file_ops, &writable_file_ops,
|
TF_SetFilesystemVersionMetadata(current_info);
|
||||||
&read_only_memory_region_ops, status);
|
|
||||||
|
current_info->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
|
||||||
|
allocator(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
|
||||||
|
current_info->random_access_file_ops->cleanup =
|
||||||
|
tf_random_access_file::Cleanup;
|
||||||
|
current_info->random_access_file_ops->read = tf_random_access_file::Read;
|
||||||
|
|
||||||
|
current_info->writable_file_ops =
|
||||||
|
static_cast<TF_WritableFileOps*>(allocator(TF_WRITABLE_FILE_OPS_SIZE));
|
||||||
|
current_info->writable_file_ops->cleanup = tf_writable_file::Cleanup;
|
||||||
|
current_info->writable_file_ops->append = tf_writable_file::Append;
|
||||||
|
current_info->writable_file_ops->tell = tf_writable_file::Tell;
|
||||||
|
current_info->writable_file_ops->flush = tf_writable_file::Flush;
|
||||||
|
current_info->writable_file_ops->sync = tf_writable_file::Sync;
|
||||||
|
current_info->writable_file_ops->close = tf_writable_file::Close;
|
||||||
|
|
||||||
|
current_info->read_only_memory_region_ops =
|
||||||
|
static_cast<TF_ReadOnlyMemoryRegionOps*>(
|
||||||
|
allocator(TF_READ_ONLY_MEMORY_REGION_OPS_SIZE));
|
||||||
|
current_info->read_only_memory_region_ops->cleanup =
|
||||||
|
tf_read_only_memory_region::Cleanup;
|
||||||
|
current_info->read_only_memory_region_ops->data =
|
||||||
|
tf_read_only_memory_region::Data;
|
||||||
|
current_info->read_only_memory_region_ops->length =
|
||||||
|
tf_read_only_memory_region::Length;
|
||||||
|
|
||||||
|
current_info->filesystem_ops =
|
||||||
|
static_cast<TF_FilesystemOps*>(allocator(TF_FILESYSTEM_OPS_SIZE));
|
||||||
|
current_info->filesystem_ops->init = tf_posix_filesystem::Init;
|
||||||
|
current_info->filesystem_ops->cleanup = tf_posix_filesystem::Cleanup;
|
||||||
|
current_info->filesystem_ops->new_random_access_file =
|
||||||
|
tf_posix_filesystem::NewRandomAccessFile;
|
||||||
|
current_info->filesystem_ops->new_writable_file =
|
||||||
|
tf_posix_filesystem::NewWritableFile;
|
||||||
|
current_info->filesystem_ops->new_appendable_file =
|
||||||
|
tf_posix_filesystem::NewAppendableFile;
|
||||||
|
current_info->filesystem_ops->new_read_only_memory_region_from_file =
|
||||||
|
tf_posix_filesystem::NewReadOnlyMemoryRegionFromFile;
|
||||||
|
current_info->filesystem_ops->create_dir = tf_posix_filesystem::CreateDir;
|
||||||
|
current_info->filesystem_ops->delete_file = tf_posix_filesystem::DeleteFile;
|
||||||
|
current_info->filesystem_ops->delete_dir = tf_posix_filesystem::DeleteDir;
|
||||||
|
current_info->filesystem_ops->rename_file = tf_posix_filesystem::RenameFile;
|
||||||
|
current_info->filesystem_ops->copy_file = tf_posix_filesystem::CopyFile;
|
||||||
|
current_info->filesystem_ops->path_exists = tf_posix_filesystem::PathExists;
|
||||||
|
current_info->filesystem_ops->stat = tf_posix_filesystem::Stat;
|
||||||
|
current_info->filesystem_ops->get_children =
|
||||||
|
tf_posix_filesystem::GetChildren;
|
||||||
|
}
|
||||||
|
|
||||||
|
(*info)[0].scheme = strdup("");
|
||||||
|
(*info)[1].scheme = strdup("file");
|
||||||
|
|
||||||
|
return num_schemes;
|
||||||
}
|
}
|
||||||
|
@ -44,7 +44,7 @@ int TransferFileContents(const char* src, const char* dst, mode_t mode,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Both files have been opened, do the transfer.
|
// Both files have been opened, do the transfer.
|
||||||
// Since errno would be overriden by `close` below, save it here.
|
// Since errno would be overridden by `close` below, save it here.
|
||||||
int error_code = 0;
|
int error_code = 0;
|
||||||
if (CopyFileContents(dst_fd, src_fd, size) < 0) error_code = errno;
|
if (CopyFileContents(dst_fd, src_fd, size) < 0) error_code = errno;
|
||||||
|
|
||||||
|
36
tensorflow/c/experimental/filesystem/plugins/windows/BUILD
Normal file
36
tensorflow/c/experimental/filesystem/plugins/windows/BUILD
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
# Experimental windows filesystem plugin.
|
||||||
|
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
|
||||||
|
|
||||||
|
package(
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filesystem implementation for Windows environment
|
||||||
|
tf_cc_shared_object(
|
||||||
|
name = "windows_filesystem.dll",
|
||||||
|
framework_so = [],
|
||||||
|
linkstatic = False,
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
"nobuilder",
|
||||||
|
"notap",
|
||||||
|
],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [":windows_filesystem_impl"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# The real implementation of the filesystem.
|
||||||
|
cc_library(
|
||||||
|
name = "windows_filesystem_impl",
|
||||||
|
srcs = ["windows_filesystem.cc"],
|
||||||
|
copts = get_win_copts(),
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
"nobuilder",
|
||||||
|
"notap",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/c:tf_status",
|
||||||
|
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||||
|
],
|
||||||
|
)
|
@ -0,0 +1,70 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||||
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
|
||||||
|
// Implementation of a filesystem for POSIX environments.
|
||||||
|
// This filesystem will support `file://` and empty (local) URI schemes.
|
||||||
|
|
||||||
|
// SECTION 1. Implementation for `TF_RandomAccessFile`
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
namespace tf_random_access_file {
|
||||||
|
|
||||||
|
// TODO(mihaimaruseac): Implement later
|
||||||
|
|
||||||
|
} // namespace tf_random_access_file
|
||||||
|
|
||||||
|
// SECTION 2. Implementation for `TF_WritableFile`
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
namespace tf_writable_file {
|
||||||
|
|
||||||
|
// TODO(mihaimaruseac): Implement later
|
||||||
|
|
||||||
|
} // namespace tf_writable_file
|
||||||
|
|
||||||
|
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
namespace tf_read_only_memory_region {
|
||||||
|
|
||||||
|
// TODO(mihaimaruseac): Implement later
|
||||||
|
|
||||||
|
} // namespace tf_read_only_memory_region
|
||||||
|
|
||||||
|
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
namespace tf_windows_filesystem {
|
||||||
|
|
||||||
|
// TODO(mihaimaruseac): Implement later
|
||||||
|
|
||||||
|
} // namespace tf_windows_filesystem
|
||||||
|
|
||||||
|
int TF_InitPlugin(void* (*allocator)(size_t), TF_FilesystemPluginInfo** info) {
|
||||||
|
const int num_schemes = 2;
|
||||||
|
*info = static_cast<TF_FilesystemPluginInfo*>(
|
||||||
|
allocator(num_schemes * sizeof((*info)[0])));
|
||||||
|
|
||||||
|
for (int i = 0; i < num_schemes; i++) {
|
||||||
|
TF_FilesystemPluginInfo* current_info = &((*info)[i]);
|
||||||
|
TF_SetFilesystemVersionMetadata(current_info);
|
||||||
|
}
|
||||||
|
|
||||||
|
(*info)[0].scheme = strdup("");
|
||||||
|
(*info)[1].scheme = strdup("file");
|
||||||
|
|
||||||
|
return num_schemes;
|
||||||
|
}
|
@ -181,7 +181,8 @@ void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor,
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i));
|
const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i));
|
||||||
TF_Tensor* result = ::tensorflow::TF_TensorFromTensor(cc_tensor, status);
|
TF_Tensor* result =
|
||||||
|
::tensorflow::TF_TensorFromTensor(cc_tensor, &status->status);
|
||||||
if (TF_GetCode(status) == TF_OK) {
|
if (TF_GetCode(status) == TF_OK) {
|
||||||
*tensor = result;
|
*tensor = result;
|
||||||
}
|
}
|
||||||
|
@ -133,7 +133,7 @@ TEST(OpsTest, TestShapeInference_VectorizeFunction) {
|
|||||||
|
|
||||||
TEST(OpsTest, AttributeAccessors) {
|
TEST(OpsTest, AttributeAccessors) {
|
||||||
TF_OpDefinitionBuilder* builder =
|
TF_OpDefinitionBuilder* builder =
|
||||||
TF_NewOpDefinitionBuilder("AttributeAccesorsOp");
|
TF_NewOpDefinitionBuilder("AttributeAccessorsOp");
|
||||||
TF_OpDefinitionBuilderAddAttr(builder, "foo1: int >= 2");
|
TF_OpDefinitionBuilderAddAttr(builder, "foo1: int >= 2");
|
||||||
TF_OpDefinitionBuilderAddAttr(builder, "foo2: string=\"my string\"");
|
TF_OpDefinitionBuilderAddAttr(builder, "foo2: string=\"my string\"");
|
||||||
TF_OpDefinitionBuilderSetIsCommutative(builder, true);
|
TF_OpDefinitionBuilderSetIsCommutative(builder, true);
|
||||||
@ -151,7 +151,7 @@ TEST(OpsTest, AttributeAccessors) {
|
|||||||
op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
|
op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
|
||||||
bool found = false;
|
bool found = false;
|
||||||
for (const auto& op : op_list.op()) {
|
for (const auto& op : op_list.op()) {
|
||||||
if (op.name() == "AttributeAccesorsOp") {
|
if (op.name() == "AttributeAccessorsOp") {
|
||||||
ASSERT_TRUE(op.is_commutative());
|
ASSERT_TRUE(op.is_commutative());
|
||||||
ASSERT_TRUE(op.is_aggregate());
|
ASSERT_TRUE(op.is_aggregate());
|
||||||
ASSERT_TRUE(op.allows_uninitialized_input());
|
ASSERT_TRUE(op.allows_uninitialized_input());
|
||||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/c/tf_tensor.h"
|
#include "tensorflow/c/tf_tensor.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
#include "tensorflow/c/tf_tensor_internal.h"
|
#include "tensorflow/c/tf_tensor_internal.h"
|
||||||
@ -103,49 +105,35 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
|
|||||||
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
|
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_Tensor* ret =
|
// TODO(gjn): Make the choice of interface a compile-time configuration.
|
||||||
new TF_Tensor{Tensor(static_cast<tensorflow::DataType>(dtype),
|
tensorflow::TensorInterface ret(
|
||||||
tensorflow::TensorShape(dimvec), buf)};
|
Tensor(static_cast<tensorflow::DataType>(dtype),
|
||||||
|
tensorflow::TensorShape(dimvec), buf));
|
||||||
buf->Unref();
|
buf->Unref();
|
||||||
size_t elem_size = TF_DataTypeSize(dtype);
|
size_t elem_size = TF_DataTypeSize(dtype);
|
||||||
if (elem_size > 0 && len < (elem_size * ret->tensor.NumElements())) {
|
if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
|
||||||
delete ret;
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return ret;
|
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(ret)};
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) {
|
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* t) {
|
||||||
// It is safe to move the Tensor if and only if we own the unique reference to
|
return t->tensor->CanMove() ? t : nullptr;
|
||||||
// it. In that case, we might as well not delete and reallocate, but a future
|
|
||||||
// implementation might need to do so.
|
|
||||||
TensorBuffer* buf = tensorflow::TensorCApi::Buffer(tensor->tensor);
|
|
||||||
if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
|
|
||||||
buf->OwnsMemory()) {
|
|
||||||
return tensor;
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TF_DeleteTensor(TF_Tensor* t) { delete t; }
|
void TF_DeleteTensor(TF_Tensor* t) { delete t; }
|
||||||
|
|
||||||
TF_DataType TF_TensorType(const TF_Tensor* t) {
|
TF_DataType TF_TensorType(const TF_Tensor* t) { return t->tensor->Type(); }
|
||||||
return static_cast<TF_DataType>(t->tensor.dtype());
|
|
||||||
}
|
|
||||||
|
|
||||||
int TF_NumDims(const TF_Tensor* t) { return t->tensor.dims(); }
|
int TF_NumDims(const TF_Tensor* t) { return t->tensor->NumDims(); }
|
||||||
|
|
||||||
int64_t TF_Dim(const TF_Tensor* t, int dim_index) {
|
int64_t TF_Dim(const TF_Tensor* t, int dim_index) {
|
||||||
return static_cast<int64_t>(t->tensor.dim_size(dim_index));
|
return t->tensor->Dim(dim_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t TF_TensorByteSize(const TF_Tensor* t) {
|
size_t TF_TensorByteSize(const TF_Tensor* t) { return t->tensor->ByteSize(); }
|
||||||
return tensorflow::TensorCApi::Buffer(t->tensor)->size();
|
|
||||||
}
|
|
||||||
|
|
||||||
void* TF_TensorData(const TF_Tensor* t) {
|
void* TF_TensorData(const TF_Tensor* t) { return t->tensor->Data(); }
|
||||||
return tensorflow::TensorCApi::Buffer(t->tensor)->data();
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t TF_TensorElementCount(const TF_Tensor* t) {
|
int64_t TF_TensorElementCount(const TF_Tensor* t) {
|
||||||
int64_t result = 1;
|
int64_t result = 1;
|
||||||
@ -160,16 +148,69 @@ void TF_TensorBitcastFrom(const TF_Tensor* from, TF_DataType type,
|
|||||||
TF_Tensor* to, const int64_t* new_dims,
|
TF_Tensor* to, const int64_t* new_dims,
|
||||||
int num_new_dims, TF_Status* status) {
|
int num_new_dims, TF_Status* status) {
|
||||||
TF_SetStatus(status, TF_OK, "");
|
TF_SetStatus(status, TF_OK, "");
|
||||||
|
Status cc_status(
|
||||||
|
static_cast<tensorflow::TensorInterface*>(to->tensor.get())
|
||||||
|
->BitcastFrom(*static_cast<const tensorflow::TensorInterface*>(
|
||||||
|
from->tensor.get()),
|
||||||
|
type, new_dims, num_new_dims));
|
||||||
|
Set_TF_Status_from_Status(status, cc_status);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
bool TensorInterface::CanMove() const {
|
||||||
|
// It is safe to move the Tensor if and only if we own the unique reference to
|
||||||
|
// it. In that case, we might as well not delete and reallocate, but a future
|
||||||
|
// implementation might need to do so.
|
||||||
|
TensorBuffer* buf = tensorflow::TensorCApi::Buffer(tensor_);
|
||||||
|
if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
|
||||||
|
buf->OwnsMemory()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_DataType TensorInterface::Type() const {
|
||||||
|
return static_cast<TF_DataType>(tensor_.dtype());
|
||||||
|
}
|
||||||
|
|
||||||
|
int TensorInterface::NumDims() const { return tensor_.dims(); }
|
||||||
|
|
||||||
|
int64_t TensorInterface::Dim(int dim_index) const {
|
||||||
|
return static_cast<int64_t>(tensor_.dim_size(dim_index));
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t TensorInterface::NumElements() const {
|
||||||
|
return static_cast<int64_t>(tensor_.NumElements());
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t TensorInterface::ByteSize() const {
|
||||||
|
return tensorflow::TensorCApi::Buffer(tensor_)->size();
|
||||||
|
}
|
||||||
|
|
||||||
|
void* TensorInterface::Data() const {
|
||||||
|
return tensorflow::TensorCApi::Buffer(tensor_)->data();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TensorInterface::BitcastFrom(const TensorInterface& from,
|
||||||
|
TF_DataType type, const int64_t* new_dims,
|
||||||
|
int num_new_dims) {
|
||||||
tensorflow::TensorShape s;
|
tensorflow::TensorShape s;
|
||||||
for (int i = 0; i < num_new_dims; ++i) {
|
for (int i = 0; i < num_new_dims; ++i) {
|
||||||
s.AddDim(new_dims[i]);
|
s.AddDim(new_dims[i]);
|
||||||
}
|
}
|
||||||
Status cc_status(to->tensor.BitcastFrom(
|
return tensor_.BitcastFrom(from.tensor_,
|
||||||
from->tensor, static_cast<tensorflow::DataType>(type), s));
|
static_cast<tensorflow::DataType>(type), s);
|
||||||
Set_TF_Status_from_Status(status, cc_status);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
|
void StringEncode(const char* src, size_t src_len, char* dst) {
|
||||||
|
dst = tensorflow::core::EncodeVarint64(dst, src_len);
|
||||||
|
memcpy(dst, src, src_len);
|
||||||
|
}
|
||||||
|
|
||||||
size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
|
size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
|
||||||
size_t dst_len, TF_Status* status) {
|
size_t dst_len, TF_Status* status) {
|
||||||
const size_t sz = TF_StringEncodedSize(src_len);
|
const size_t sz = TF_StringEncodedSize(src_len);
|
||||||
@ -185,8 +226,7 @@ size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
|
|||||||
src_len, "-byte string"));
|
src_len, "-byte string"));
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
dst = tensorflow::core::EncodeVarint64(dst, src_len);
|
StringEncode(src, src_len, dst);
|
||||||
memcpy(dst, src, src_len);
|
|
||||||
return sz;
|
return sz;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -245,13 +285,11 @@ static TF_Tensor* EmptyTensor(TF_DataType dtype,
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
// Non-static for testing.
|
// Non-static for testing.
|
||||||
TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) {
|
||||||
TF_Status* status) {
|
*status = tensorflow::Status::OK();
|
||||||
TF_SetStatus(status, TF_OK, "");
|
|
||||||
if (!src.IsInitialized()) {
|
if (!src.IsInitialized()) {
|
||||||
Set_TF_Status_from_Status(
|
*status = FailedPrecondition(
|
||||||
status, FailedPrecondition(
|
"attempt to use a tensor with an uninitialized value");
|
||||||
"attempt to use a tensor with an uninitialized value"));
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (src.NumElements() == 0) {
|
if (src.NumElements() == 0) {
|
||||||
@ -259,14 +297,13 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
|||||||
}
|
}
|
||||||
if (src.dtype() == tensorflow::DT_RESOURCE) {
|
if (src.dtype() == tensorflow::DT_RESOURCE) {
|
||||||
if (src.shape().dims() != 0) {
|
if (src.shape().dims() != 0) {
|
||||||
Set_TF_Status_from_Status(
|
*status = InvalidArgument(
|
||||||
status, InvalidArgument(
|
"Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
|
||||||
"Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
|
src.shape().DebugString(),
|
||||||
src.shape().DebugString(),
|
"). Please file a bug at "
|
||||||
"). Please file a bug at "
|
"https://github.com/tensorflow/tensorflow/issues/new, "
|
||||||
"https://github.com/tensorflow/tensorflow/issues/new, "
|
"ideally with a "
|
||||||
"ideally with a "
|
"short code snippet that reproduces this error.");
|
||||||
"short code snippet that reproduces this error."));
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
const string str =
|
const string str =
|
||||||
@ -276,12 +313,11 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
|||||||
return t;
|
return t;
|
||||||
}
|
}
|
||||||
if (src.dtype() != tensorflow::DT_STRING) {
|
if (src.dtype() != tensorflow::DT_STRING) {
|
||||||
auto* result = new TF_Tensor();
|
Tensor tensor;
|
||||||
if (!result->tensor.CopyFrom(src, src.shape())) {
|
if (!tensor.CopyFrom(src, src.shape())) {
|
||||||
delete result;
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return result;
|
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(tensor)};
|
||||||
}
|
}
|
||||||
// DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly
|
// DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly
|
||||||
// encoded sequence of strings.
|
// encoded sequence of strings.
|
||||||
@ -305,23 +341,15 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
|||||||
*offsets = (dst - data_start);
|
*offsets = (dst - data_start);
|
||||||
offsets++;
|
offsets++;
|
||||||
const string& s = srcarray(i);
|
const string& s = srcarray(i);
|
||||||
size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status);
|
const size_t consumed = TF_StringEncodedSize(s.size());
|
||||||
if (TF_GetCode(status) != TF_OK) {
|
StringEncode(s.data(), s.size(), dst);
|
||||||
Set_TF_Status_from_Status(
|
|
||||||
status,
|
|
||||||
InvalidArgument("invalid string tensor encoding (string #", i, " of ",
|
|
||||||
srcarray.size(), "): ", TF_Message(status)));
|
|
||||||
delete[] base;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
dst += consumed;
|
dst += consumed;
|
||||||
dst_len -= consumed;
|
dst_len -= consumed;
|
||||||
}
|
}
|
||||||
if (dst != base + size) {
|
if (dst != base + size) {
|
||||||
Set_TF_Status_from_Status(
|
*status = InvalidArgument(
|
||||||
status, InvalidArgument(
|
"invalid string tensor encoding (decoded ", (dst - base),
|
||||||
"invalid string tensor encoding (decoded ", (dst - base),
|
" bytes, but the tensor is encoded in ", size, " bytes");
|
||||||
" bytes, but the tensor is encoded in ", size, " bytes"));
|
|
||||||
delete[] base;
|
delete[] base;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
@ -339,31 +367,35 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
|
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
|
||||||
if (src->tensor.dtype() == DT_RESOURCE) {
|
return static_cast<const tensorflow::TensorInterface*>(src->tensor.get())
|
||||||
if (src->tensor.dims() != 0) {
|
->ToTensor(dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TensorInterface::ToTensor(Tensor* dst) const {
|
||||||
|
if (tensor_.dtype() == DT_RESOURCE) {
|
||||||
|
if (tensor_.dims() != 0) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"Malformed TF_RESOURCE tensor: expected a scalar, got a tensor with "
|
"Malformed TF_RESOURCE tensor: expected a scalar, got a tensor with "
|
||||||
"shape ",
|
"shape ",
|
||||||
src->tensor.shape().DebugString());
|
tensor_.shape().DebugString());
|
||||||
}
|
}
|
||||||
*dst = Tensor(tensorflow::DT_RESOURCE, src->tensor.shape());
|
*dst = Tensor(tensorflow::DT_RESOURCE, tensor_.shape());
|
||||||
if (!dst->scalar<tensorflow::ResourceHandle>()().ParseFromString(
|
if (!dst->scalar<tensorflow::ResourceHandle>()().ParseFromString(
|
||||||
string(static_cast<const char*>(TF_TensorData(src)),
|
string(static_cast<const char*>(Data()), ByteSize()))) {
|
||||||
TF_TensorByteSize(src)))) {
|
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"Malformed TF_RESOUCE tensor: unable to parse resource handle");
|
"Malformed TF_RESOURCE tensor: unable to parse resource handle");
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
if (src->tensor.dtype() != DT_STRING) {
|
if (tensor_.dtype() != DT_STRING) {
|
||||||
*dst = src->tensor;
|
*dst = tensor_;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
// TF_STRING tensors require copying since Tensor class expects a sequence of
|
// TF_STRING tensors require copying since Tensor class expects a sequence of
|
||||||
// string objects.
|
// string objects.
|
||||||
const tensorflow::int64 num_elements = src->tensor.NumElements();
|
const tensorflow::int64 num_elements = tensor_.NumElements();
|
||||||
const char* input = reinterpret_cast<const char*>(TF_TensorData(src));
|
const char* input = reinterpret_cast<const char*>(Data());
|
||||||
const size_t src_size = TF_TensorByteSize(src);
|
const size_t src_size = ByteSize();
|
||||||
if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) <
|
if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) <
|
||||||
num_elements) {
|
num_elements) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
@ -372,7 +404,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
|
|||||||
const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
|
const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
|
||||||
const char* limit = input + src_size;
|
const char* limit = input + src_size;
|
||||||
|
|
||||||
*dst = Tensor(src->tensor.dtype(), src->tensor.shape());
|
*dst = Tensor(tensor_.dtype(), tensor_.shape());
|
||||||
auto dstarray = dst->flat<tstring>();
|
auto dstarray = dst->flat<tstring>();
|
||||||
for (tensorflow::int64 i = 0; i < num_elements; ++i) {
|
for (tensorflow::int64 i = 0; i < num_elements; ++i) {
|
||||||
tensorflow::uint64 offset =
|
tensorflow::uint64 offset =
|
||||||
@ -391,8 +423,8 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool TensorInterface::IsAligned() const { return tensor_.IsAligned(); }
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
bool TF_TensorIsAligned(const TF_Tensor* tensor) {
|
bool TF_TensorIsAligned(const TF_Tensor* t) { return t->tensor->IsAligned(); }
|
||||||
return tensor->tensor.IsAligned();
|
|
||||||
}
|
|
||||||
|
@ -16,9 +16,12 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
|
#ifndef TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
|
||||||
#define TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
|
#define TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
#include "tensorflow/c/tf_datatype.h"
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_interface.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
|
||||||
// Internal structures used by the C API. These are likely to change and should
|
// Internal structures used by the C API. These are likely to change and should
|
||||||
@ -28,7 +31,7 @@ limitations under the License.
|
|||||||
// passed to or returned from C functions *by pointer*. Otherwise, changes to
|
// passed to or returned from C functions *by pointer*. Otherwise, changes to
|
||||||
// its internal structure will break the C API's binary interface.
|
// its internal structure will break the C API's binary interface.
|
||||||
typedef struct TF_Tensor {
|
typedef struct TF_Tensor {
|
||||||
::tensorflow::Tensor tensor;
|
std::unique_ptr<AbstractTensorInterface> tensor;
|
||||||
} TF_Tensor;
|
} TF_Tensor;
|
||||||
|
|
||||||
class TF_ManagedBuffer : public tensorflow::TensorBuffer {
|
class TF_ManagedBuffer : public tensorflow::TensorBuffer {
|
||||||
@ -83,4 +86,5 @@ void* allocate_tensor(const char* operation, size_t len, Allocator* allocator);
|
|||||||
// a different Allocator as `arg`.
|
// a different Allocator as `arg`.
|
||||||
void deallocate_buffer(void* data, size_t len, void* arg);
|
void deallocate_buffer(void* data, size_t len, void* arg);
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
|
#endif // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
|
||||||
|
@ -96,7 +96,7 @@ class SymbolicGradientBuilder {
|
|||||||
// Used to identify nodes at which to stop backprop.
|
// Used to identify nodes at which to stop backprop.
|
||||||
std::unordered_set<int> GetStopBackpropNodes(
|
std::unordered_set<int> GetStopBackpropNodes(
|
||||||
const std::vector<bool>& reachable_nodes,
|
const std::vector<bool>& reachable_nodes,
|
||||||
const std::unordered_set<int>& output_nodes);
|
const std::unordered_set<int>& output_nodes) const;
|
||||||
|
|
||||||
const Scope& scope_;
|
const Scope& scope_;
|
||||||
const ops::GradOpRegistry* registry_;
|
const ops::GradOpRegistry* registry_;
|
||||||
@ -190,7 +190,7 @@ std::vector<bool> SymbolicGradientBuilder::GetReachableNodes() {
|
|||||||
|
|
||||||
std::unordered_set<int> SymbolicGradientBuilder::GetStopBackpropNodes(
|
std::unordered_set<int> SymbolicGradientBuilder::GetStopBackpropNodes(
|
||||||
const std::vector<bool>& reachable_nodes,
|
const std::vector<bool>& reachable_nodes,
|
||||||
const std::unordered_set<int>& output_nodes) {
|
const std::unordered_set<int>& output_nodes) const {
|
||||||
// Output nodes that get transitively consumed by other `outputs_` are stored
|
// Output nodes that get transitively consumed by other `outputs_` are stored
|
||||||
// in `internal_outputs`.
|
// in `internal_outputs`.
|
||||||
std::unordered_set<int> internal_outputs;
|
std::unordered_set<int> internal_outputs;
|
||||||
@ -346,8 +346,8 @@ Status SymbolicGradientBuilder::SumGradients(const Output& src, Output* grad) {
|
|||||||
"Unable to find backprop list for node.id ", src.node()->name());
|
"Unable to find backprop list for node.id ", src.node()->name());
|
||||||
}
|
}
|
||||||
const auto& grads = iter->second;
|
const auto& grads = iter->second;
|
||||||
// Filter any backproped 'NoGradient' Outputs from 'grads' (if needed).
|
// Filter any backpropped 'NoGradient' Outputs from 'grads' (if needed).
|
||||||
// Return any valid backproped gradients that remain after filtering,
|
// Return any valid backpropped gradients that remain after filtering,
|
||||||
// or 'NoGradient' otherwise.
|
// or 'NoGradient' otherwise.
|
||||||
std::vector<Output> grads_to_keep;
|
std::vector<Output> grads_to_keep;
|
||||||
for (const Output& o : grads) {
|
for (const Output& o : grads) {
|
||||||
@ -519,7 +519,7 @@ Status SymbolicGradientBuilder::AddGradients() {
|
|||||||
// Backprop along the in edges.
|
// Backprop along the in edges.
|
||||||
// TODO(andydavis) Find cleaner way to map each grad output returned by
|
// TODO(andydavis) Find cleaner way to map each grad output returned by
|
||||||
// gradient function to the src node/output to which it should be
|
// gradient function to the src node/output to which it should be
|
||||||
// backproped. Maybe grad functions can return a vector of Output pairs to
|
// backpropped. Maybe grad functions can return a vector of Output pairs to
|
||||||
// make this association explicit.
|
// make this association explicit.
|
||||||
size_t dx_index = 0;
|
size_t dx_index = 0;
|
||||||
for (const Edge* e : n->in_edges()) {
|
for (const Edge* e : n->in_edges()) {
|
||||||
|
@ -64,7 +64,7 @@ bool IsZero(const Scope& scope, const Output& grad) {
|
|||||||
// Multiply after broadcasting vec to match dimensions of mat.
|
// Multiply after broadcasting vec to match dimensions of mat.
|
||||||
// Args:
|
// Args:
|
||||||
// vec: A 1-D tensor of dimension [D0]
|
// vec: A 1-D tensor of dimension [D0]
|
||||||
// mat: A 2-D tensor of dimesnion [D0, D1]
|
// mat: A 2-D tensor of dimension [D0, D1]
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// A tensor of dimension [D0, D1], the result fo vec * mat.
|
// A tensor of dimension [D0, D1], the result fo vec * mat.
|
||||||
|
@ -259,6 +259,9 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) {
|
|||||||
RunTest(x, x_init_value, y, y_shape);
|
RunTest(x, x_init_value, y, y_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(rocm):
|
||||||
|
// Re-enable this test once 3D pooling is supported on ROCm platform
|
||||||
|
#ifndef TENSORFLOW_USE_ROCM
|
||||||
TEST_F(NNGradTest, MaxPool3DGradHelper) {
|
TEST_F(NNGradTest, MaxPool3DGradHelper) {
|
||||||
TensorShape x_shape({1, 3, 3, 3, 1});
|
TensorShape x_shape({1, 3, 3, 3, 1});
|
||||||
TensorShape y_shape({1, 1, 1, 1, 1});
|
TensorShape y_shape({1, 1, 1, 1, 1});
|
||||||
@ -271,6 +274,7 @@ TEST_F(NNGradTest, MaxPool3DGradHelper) {
|
|||||||
SetRandomValuesForMaxPooling<float>(&x_init_value);
|
SetRandomValuesForMaxPooling<float>(&x_init_value);
|
||||||
RunTest(x, x_init_value, y, y_shape);
|
RunTest(x, x_init_value, y, y_shape);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
TEST_F(NNGradTest, AvgPoolGradHelper) {
|
TEST_F(NNGradTest, AvgPoolGradHelper) {
|
||||||
TensorShape x_shape({1, 2, 2, 1});
|
TensorShape x_shape({1, 2, 2, 1});
|
||||||
@ -283,6 +287,9 @@ TEST_F(NNGradTest, AvgPoolGradHelper) {
|
|||||||
RunTest(x, x_shape, y, y_shape);
|
RunTest(x, x_shape, y, y_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(rocm):
|
||||||
|
// Re-enable this test once 3D pooling is supported on ROCm platform
|
||||||
|
#ifndef TENSORFLOW_USE_ROCM
|
||||||
TEST_F(NNGradTest, AvgPool3DGradHelper) {
|
TEST_F(NNGradTest, AvgPool3DGradHelper) {
|
||||||
TensorShape x_shape({1, 3, 3, 3, 1});
|
TensorShape x_shape({1, 3, 3, 3, 1});
|
||||||
TensorShape y_shape({1, 1, 1, 1, 1});
|
TensorShape y_shape({1, 1, 1, 1, 1});
|
||||||
@ -293,6 +300,7 @@ TEST_F(NNGradTest, AvgPool3DGradHelper) {
|
|||||||
auto y = AvgPool3D(scope_, x, ksize, strides, "SAME");
|
auto y = AvgPool3D(scope_, x, ksize, strides, "SAME");
|
||||||
RunTest(x, x_shape, y, y_shape);
|
RunTest(x, x_shape, y, y_shape);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
TEST_F(NNGradTest, LRN) {
|
TEST_F(NNGradTest, LRN) {
|
||||||
TensorShape x_shape({1, 1, 2, 1});
|
TensorShape x_shape({1, 1, 2, 1});
|
||||||
|
@ -124,13 +124,12 @@ cc_library(
|
|||||||
hdrs = ["bundle_v2.h"],
|
hdrs = ["bundle_v2.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":constants",
|
":constants",
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
|
||||||
] + if_not_mobile([
|
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/platform:strcat",
|
"//tensorflow/core/platform:strcat",
|
||||||
"//tensorflow/core/util/tensor_bundle",
|
"//tensorflow/core/util/tensor_bundle",
|
||||||
]),
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
|
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
|
||||||
|
load("//tensorflow/core/platform:build_config.bzl", "if_llvm_aarch64_available")
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = ["//visibility:private"],
|
default_visibility = ["//visibility:private"],
|
||||||
@ -27,9 +28,14 @@ cc_library(
|
|||||||
"compile.h",
|
"compile.h",
|
||||||
"flags.h",
|
"flags.h",
|
||||||
],
|
],
|
||||||
|
defines = if_llvm_aarch64_available(["TF_LLVM_AARCH64_AVAILABLE=1"]),
|
||||||
|
visibility = ["//tensorflow/python:__pkg__"],
|
||||||
deps = [
|
deps = [
|
||||||
":aot_only_var_handle_op",
|
":aot_only_var_handle_op",
|
||||||
":embedded_protocol_buffers",
|
":embedded_protocol_buffers",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
"//tensorflow/compiler/tf2xla",
|
"//tensorflow/compiler/tf2xla",
|
||||||
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
|
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
|
||||||
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
|
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
|
||||||
@ -53,10 +59,13 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"@com_google_absl//absl/memory",
|
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
|
||||||
"@com_google_absl//absl/strings",
|
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
||||||
"@com_google_absl//absl/types:span",
|
"@llvm-project//llvm:target",
|
||||||
],
|
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||||
|
] + if_llvm_aarch64_available([
|
||||||
|
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
|
||||||
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
@ -86,6 +95,19 @@ tf_cc_binary(
|
|||||||
deps = [":tfcompile_main"],
|
deps = [":tfcompile_main"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "llvm_targets",
|
||||||
|
visibility = ["//tensorflow/python:__pkg__"],
|
||||||
|
deps = [
|
||||||
|
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
|
||||||
|
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
||||||
|
"@llvm-project//llvm:target",
|
||||||
|
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||||
|
] + if_llvm_aarch64_available([
|
||||||
|
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tfcompile_main",
|
name = "tfcompile_main",
|
||||||
srcs = ["tfcompile_main.cc"],
|
srcs = ["tfcompile_main.cc"],
|
||||||
@ -104,11 +126,6 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@llvm-project//llvm:aarch64_code_gen", # fixdeps: keep
|
|
||||||
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
|
|
||||||
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
|
||||||
"@llvm-project//llvm:target",
|
|
||||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -214,8 +231,13 @@ cc_library(
|
|||||||
cc_library(
|
cc_library(
|
||||||
name = "aot_only_var_handle_op",
|
name = "aot_only_var_handle_op",
|
||||||
srcs = ["aot_only_var_handle_op.cc"],
|
srcs = ["aot_only_var_handle_op.cc"],
|
||||||
|
hdrs = ["aot_only_var_handle_op.h"],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow/compiler/tf2xla:__pkg__",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
@ -13,9 +13,12 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/aot/aot_only_var_handle_op.h"
|
||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
@ -51,6 +54,31 @@ void XlaAotOnlyVarHandleOp::Compile(XlaOpKernelContext* context) {
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
REGISTER_XLA_OP(Name("VarHandleOp").CompilationOnly(), XlaAotOnlyVarHandleOp);
|
REGISTER_OP(tfcompile::kXlaAotOnlyVarHandleOp)
|
||||||
|
.Doc(R"doc(
|
||||||
|
Internal VarHandleOp registration used for XLA AOT compilation.
|
||||||
|
)doc")
|
||||||
|
.Attr("container: string = ''")
|
||||||
|
.Attr("shared_name: string = ''")
|
||||||
|
.Attr("dtype: type")
|
||||||
|
.Attr("shape: shape")
|
||||||
|
.Output("resource: resource")
|
||||||
|
.SetIsStateful()
|
||||||
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
|
c->set_output(0, c->Scalar());
|
||||||
|
DataType t;
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t));
|
||||||
|
PartialTensorShape p;
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("shape", &p));
|
||||||
|
shape_inference::ShapeHandle s;
|
||||||
|
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s));
|
||||||
|
c->set_output_handle_shapes_and_types(
|
||||||
|
0, std::vector<shape_inference::ShapeAndType>{{s, t}});
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name(tfcompile::kXlaAotOnlyVarHandleOp).CompilationOnly(),
|
||||||
|
XlaAotOnlyVarHandleOp);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
27
tensorflow/compiler/aot/aot_only_var_handle_op.h
Normal file
27
tensorflow/compiler/aot/aot_only_var_handle_op.h
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_
|
||||||
|
#define TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace tfcompile {
|
||||||
|
|
||||||
|
static constexpr const char* const kXlaAotOnlyVarHandleOp =
|
||||||
|
"_XlaAotOnlyVarHandleOp";
|
||||||
|
|
||||||
|
} // namespace tfcompile
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_
|
@ -74,16 +74,16 @@ void DumpStatsToStdout(const Stats& stats) {
|
|||||||
const int kBufSize = 1000;
|
const int kBufSize = 1000;
|
||||||
char buf[kBufSize];
|
char buf[kBufSize];
|
||||||
snprintf(buf, kBufSize, "Mean with %2.0f%% trimmed:", trim_ratio * 100);
|
snprintf(buf, kBufSize, "Mean with %2.0f%% trimmed:", trim_ratio * 100);
|
||||||
const string label_trimmed(buf);
|
std::string label_trimmed(buf);
|
||||||
snprintf(buf, kBufSize, "Mean of %2.0f%% best:", best_ratio * 100);
|
snprintf(buf, kBufSize, "Mean of %2.0f%% best:", best_ratio * 100);
|
||||||
const string label_best(buf);
|
std::string label_best(buf);
|
||||||
std::vector<std::pair<string, double>> groups = {
|
std::vector<std::pair<std::string, double>> groups = {
|
||||||
{"Best:", sorted_us.front()},
|
{"Best:", sorted_us.front()},
|
||||||
{"Worst:", sorted_us.back()},
|
{"Worst:", sorted_us.back()},
|
||||||
{"Median:", sorted_us[count_us / 2]},
|
{"Median:", sorted_us[count_us / 2]},
|
||||||
{"Mean:", sum_us / count_us},
|
{"Mean:", sum_us / count_us},
|
||||||
{label_trimmed, sum_us_trimmed / count_us_trimmed},
|
{std::move(label_trimmed), sum_us_trimmed / count_us_trimmed},
|
||||||
{label_best, sum_us_best / count_us_best},
|
{std::move(label_best), sum_us_best / count_us_best},
|
||||||
};
|
};
|
||||||
int max_label_size = 0;
|
int max_label_size = 0;
|
||||||
double max_us = 0;
|
double max_us = 0;
|
||||||
@ -102,7 +102,7 @@ void DumpStatsToStdout(const Stats& stats) {
|
|||||||
}
|
}
|
||||||
// Dump stats out.
|
// Dump stats out.
|
||||||
printf("Benchmark ran %zu iterations over %lld us\n", count_us,
|
printf("Benchmark ran %zu iterations over %lld us\n", count_us,
|
||||||
stats.total_us);
|
static_cast<long long>(stats.total_us)); // NOLINT
|
||||||
for (const auto& g : groups) {
|
for (const auto& g : groups) {
|
||||||
printf(" %-*s %*.3f us\n", max_label_size, g.first.c_str(), max_digits + 4,
|
printf(" %-*s %*.3f us\n", max_label_size, g.first.c_str(), max_digits + 4,
|
||||||
g.second);
|
g.second);
|
||||||
@ -114,7 +114,8 @@ void Benchmark(const Options& options, const BenchmarkFn& fn, Stats* stats) {
|
|||||||
const int64 max_us = (options.max_micros <= 0 && options.max_iters <= 0)
|
const int64 max_us = (options.max_micros <= 0 && options.max_iters <= 0)
|
||||||
? Options::kDefaultMicros
|
? Options::kDefaultMicros
|
||||||
: options.max_micros;
|
: options.max_micros;
|
||||||
printf("Running benchmark for %lld us\n", max_us);
|
// NOLINTNEXTLINE
|
||||||
|
printf("Running benchmark for %lld us\n", static_cast<long long>(max_us));
|
||||||
const int64 start_us = NowMicros();
|
const int64 start_us = NowMicros();
|
||||||
int64 iters = 0;
|
int64 iters = 0;
|
||||||
while (true) {
|
while (true) {
|
||||||
|
@ -423,8 +423,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
|
|||||||
GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
|
GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
|
||||||
const string include_xla_data_proto =
|
const string include_xla_data_proto =
|
||||||
opts.gen_program_shape
|
opts.gen_program_shape
|
||||||
?
|
? R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
|
||||||
R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
|
|
||||||
: "";
|
: "";
|
||||||
|
|
||||||
const string include_hlo_profile_printer_data_proto =
|
const string include_hlo_profile_printer_data_proto =
|
||||||
|
@ -20,6 +20,8 @@ limitations under the License.
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "llvm-c/Target.h"
|
||||||
|
#include "tensorflow/compiler/aot/codegen.h"
|
||||||
#include "tensorflow/compiler/aot/flags.h"
|
#include "tensorflow/compiler/aot/flags.h"
|
||||||
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
||||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||||
@ -90,7 +92,7 @@ Status CompileXla(xla::CompileOnlyClient* client,
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
|
||||||
const MainFlags& flags, CompileResult* compile_result) {
|
const MainFlags& flags, CompileResult* compile_result) {
|
||||||
// Converts the graph into an XLA computation, and compiles the
|
// Converts the graph into an XLA computation, and compiles the
|
||||||
// computation.
|
// computation.
|
||||||
@ -108,8 +110,8 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
|||||||
if (!flags.mlir_components.empty()) {
|
if (!flags.mlir_components.empty()) {
|
||||||
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
|
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config,
|
||||||
ConvertGraphDefToXla(graph_def, config, client, &computation));
|
client, &computation));
|
||||||
}
|
}
|
||||||
if (!flags.out_session_module.empty()) {
|
if (!flags.out_session_module.empty()) {
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
|
||||||
@ -132,5 +134,96 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
|||||||
return CompileXla(client, computation, aot_opts, compile_result);
|
return CompileXla(client, computation, aot_opts, compile_result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
|
||||||
|
if (absl::EndsWith(fname, ".pbtxt")) {
|
||||||
|
return ReadTextProto(Env::Default(), fname, proto);
|
||||||
|
} else {
|
||||||
|
return ReadBinaryProto(Env::Default(), fname, proto);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::once_flag targets_init;
|
||||||
|
|
||||||
|
static void InitializeTargets() {
|
||||||
|
// Initialize all LLVM targets so we can cross compile.
|
||||||
|
#if TF_LLVM_AARCH64_AVAILABLE
|
||||||
|
LLVMInitializeAArch64Target();
|
||||||
|
LLVMInitializeAArch64TargetInfo();
|
||||||
|
LLVMInitializeAArch64TargetMC();
|
||||||
|
LLVMInitializeAArch64AsmPrinter();
|
||||||
|
#endif
|
||||||
|
LLVMInitializeARMTarget();
|
||||||
|
LLVMInitializeARMTargetInfo();
|
||||||
|
LLVMInitializeARMTargetMC();
|
||||||
|
LLVMInitializeARMAsmPrinter();
|
||||||
|
LLVMInitializePowerPCTarget();
|
||||||
|
LLVMInitializePowerPCTargetInfo();
|
||||||
|
LLVMInitializePowerPCTargetMC();
|
||||||
|
LLVMInitializePowerPCAsmPrinter();
|
||||||
|
LLVMInitializeX86Target();
|
||||||
|
LLVMInitializeX86TargetInfo();
|
||||||
|
LLVMInitializeX86TargetMC();
|
||||||
|
LLVMInitializeX86AsmPrinter();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Main(const MainFlags& flags) {
|
||||||
|
std::call_once(targets_init, &InitializeTargets);
|
||||||
|
|
||||||
|
// Process config.
|
||||||
|
tf2xla::Config config;
|
||||||
|
if (flags.config.empty()) {
|
||||||
|
return errors::InvalidArgument("Must specify --config");
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
|
||||||
|
TF_RETURN_IF_ERROR(ValidateConfig(config));
|
||||||
|
if (flags.dump_fetch_nodes) {
|
||||||
|
std::set<string> nodes;
|
||||||
|
for (const tf2xla::Fetch& fetch : config.fetch()) {
|
||||||
|
nodes.insert(fetch.id().node_name());
|
||||||
|
}
|
||||||
|
std::cout << absl::StrJoin(nodes, ",");
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read and initialize the graph.
|
||||||
|
if (flags.graph.empty()) {
|
||||||
|
return errors::InvalidArgument("Must specify --graph");
|
||||||
|
}
|
||||||
|
GraphDef graph_def;
|
||||||
|
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
|
||||||
|
CompileResult compile_result;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
CompileGraph(std::move(graph_def), config, flags, &compile_result));
|
||||||
|
|
||||||
|
// Write output files.
|
||||||
|
Env* env = Env::Default();
|
||||||
|
const std::vector<char>& obj = compile_result.aot->object_file_data();
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
WriteStringToFile(env, flags.out_function_object,
|
||||||
|
absl::string_view(obj.data(), obj.size())));
|
||||||
|
CodegenOpts codegen_opts;
|
||||||
|
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
|
||||||
|
codegen_opts.gen_program_shape = flags.gen_program_shape;
|
||||||
|
codegen_opts.target_triple = flags.target_triple;
|
||||||
|
if (flags.cpp_class.empty()) {
|
||||||
|
return errors::InvalidArgument("Must specify --cpp_class");
|
||||||
|
}
|
||||||
|
codegen_opts.gen_hlo_profile_printer_data =
|
||||||
|
xla::GetDebugOptionsFromFlags().xla_hlo_profile();
|
||||||
|
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
|
||||||
|
&codegen_opts.namespaces));
|
||||||
|
|
||||||
|
MetadataResult metadata_result;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
GenerateMetadata(codegen_opts, compile_result, &metadata_result));
|
||||||
|
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object,
|
||||||
|
metadata_result.object_file_data));
|
||||||
|
string header;
|
||||||
|
TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result,
|
||||||
|
metadata_result, &header));
|
||||||
|
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tfcompile
|
} // namespace tfcompile
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -42,9 +42,12 @@ struct CompileResult {
|
|||||||
// that performs the graph operations.
|
// that performs the graph operations.
|
||||||
//
|
//
|
||||||
// The XLA compilation options are specified in the flags.
|
// The XLA compilation options are specified in the flags.
|
||||||
Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
|
||||||
const MainFlags& flags, CompileResult* compile_result);
|
const MainFlags& flags, CompileResult* compile_result);
|
||||||
|
|
||||||
|
// The full compilation method, for reuse in a library setting.
|
||||||
|
Status Main(const MainFlags& flags);
|
||||||
|
|
||||||
} // namespace tfcompile
|
} // namespace tfcompile
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ namespace tensorflow {
|
|||||||
namespace tfcompile {
|
namespace tfcompile {
|
||||||
|
|
||||||
// Flags for the tfcompile binary. See *.cc file for descriptions.
|
// Flags for the tfcompile binary. See *.cc file for descriptions.
|
||||||
|
|
||||||
struct MainFlags {
|
struct MainFlags {
|
||||||
string graph;
|
string graph;
|
||||||
string config;
|
string config;
|
||||||
|
@ -25,6 +25,7 @@ test_suite(
|
|||||||
":test_graph_tfmatmulandadd_test",
|
":test_graph_tfmatmulandadd_test",
|
||||||
":test_graph_tfsplits_test",
|
":test_graph_tfsplits_test",
|
||||||
":test_graph_tftop_k_test",
|
":test_graph_tftop_k_test",
|
||||||
|
":test_graph_tfvariable_readonly_test",
|
||||||
":test_graph_tfvariable_sequential_updates_test",
|
":test_graph_tfvariable_sequential_updates_test",
|
||||||
":test_graph_tfvariable_test",
|
":test_graph_tfvariable_test",
|
||||||
":tfcompile_test",
|
":tfcompile_test",
|
||||||
@ -73,6 +74,7 @@ genrule(
|
|||||||
"test_graph_tfsplits.pb",
|
"test_graph_tfsplits.pb",
|
||||||
"test_graph_tftop_k.pb",
|
"test_graph_tftop_k.pb",
|
||||||
"test_graph_tfvariable.pb",
|
"test_graph_tfvariable.pb",
|
||||||
|
"test_graph_tfvariable_readonly.pb",
|
||||||
"test_graph_tfvariable_sequential_updates.pb",
|
"test_graph_tfvariable_sequential_updates.pb",
|
||||||
],
|
],
|
||||||
# Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any
|
# Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any
|
||||||
@ -238,6 +240,17 @@ tf_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_library(
|
||||||
|
name = "test_graph_tfvariable_readonly",
|
||||||
|
testonly = 1,
|
||||||
|
config = "test_graph_tfvariable_readonly.config.pbtxt",
|
||||||
|
cpp_class = "VariableReadonlyComp",
|
||||||
|
graph = "test_graph_tfvariable_readonly.pb",
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_library(
|
tf_library(
|
||||||
name = "test_graph_tfvariable_sequential_updates",
|
name = "test_graph_tfvariable_sequential_updates",
|
||||||
testonly = 1,
|
testonly = 1,
|
||||||
@ -269,6 +282,7 @@ tf_cc_test(
|
|||||||
":test_graph_tfsplits",
|
":test_graph_tfsplits",
|
||||||
":test_graph_tftop_k",
|
":test_graph_tftop_k",
|
||||||
":test_graph_tfvariable",
|
":test_graph_tfvariable",
|
||||||
|
":test_graph_tfvariable_readonly",
|
||||||
":test_graph_tfvariable_sequential_updates",
|
":test_graph_tfvariable_sequential_updates",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:test",
|
"//tensorflow/compiler/xla:test",
|
||||||
@ -323,6 +337,42 @@ tf_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_library(
|
||||||
|
name = "test_graph_tfcond_mlir_bridge",
|
||||||
|
testonly = 1,
|
||||||
|
config = "test_graph_tfcond.config.pbtxt",
|
||||||
|
cpp_class = "CondComp",
|
||||||
|
graph = "test_graph_tfcond.pb",
|
||||||
|
mlir_components = "Bridge",
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_library(
|
||||||
|
name = "test_graph_tfassert_eq_mlir_bridge",
|
||||||
|
testonly = 1,
|
||||||
|
config = "test_graph_tfassert_eq.config.pbtxt",
|
||||||
|
cpp_class = "AssertComp",
|
||||||
|
graph = "test_graph_tfassert_eq.pb",
|
||||||
|
mlir_components = "Bridge",
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_library(
|
||||||
|
name = "test_graph_tfgather_mlir_bridge",
|
||||||
|
testonly = 1,
|
||||||
|
config = "test_graph_tfgather.config.pbtxt",
|
||||||
|
cpp_class = "GatherComp",
|
||||||
|
graph = "test_graph_tfgather.pb",
|
||||||
|
mlir_components = "Bridge",
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_library(
|
tf_library(
|
||||||
name = "test_graph_tfmatmul_mlir_bridge",
|
name = "test_graph_tfmatmul_mlir_bridge",
|
||||||
testonly = 1,
|
testonly = 1,
|
||||||
@ -361,6 +411,42 @@ tf_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_library(
|
||||||
|
name = "test_graph_tfsplits_mlir_bridge",
|
||||||
|
testonly = 1,
|
||||||
|
config = "test_graph_tfsplits.config.pbtxt",
|
||||||
|
cpp_class = "SplitsComp",
|
||||||
|
graph = "test_graph_tfsplits.pb",
|
||||||
|
mlir_components = "Bridge",
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_library(
|
||||||
|
name = "test_graph_tftop_k_mlir_bridge",
|
||||||
|
testonly = 1,
|
||||||
|
config = "test_graph_tftop_k.config.pbtxt",
|
||||||
|
cpp_class = "TopKComp",
|
||||||
|
graph = "test_graph_tftop_k.pb",
|
||||||
|
mlir_components = "Bridge",
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_library(
|
||||||
|
name = "test_graph_tfvariable_readonly_mlir_bridge",
|
||||||
|
testonly = 1,
|
||||||
|
config = "test_graph_tfvariable_readonly.config.pbtxt",
|
||||||
|
cpp_class = "VariableReadonlyComp",
|
||||||
|
graph = "test_graph_tfvariable_readonly.pb",
|
||||||
|
mlir_components = "Bridge",
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "tfcompile_test_mlir_bridge",
|
name = "tfcompile_test_mlir_bridge",
|
||||||
srcs = ["tfcompile_test.cc"],
|
srcs = ["tfcompile_test.cc"],
|
||||||
@ -372,9 +458,15 @@ tf_cc_test(
|
|||||||
":test_graph_tfadd_mlir_bridge",
|
":test_graph_tfadd_mlir_bridge",
|
||||||
":test_graph_tfadd_with_ckpt_mlir_bridge",
|
":test_graph_tfadd_with_ckpt_mlir_bridge",
|
||||||
":test_graph_tfadd_with_ckpt_saver_mlir_bridge",
|
":test_graph_tfadd_with_ckpt_saver_mlir_bridge",
|
||||||
|
":test_graph_tfassert_eq_mlir_bridge",
|
||||||
|
":test_graph_tfcond_mlir_bridge",
|
||||||
|
":test_graph_tfgather_mlir_bridge",
|
||||||
":test_graph_tfmatmul_mlir_bridge",
|
":test_graph_tfmatmul_mlir_bridge",
|
||||||
":test_graph_tfmatmulandadd_mlir_bridge",
|
":test_graph_tfmatmulandadd_mlir_bridge",
|
||||||
":test_graph_tfmatmulandadd_with_profiling_mlir_bridge",
|
":test_graph_tfmatmulandadd_with_profiling_mlir_bridge",
|
||||||
|
":test_graph_tfsplits_mlir_bridge",
|
||||||
|
":test_graph_tftop_k_mlir_bridge",
|
||||||
|
":test_graph_tfvariable_readonly_mlir_bridge",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:test",
|
"//tensorflow/compiler/xla:test",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
|
@ -34,6 +34,7 @@ from tensorflow.python.framework import function
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import control_flow_util
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn_ops
|
from tensorflow.python.ops import nn_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
@ -153,6 +154,14 @@ def tftop_k(_):
|
|||||||
array_ops.identity(output[1], name='indices')
|
array_ops.identity(output[1], name='indices')
|
||||||
|
|
||||||
|
|
||||||
|
def tfvariable_readonly(_):
|
||||||
|
x = variables.Variable(1000.0, name='x')
|
||||||
|
old_x = x.value()
|
||||||
|
with ops.control_dependencies([old_x]):
|
||||||
|
new_value = math_ops.add(old_x, 42.0)
|
||||||
|
array_ops.identity(new_value, name='result')
|
||||||
|
|
||||||
|
|
||||||
def tfvariable(_):
|
def tfvariable(_):
|
||||||
x = variables.Variable(1000.0, name='x')
|
x = variables.Variable(1000.0, name='x')
|
||||||
old_x = x.value()
|
old_x = x.value()
|
||||||
@ -184,6 +193,7 @@ def write_graph(build_graph, out_dir):
|
|||||||
|
|
||||||
|
|
||||||
def main(_):
|
def main(_):
|
||||||
|
control_flow_util.enable_control_flow_v2()
|
||||||
write_graph(tfadd, FLAGS.out_dir)
|
write_graph(tfadd, FLAGS.out_dir)
|
||||||
write_graph(tfadd_with_ckpt, FLAGS.out_dir)
|
write_graph(tfadd_with_ckpt, FLAGS.out_dir)
|
||||||
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
|
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
|
||||||
@ -196,6 +206,7 @@ def main(_):
|
|||||||
write_graph(tfsplits, FLAGS.out_dir)
|
write_graph(tfsplits, FLAGS.out_dir)
|
||||||
write_graph(tftop_k, FLAGS.out_dir)
|
write_graph(tftop_k, FLAGS.out_dir)
|
||||||
write_graph(tfvariable, FLAGS.out_dir)
|
write_graph(tfvariable, FLAGS.out_dir)
|
||||||
|
write_graph(tfvariable_readonly, FLAGS.out_dir)
|
||||||
write_graph(tfvariable_sequential_updates, FLAGS.out_dir)
|
write_graph(tfvariable_sequential_updates, FLAGS.out_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@ -0,0 +1,12 @@
|
|||||||
|
# Text form of tensorflow.tf2xla.Config proto.
|
||||||
|
fetch {
|
||||||
|
id { node_name: "result" }
|
||||||
|
}
|
||||||
|
|
||||||
|
variable {
|
||||||
|
node_name: "x"
|
||||||
|
shape {
|
||||||
|
}
|
||||||
|
type: DT_FLOAT
|
||||||
|
readonly: true
|
||||||
|
}
|
@ -30,9 +30,15 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_mlir_bridge.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_mlir_bridge.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_mlir_bridge.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_mlir_bridge.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver_mlir_bridge.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver_mlir_bridge.h"
|
||||||
|
#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq_mlir_bridge.h"
|
||||||
|
#include "tensorflow/compiler/aot/tests/test_graph_tfcond_mlir_bridge.h"
|
||||||
|
#include "tensorflow/compiler/aot/tests/test_graph_tfgather_mlir_bridge.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul_mlir_bridge.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul_mlir_bridge.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_mlir_bridge.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_mlir_bridge.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling_mlir_bridge.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling_mlir_bridge.h"
|
||||||
|
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits_mlir_bridge.h"
|
||||||
|
#include "tensorflow/compiler/aot/tests/test_graph_tftop_k_mlir_bridge.h"
|
||||||
|
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly_mlir_bridge.h"
|
||||||
#else
|
#else
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
|
||||||
@ -47,6 +53,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable.h"
|
||||||
|
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly.h"
|
||||||
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates.h"
|
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -167,8 +174,6 @@ TEST(TFCompileTest, AddWithCkptSaver) {
|
|||||||
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
|
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(bixia): the following tests failed with MLIR bridge.
|
|
||||||
#if !defined(ENABLE_MLIR_BRIDGE_TEST)
|
|
||||||
TEST(TFCompileTest, Cond) {
|
TEST(TFCompileTest, Cond) {
|
||||||
CondComp cond;
|
CondComp cond;
|
||||||
EXPECT_EQ(cond.arg0_data(), cond.arg_data(0));
|
EXPECT_EQ(cond.arg0_data(), cond.arg_data(0));
|
||||||
@ -233,7 +238,6 @@ TEST(TFCompileTest, Gather) {
|
|||||||
EXPECT_EQ(gather_const.result0_data(), gather.results()[0]);
|
EXPECT_EQ(gather_const.result0_data(), gather.results()[0]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
TEST(TFCompileTest, MatMul2) {
|
TEST(TFCompileTest, MatMul2) {
|
||||||
Eigen::ThreadPool tp(2);
|
Eigen::ThreadPool tp(2);
|
||||||
@ -439,6 +443,7 @@ TEST(TFCompileTest, Function) {
|
|||||||
EXPECT_EQ(add_fn.result0_data()[0], 3);
|
EXPECT_EQ(add_fn.result0_data()[0], 3);
|
||||||
EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]);
|
EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
TEST(TFCompileTest, Splits) {
|
TEST(TFCompileTest, Splits) {
|
||||||
Eigen::ThreadPool tp(1);
|
Eigen::ThreadPool tp(1);
|
||||||
@ -492,6 +497,22 @@ TEST(TFCompileTest, TopK) {
|
|||||||
EXPECT_EQ(expected_indices[1], fn.result1(1));
|
EXPECT_EQ(expected_indices[1], fn.result1(1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(TFCompileTest, VariableReadonly) {
|
||||||
|
Eigen::ThreadPool tp(1);
|
||||||
|
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
|
||||||
|
|
||||||
|
VariableReadonlyComp fn;
|
||||||
|
float x = 23;
|
||||||
|
fn.set_var_x_data(&x);
|
||||||
|
|
||||||
|
fn.set_thread_pool(&device);
|
||||||
|
fn.Run();
|
||||||
|
EXPECT_EQ(fn.result0(), 65);
|
||||||
|
EXPECT_EQ(fn.var_x(), 23);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(bixia): the following tests failed with MLIR bridge.
|
||||||
|
#if !defined(ENABLE_MLIR_BRIDGE_TEST)
|
||||||
TEST(TFCompileTest, Variable) {
|
TEST(TFCompileTest, Variable) {
|
||||||
Eigen::ThreadPool tp(1);
|
Eigen::ThreadPool tp(1);
|
||||||
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
|
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
|
||||||
@ -564,6 +585,7 @@ TEST(TFCompileTest, VariableSequentialUpdatesNoAlloc) {
|
|||||||
fn.Run();
|
fn.Run();
|
||||||
EXPECT_NEAR(x, 0.594322f, 1e-6);
|
EXPECT_NEAR(x, 0.594322f, 1e-6);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
TEST(TFCompileTest, AssertEqAndReturnDiff) {
|
TEST(TFCompileTest, AssertEqAndReturnDiff) {
|
||||||
// Assert is converted into a no-op in XLA, so there is no failure even if the
|
// Assert is converted into a no-op in XLA, so there is no failure even if the
|
||||||
@ -665,6 +687,11 @@ TEST(TFCompileTest, HloProfiling) {
|
|||||||
/*clock_rate_ghz=*/1.0);
|
/*clock_rate_ghz=*/1.0);
|
||||||
VLOG(1) << "Original HLO profile string:\n" << hlo_profile_as_string;
|
VLOG(1) << "Original HLO profile string:\n" << hlo_profile_as_string;
|
||||||
|
|
||||||
|
// Replace Arg_n with argn when the MLIR bridge is used.
|
||||||
|
#if defined(ENABLE_MLIR_BRIDGE_TEST)
|
||||||
|
RE2::GlobalReplace(&hlo_profile_as_string, "(Arg_)([0-9].)", "arg\\2");
|
||||||
|
#endif
|
||||||
|
|
||||||
// Strip away identifier details from the profile string to avoid this test
|
// Strip away identifier details from the profile string to avoid this test
|
||||||
// being a change detector for xla internals. Identifiers such as '%dot.0.7'
|
// being a change detector for xla internals. Identifiers such as '%dot.0.7'
|
||||||
// just become '%dot'.
|
// just become '%dot'.
|
||||||
@ -690,7 +717,6 @@ TEST(TFCompileTest, HloProfiling) {
|
|||||||
IsSupersetOf({header, total_cycles_profile_line, dot_profile_line,
|
IsSupersetOf({header, total_cycles_profile_line, dot_profile_line,
|
||||||
add_profile_line, tuple_profile_line}));
|
add_profile_line, tuple_profile_line}));
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tfcompile
|
} // namespace tfcompile
|
||||||
|
@ -21,7 +21,6 @@ limitations under the License.
|
|||||||
#include "absl/strings/match.h"
|
#include "absl/strings/match.h"
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "llvm-c/Target.h"
|
|
||||||
#include "tensorflow/compiler/aot/codegen.h"
|
#include "tensorflow/compiler/aot/codegen.h"
|
||||||
#include "tensorflow/compiler/aot/compile.h"
|
#include "tensorflow/compiler/aot/compile.h"
|
||||||
#include "tensorflow/compiler/aot/flags.h"
|
#include "tensorflow/compiler/aot/flags.h"
|
||||||
@ -56,88 +55,6 @@ const char kUsageHeader[] =
|
|||||||
"--cpp_class=\"mynamespace::MyComputation\"\n"
|
"--cpp_class=\"mynamespace::MyComputation\"\n"
|
||||||
"\n";
|
"\n";
|
||||||
|
|
||||||
Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
|
|
||||||
if (absl::EndsWith(fname, ".pbtxt")) {
|
|
||||||
return ReadTextProto(Env::Default(), fname, proto);
|
|
||||||
} else {
|
|
||||||
return ReadBinaryProto(Env::Default(), fname, proto);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Status Main(const MainFlags& flags) {
|
|
||||||
// Initialize all LLVM targets so we can cross compile.
|
|
||||||
LLVMInitializeAArch64Target();
|
|
||||||
LLVMInitializeAArch64TargetInfo();
|
|
||||||
LLVMInitializeAArch64TargetMC();
|
|
||||||
LLVMInitializeAArch64AsmPrinter();
|
|
||||||
LLVMInitializeARMTarget();
|
|
||||||
LLVMInitializeARMTargetInfo();
|
|
||||||
LLVMInitializeARMTargetMC();
|
|
||||||
LLVMInitializeARMAsmPrinter();
|
|
||||||
LLVMInitializePowerPCTarget();
|
|
||||||
LLVMInitializePowerPCTargetInfo();
|
|
||||||
LLVMInitializePowerPCTargetMC();
|
|
||||||
LLVMInitializePowerPCAsmPrinter();
|
|
||||||
LLVMInitializeX86Target();
|
|
||||||
LLVMInitializeX86TargetInfo();
|
|
||||||
LLVMInitializeX86TargetMC();
|
|
||||||
LLVMInitializeX86AsmPrinter();
|
|
||||||
|
|
||||||
// Process config.
|
|
||||||
tf2xla::Config config;
|
|
||||||
if (flags.config.empty()) {
|
|
||||||
return errors::InvalidArgument("Must specify --config");
|
|
||||||
}
|
|
||||||
TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
|
|
||||||
TF_RETURN_IF_ERROR(ValidateConfig(config));
|
|
||||||
if (flags.dump_fetch_nodes) {
|
|
||||||
std::set<string> nodes;
|
|
||||||
for (const tf2xla::Fetch& fetch : config.fetch()) {
|
|
||||||
nodes.insert(fetch.id().node_name());
|
|
||||||
}
|
|
||||||
std::cout << absl::StrJoin(nodes, ",");
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read and initialize the graph.
|
|
||||||
if (flags.graph.empty()) {
|
|
||||||
return errors::InvalidArgument("Must specify --graph");
|
|
||||||
}
|
|
||||||
GraphDef graph_def;
|
|
||||||
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
|
|
||||||
CompileResult compile_result;
|
|
||||||
TF_RETURN_IF_ERROR(CompileGraph(graph_def, config, flags, &compile_result));
|
|
||||||
|
|
||||||
// Write output files.
|
|
||||||
Env* env = Env::Default();
|
|
||||||
const std::vector<char>& obj = compile_result.aot->object_file_data();
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
WriteStringToFile(env, flags.out_function_object,
|
|
||||||
absl::string_view(obj.data(), obj.size())));
|
|
||||||
CodegenOpts codegen_opts;
|
|
||||||
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
|
|
||||||
codegen_opts.gen_program_shape = flags.gen_program_shape;
|
|
||||||
codegen_opts.target_triple = flags.target_triple;
|
|
||||||
if (flags.cpp_class.empty()) {
|
|
||||||
return errors::InvalidArgument("Must specify --cpp_class");
|
|
||||||
}
|
|
||||||
codegen_opts.gen_hlo_profile_printer_data =
|
|
||||||
xla::GetDebugOptionsFromFlags().xla_hlo_profile();
|
|
||||||
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
|
|
||||||
&codegen_opts.namespaces));
|
|
||||||
|
|
||||||
MetadataResult metadata_result;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
GenerateMetadata(codegen_opts, compile_result, &metadata_result));
|
|
||||||
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object,
|
|
||||||
metadata_result.object_file_data));
|
|
||||||
string header;
|
|
||||||
TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result,
|
|
||||||
metadata_result, &header));
|
|
||||||
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // end namespace tfcompile
|
} // end namespace tfcompile
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
@ -4,12 +4,7 @@ load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilati
|
|||||||
load("//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
|
load("//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = [
|
default_visibility = [":internal"],
|
||||||
":internal",
|
|
||||||
# BEGIN-GOOGLE-INTERNAL
|
|
||||||
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
|
|
||||||
# END-GOOGLE-INTERNAL
|
|
||||||
],
|
|
||||||
licenses = ["notice"], # Apache 2.0
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -82,19 +77,6 @@ cc_library(
|
|||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "xla_mlir_gpu_jit",
|
|
||||||
visibility = ["//visibility:public"],
|
|
||||||
deps = if_cuda_or_rocm([
|
|
||||||
":jit_compilation_passes",
|
|
||||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
|
||||||
"//tensorflow/compiler/xla/service:mlir_gpu_plugin",
|
|
||||||
]),
|
|
||||||
alwayslink = 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "xla_cpu_device",
|
name = "xla_cpu_device",
|
||||||
srcs = ["xla_cpu_device.cc"],
|
srcs = ["xla_cpu_device.cc"],
|
||||||
@ -120,6 +102,7 @@ cc_library(
|
|||||||
srcs = ["xla_gpu_device.cc"],
|
srcs = ["xla_gpu_device.cc"],
|
||||||
visibility = [":friends"],
|
visibility = [":friends"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":flags",
|
||||||
":jit_compilation_passes",
|
":jit_compilation_passes",
|
||||||
":xla_device",
|
":xla_device",
|
||||||
":xla_kernel_creator", # buildcleaner: keep
|
":xla_kernel_creator", # buildcleaner: keep
|
||||||
@ -128,6 +111,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
|
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
|
"//tensorflow/core:gpu_init",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
@ -1584,7 +1584,6 @@ DeadnessAnalysis::~DeadnessAnalysis() {}
|
|||||||
absl::flat_hash_map<TensorId, string, TensorId::Hasher>
|
absl::flat_hash_map<TensorId, string, TensorId::Hasher>
|
||||||
DeadnessAnalysisImpl::PredicateMapAsString() const {
|
DeadnessAnalysisImpl::PredicateMapAsString() const {
|
||||||
absl::flat_hash_map<TensorId, string, TensorId::Hasher> result;
|
absl::flat_hash_map<TensorId, string, TensorId::Hasher> result;
|
||||||
std::vector<TensorId> tensor_ids;
|
|
||||||
for (const auto& kv_pair : predicate_map_) {
|
for (const auto& kv_pair : predicate_map_) {
|
||||||
CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second);
|
CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second);
|
||||||
}
|
}
|
||||||
|
@ -374,39 +374,6 @@ xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
|
|||||||
return new_def;
|
return new_def;
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_ATTRIBUTE_NOINLINE Status
|
|
||||||
ValidateOutsideCompilationCallNode(Node* call_node) {
|
|
||||||
// DT_INT64 as input/output for outside compilation is not supported yet:
|
|
||||||
// b/120809951.
|
|
||||||
for (const Edge* e : call_node->in_edges()) {
|
|
||||||
if (e->IsControlEdge()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
DataType dtype = e->src()->output_type(e->src_output());
|
|
||||||
if (dtype == DT_INT64) {
|
|
||||||
return errors::Unimplemented(
|
|
||||||
"int64 input for outside compilation is not supported yet: "
|
|
||||||
"b/120809951. Please cast output of node ",
|
|
||||||
e->src()->DebugString(),
|
|
||||||
" to int32 before feeding it into outside compilation.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (const Edge* e : call_node->out_edges()) {
|
|
||||||
if (e->IsControlEdge()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
DataType dtype = e->dst()->input_type(e->dst_input());
|
|
||||||
if (dtype == DT_INT64) {
|
|
||||||
return errors::Unimplemented(
|
|
||||||
"int64 output for outside compilation is not supported yet: "
|
|
||||||
"b/120809951. Please cast input of node ",
|
|
||||||
e->dst()->DebugString(),
|
|
||||||
" to int32 before returning it from outside compilation.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Replace outside compilation function call node with XlaHostCompute node.
|
// Replace outside compilation function call node with XlaHostCompute node.
|
||||||
TF_ATTRIBUTE_NOINLINE xla::StatusOr<Node*> ReplaceOutsideCompilationCallNode(
|
TF_ATTRIBUTE_NOINLINE xla::StatusOr<Node*> ReplaceOutsideCompilationCallNode(
|
||||||
Graph* g, Node* call_node, const std::map<string, int>& host_compute_core,
|
Graph* g, Node* call_node, const std::map<string, int>& host_compute_core,
|
||||||
@ -2384,7 +2351,6 @@ Status ExtractOutsideCompilationForFunction(
|
|||||||
}
|
}
|
||||||
std::map<string, Node*> host_compute_nodes;
|
std::map<string, Node*> host_compute_nodes;
|
||||||
for (Node* n : outside_compilation_nodes) {
|
for (Node* n : outside_compilation_nodes) {
|
||||||
TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n));
|
|
||||||
auto host_compute_node_or = ReplaceOutsideCompilationCallNode(
|
auto host_compute_node_or = ReplaceOutsideCompilationCallNode(
|
||||||
graph_out.get(), n, host_compute_core, *cluster_deps);
|
graph_out.get(), n, host_compute_core, *cluster_deps);
|
||||||
TF_RETURN_IF_ERROR(host_compute_node_or.status());
|
TF_RETURN_IF_ERROR(host_compute_node_or.status());
|
||||||
|
@ -155,6 +155,7 @@ void AllocateAndParseFlags() {
|
|||||||
|
|
||||||
device_flags = new XlaDeviceFlags;
|
device_flags = new XlaDeviceFlags;
|
||||||
device_flags->tf_xla_compile_on_demand = false;
|
device_flags->tf_xla_compile_on_demand = false;
|
||||||
|
device_flags->tf_xla_enable_xla_devices = true;
|
||||||
|
|
||||||
ops_flags = new XlaOpsCommonFlags;
|
ops_flags = new XlaOpsCommonFlags;
|
||||||
ops_flags->tf_xla_always_defer_compilation = false;
|
ops_flags->tf_xla_always_defer_compilation = false;
|
||||||
@ -187,6 +188,12 @@ void AllocateAndParseFlags() {
|
|||||||
"Switch a device into 'on-demand' mode, where instead of "
|
"Switch a device into 'on-demand' mode, where instead of "
|
||||||
"autoclustering ops are compiled one by one just-in-time."),
|
"autoclustering ops are compiled one by one just-in-time."),
|
||||||
|
|
||||||
|
Flag("tf_xla_enable_xla_devices",
|
||||||
|
&device_flags->tf_xla_enable_xla_devices,
|
||||||
|
"Generate XLA_* devices, where placing a computation on such a "
|
||||||
|
"device"
|
||||||
|
"forces compilation by XLA. Deprecated."),
|
||||||
|
|
||||||
Flag("tf_xla_always_defer_compilation",
|
Flag("tf_xla_always_defer_compilation",
|
||||||
&ops_flags->tf_xla_always_defer_compilation, ""),
|
&ops_flags->tf_xla_always_defer_compilation, ""),
|
||||||
|
|
||||||
|
@ -87,6 +87,9 @@ struct XlaDeviceFlags {
|
|||||||
// Enabling this mode by a legacy flag is a temporary mechanism. When this
|
// Enabling this mode by a legacy flag is a temporary mechanism. When this
|
||||||
// feature is battle-tested, we will switch this to be a session option.
|
// feature is battle-tested, we will switch this to be a session option.
|
||||||
bool tf_xla_compile_on_demand;
|
bool tf_xla_compile_on_demand;
|
||||||
|
|
||||||
|
// Enables "XLA" devices if this flag is set.
|
||||||
|
bool tf_xla_enable_xla_devices;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Flags common to the _Xla* ops and their kernels.
|
// Flags common to the _Xla* ops and their kernels.
|
||||||
|
@ -1776,9 +1776,9 @@ absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable() {
|
|||||||
"Lgamma", "Digamma",
|
"Lgamma", "Digamma",
|
||||||
// Binary
|
// Binary
|
||||||
"Add", "AddV2", "Sub", "Mul", "Div", "Atan2", "Complex", "DivNoNan",
|
"Add", "AddV2", "Sub", "Mul", "Div", "Atan2", "Complex", "DivNoNan",
|
||||||
"MulNoNan", "FloorDiv", "Xlogy", "Xdivy", "FloorMod", "BitwiseAnd",
|
"MulNoNan", "FloorDiv", "Xlogy", "Xlog1py", "Xdivy", "FloorMod",
|
||||||
"BitwiseOr", "BitwiseXor", "LeftShift", "RightShift", "LogicalAnd",
|
"BitwiseAnd", "BitwiseOr", "BitwiseXor", "LeftShift", "RightShift",
|
||||||
"LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
|
"LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
|
||||||
"ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "TruncateDiv",
|
"ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "TruncateDiv",
|
||||||
"TruncateMod", "Equal", "NotEqual", "Greater", "GreaterEqual",
|
"TruncateMod", "Equal", "NotEqual", "Greater", "GreaterEqual",
|
||||||
"Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", "SoftsignGrad",
|
"Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", "SoftsignGrad",
|
||||||
@ -1872,6 +1872,8 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
|||||||
"Einsum",
|
"Einsum",
|
||||||
"EmptyTensorList",
|
"EmptyTensorList",
|
||||||
"ExtractImagePatches",
|
"ExtractImagePatches",
|
||||||
|
"Igamma",
|
||||||
|
"Igammac",
|
||||||
"FFT",
|
"FFT",
|
||||||
"FFT2D",
|
"FFT2D",
|
||||||
"FFT3D",
|
"FFT3D",
|
||||||
|
@ -36,8 +36,13 @@ class XlaCpuDeviceFactory : public DeviceFactory {
|
|||||||
};
|
};
|
||||||
|
|
||||||
Status XlaCpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
Status XlaCpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
||||||
devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0"));
|
XlaDeviceFlags* flags = GetXlaDeviceFlags();
|
||||||
|
if (!flags->tf_xla_enable_xla_devices) {
|
||||||
|
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0"));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,6 +50,10 @@ Status XlaCpuDeviceFactory::CreateDevices(
|
|||||||
const SessionOptions& session_options, const string& name_prefix,
|
const SessionOptions& session_options, const string& name_prefix,
|
||||||
std::vector<std::unique_ptr<Device>>* devices) {
|
std::vector<std::unique_ptr<Device>>* devices) {
|
||||||
XlaDeviceFlags* flags = GetXlaDeviceFlags();
|
XlaDeviceFlags* flags = GetXlaDeviceFlags();
|
||||||
|
if (!flags->tf_xla_enable_xla_devices) {
|
||||||
|
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
bool compile_on_demand = flags->tf_xla_compile_on_demand;
|
bool compile_on_demand = flags->tf_xla_compile_on_demand;
|
||||||
|
|
||||||
XlaOpRegistry::DeviceRegistration registration;
|
XlaOpRegistry::DeviceRegistration registration;
|
||||||
|
@ -140,7 +140,6 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
|||||||
// The device tensor should always be fresh.
|
// The device tensor should always be fresh.
|
||||||
TF_RET_CHECK(!xla_tensor->has_shaped_buffer());
|
TF_RET_CHECK(!xla_tensor->has_shaped_buffer());
|
||||||
|
|
||||||
xla_tensor->set_host_tensor(*cpu_tensor);
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
|
xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
|
||||||
stream_->parent()->device_ordinal()));
|
stream_->parent()->device_ordinal()));
|
||||||
|
@ -14,17 +14,20 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
// Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
|
// Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
|
||||||
// operators using XLA via the XLA "CUDA" (GPU) backend.
|
// operators using XLA via the XLA "CUDA" or "ROCM" (GPU) backend.
|
||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/numbers.h"
|
#include "absl/strings/numbers.h"
|
||||||
#include "absl/strings/str_split.h"
|
#include "absl/strings/str_split.h"
|
||||||
|
#include "tensorflow/compiler/jit/flags.h"
|
||||||
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
|
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
|
||||||
#include "tensorflow/compiler/jit/xla_device.h"
|
#include "tensorflow/compiler/jit/xla_device.h"
|
||||||
#include "tensorflow/compiler/jit/xla_device_ops.h"
|
#include "tensorflow/compiler/jit/xla_device_ops.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
|
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -61,7 +64,14 @@ class XlaGpuDeviceFactory : public DeviceFactory {
|
|||||||
};
|
};
|
||||||
|
|
||||||
Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
||||||
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA");
|
XlaDeviceFlags* flags = GetXlaDeviceFlags();
|
||||||
|
if (!flags->tf_xla_enable_xla_devices) {
|
||||||
|
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto platform =
|
||||||
|
se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
|
||||||
if (!platform.ok()) {
|
if (!platform.ok()) {
|
||||||
// Treat failures as non-fatal; there might not be a GPU in the machine.
|
// Treat failures as non-fatal; there might not be a GPU in the machine.
|
||||||
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
|
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
|
||||||
@ -84,6 +94,12 @@ Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
|||||||
Status XlaGpuDeviceFactory::CreateDevices(
|
Status XlaGpuDeviceFactory::CreateDevices(
|
||||||
const SessionOptions& session_options, const string& name_prefix,
|
const SessionOptions& session_options, const string& name_prefix,
|
||||||
std::vector<std::unique_ptr<Device>>* devices) {
|
std::vector<std::unique_ptr<Device>>* devices) {
|
||||||
|
XlaDeviceFlags* flags = GetXlaDeviceFlags();
|
||||||
|
if (!flags->tf_xla_enable_xla_devices) {
|
||||||
|
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
XlaOpRegistry::DeviceRegistration registration;
|
XlaOpRegistry::DeviceRegistration registration;
|
||||||
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
|
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
|
||||||
registration.autoclustering_policy =
|
registration.autoclustering_policy =
|
||||||
@ -103,7 +119,8 @@ Status XlaGpuDeviceFactory::CreateDevices(
|
|||||||
RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT);
|
RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT);
|
||||||
(void)registrations;
|
(void)registrations;
|
||||||
|
|
||||||
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA");
|
auto platform =
|
||||||
|
se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
|
||||||
if (!platform.ok()) {
|
if (!platform.ok()) {
|
||||||
// Treat failures as non-fatal; there might not be a GPU in the machine.
|
// Treat failures as non-fatal; there might not be a GPU in the machine.
|
||||||
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
|
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
|
||||||
|
@ -222,8 +222,9 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
|||||||
OpKernelConstruction construction(
|
OpKernelConstruction construction(
|
||||||
DeviceType(dev->device_type()), dev,
|
DeviceType(dev->device_type()), dev,
|
||||||
dev->GetAllocator(AllocatorAttributes()), &node_def,
|
dev->GetAllocator(AllocatorAttributes()), &node_def,
|
||||||
&fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types,
|
&fbody->fdef.signature(), flr, dev->resource_manager(), fbody->arg_types,
|
||||||
fbody->ret_types, output_memory_types, flr->graph_def_version(), &s);
|
input_memory_types, fbody->ret_types, output_memory_types,
|
||||||
|
flr->graph_def_version(), &s);
|
||||||
|
|
||||||
*kernel = absl::make_unique<XlaLocalLaunchBase>(
|
*kernel = absl::make_unique<XlaLocalLaunchBase>(
|
||||||
&construction, constant_arg_indices, resource_arg_indices, function);
|
&construction, constant_arg_indices, resource_arg_indices, function);
|
||||||
|
@ -44,8 +44,11 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/platform:logging",
|
"//tensorflow/core/platform:logging",
|
||||||
"@llvm-project//llvm:support",
|
"@llvm-project//llvm:support",
|
||||||
|
"@llvm-project//mlir:AffineDialectRegistration",
|
||||||
|
"@llvm-project//mlir:LoopDialectRegistration",
|
||||||
"@llvm-project//mlir:MlirOptLib",
|
"@llvm-project//mlir:MlirOptLib",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
|
"@llvm-project//mlir:QuantOpsDialectRegistration",
|
||||||
"@llvm-project//mlir:Support",
|
"@llvm-project//mlir:Support",
|
||||||
"@llvm-project//mlir/test:TestTransforms",
|
"@llvm-project//mlir/test:TestTransforms",
|
||||||
],
|
],
|
||||||
@ -80,9 +83,10 @@ cc_library(
|
|||||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||||
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
|
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
|
||||||
"//tensorflow/compiler/mlir/xla:xla_lower",
|
"//tensorflow/compiler/mlir/xla:xla_lower",
|
||||||
"@llvm-project//mlir:AffineDialectRegistration",
|
"//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts",
|
||||||
|
"//tensorflow/compiler/mlir/xla:xla_test_passes",
|
||||||
|
"@llvm-project//mlir:AffineOps",
|
||||||
"@llvm-project//mlir:QuantOps",
|
"@llvm-project//mlir:QuantOps",
|
||||||
"@llvm-project//mlir:QuantOpsDialectRegistration",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -47,6 +47,14 @@ gentbl(
|
|||||||
"-gen-op-doc",
|
"-gen-op-doc",
|
||||||
"g3doc/tfl_ops.md",
|
"g3doc/tfl_ops.md",
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
"-gen-op-interface-decls",
|
||||||
|
"ir/tfl_ops_interface.h.inc",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"-gen-op-interface-defs",
|
||||||
|
"ir/tfl_ops_interface.cc.inc",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||||
td_file = "ir/tfl_ops.td",
|
td_file = "ir/tfl_ops.td",
|
||||||
@ -177,11 +185,12 @@ cc_library(
|
|||||||
"ir/tfl_ops.cc",
|
"ir/tfl_ops.cc",
|
||||||
"ir/tfl_ops.cc.inc",
|
"ir/tfl_ops.cc.inc",
|
||||||
"ir/tfl_ops.h.inc",
|
"ir/tfl_ops.h.inc",
|
||||||
|
"ir/tfl_ops_interface.cc.inc",
|
||||||
|
"ir/tfl_ops_interface.h.inc",
|
||||||
"utils/attribute_utils.cc",
|
"utils/attribute_utils.cc",
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"ir/tfl_ops.h",
|
"ir/tfl_ops.h",
|
||||||
"ir/tfl_traits.h",
|
|
||||||
"transforms/passes.h",
|
"transforms/passes.h",
|
||||||
"utils/attribute_utils.h",
|
"utils/attribute_utils.h",
|
||||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",
|
"//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",
|
||||||
@ -330,6 +339,7 @@ cc_library(
|
|||||||
cc_library(
|
cc_library(
|
||||||
name = "tensorflow_lite_quantize",
|
name = "tensorflow_lite_quantize",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
"transforms/default_quant_params.cc",
|
||||||
"transforms/generated_post_quantize.inc",
|
"transforms/generated_post_quantize.inc",
|
||||||
"transforms/generated_quantize.inc",
|
"transforms/generated_quantize.inc",
|
||||||
"transforms/load_quantization_recipe.cc",
|
"transforms/load_quantization_recipe.cc",
|
||||||
@ -506,6 +516,7 @@ cc_library(
|
|||||||
"//tensorflow/lite:schema_fbs_version",
|
"//tensorflow/lite:schema_fbs_version",
|
||||||
"//tensorflow/lite:string_util",
|
"//tensorflow/lite:string_util",
|
||||||
"//tensorflow/lite/delegates/flex:whitelisted_flex_ops_lib",
|
"//tensorflow/lite/delegates/flex:whitelisted_flex_ops_lib",
|
||||||
|
"//tensorflow/lite/kernels/internal:kernel_utils",
|
||||||
"//tensorflow/lite/schema:schema_fbs",
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
"//tensorflow/lite/tools/versioning:op_version",
|
"//tensorflow/lite/tools/versioning:op_version",
|
||||||
"@com_google_absl//absl/base",
|
"@com_google_absl//absl/base",
|
||||||
@ -671,12 +682,16 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
exports_files(
|
cc_library(
|
||||||
["transforms/passes.h"],
|
name = "empty_passes",
|
||||||
|
hdrs = ["transforms/passes.h"],
|
||||||
visibility = [
|
visibility = [
|
||||||
"//configs/devtools/hawkeye/tflite:__subpackages__",
|
"//configs/devtools/hawkeye/tflite:__subpackages__",
|
||||||
"//learning/brain/models/app_benchmarks:__subpackages__",
|
"//learning/brain/models/app_benchmarks:__subpackages__",
|
||||||
"//tensorflow/compiler/mlir/lite:friends",
|
"//tensorflow/compiler/mlir/lite:friends",
|
||||||
"//tensorflow/lite/experimental/mlir:__subpackages__",
|
"//tensorflow/lite/experimental/mlir:__subpackages__",
|
||||||
],
|
],
|
||||||
|
deps = [
|
||||||
|
"@llvm-project//llvm:support",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
@ -31,10 +31,11 @@ struct PassConfig {
|
|||||||
: emit_builtin_tflite_ops(true),
|
: emit_builtin_tflite_ops(true),
|
||||||
lower_tensor_list_ops(false),
|
lower_tensor_list_ops(false),
|
||||||
trim_functions_whitelist({}),
|
trim_functions_whitelist({}),
|
||||||
quant_specs(specs),
|
quant_specs(std::move(specs)),
|
||||||
skip_control_dialect(false),
|
skip_control_dialect(false),
|
||||||
form_clusters(false),
|
form_clusters(false),
|
||||||
inline_functions(false) {}
|
inline_functions(false),
|
||||||
|
unfold_batch_matmul(true) {}
|
||||||
|
|
||||||
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
|
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
|
||||||
// added, which produces TF Lite ops.
|
// added, which produces TF Lite ops.
|
||||||
@ -57,6 +58,9 @@ struct PassConfig {
|
|||||||
// Inline function calls within the main function in the MLIR module, prior
|
// Inline function calls within the main function in the MLIR module, prior
|
||||||
// to legalization to TFLite.
|
// to legalization to TFLite.
|
||||||
bool inline_functions;
|
bool inline_functions;
|
||||||
|
// if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set
|
||||||
|
// of tfl.fully_connected ops.
|
||||||
|
bool unfold_batch_matmul;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace TFL
|
} // namespace TFL
|
||||||
|
@ -389,7 +389,6 @@ StatusOr<mlir::ElementsAttr> ConvertIntBuffer(
|
|||||||
mlir::RankedTensorType shaped_type, mlir::Type elem_type,
|
mlir::RankedTensorType shaped_type, mlir::Type elem_type,
|
||||||
const std::vector<uint8_t>& buffer) {
|
const std::vector<uint8_t>& buffer) {
|
||||||
unsigned bit_width;
|
unsigned bit_width;
|
||||||
mlir::RankedTensorType buffer_type;
|
|
||||||
if (auto itype = elem_type.dyn_cast<mlir::IntegerType>()) {
|
if (auto itype = elem_type.dyn_cast<mlir::IntegerType>()) {
|
||||||
bit_width = itype.getWidth();
|
bit_width = itype.getWidth();
|
||||||
} else if (auto qtype = elem_type.dyn_cast<QuantizedType>()) {
|
} else if (auto qtype = elem_type.dyn_cast<QuantizedType>()) {
|
||||||
@ -920,15 +919,13 @@ StatusOr<FuncOp> ConvertSubgraph(
|
|||||||
// represents TFLite, this entry point must be called "main"
|
// represents TFLite, this entry point must be called "main"
|
||||||
// TODO(b/131175224,b/132239787) Support multiple entry points
|
// TODO(b/131175224,b/132239787) Support multiple entry points
|
||||||
std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
|
std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
|
||||||
if (subgraph.name.empty()) {
|
if (index == 0) {
|
||||||
if (index == 0) {
|
return "main";
|
||||||
return "main";
|
|
||||||
} else {
|
|
||||||
return llvm::formatv("fn_{0}", index).str();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return subgraph.name;
|
|
||||||
}
|
}
|
||||||
|
if (subgraph.name.empty()) {
|
||||||
|
return llvm::formatv("fn_{0}", index).str();
|
||||||
|
}
|
||||||
|
return subgraph.name;
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -259,9 +259,9 @@ Status mlir::CustomOptionsToAttributes(
|
|||||||
attributes->emplace_back(builder.getNamedAttr(
|
attributes->emplace_back(builder.getNamedAttr(
|
||||||
"stride_w", builder.getI32IntegerAttr(pool_params->stride_width)));
|
"stride_w", builder.getI32IntegerAttr(pool_params->stride_width)));
|
||||||
attributes->emplace_back(builder.getNamedAttr(
|
attributes->emplace_back(builder.getNamedAttr(
|
||||||
"filter_w", builder.getI32IntegerAttr(pool_params->filter_height)));
|
"filter_h", builder.getI32IntegerAttr(pool_params->filter_height)));
|
||||||
attributes->emplace_back(builder.getNamedAttr(
|
attributes->emplace_back(builder.getNamedAttr(
|
||||||
"filter_h", builder.getI32IntegerAttr(pool_params->filter_width)));
|
"filter_w", builder.getI32IntegerAttr(pool_params->filter_width)));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
||||||
} else if (op_name == "tfl.convolution_2d_transpose_bias") {
|
} else if (op_name == "tfl.convolution_2d_transpose_bias") {
|
||||||
|
@ -71,6 +71,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h"
|
#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
#include "tensorflow/lite/string_util.h"
|
#include "tensorflow/lite/string_util.h"
|
||||||
#include "tensorflow/lite/tools/versioning/op_version.h"
|
#include "tensorflow/lite/tools/versioning/op_version.h"
|
||||||
@ -218,6 +219,13 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
|
|||||||
auto qtype = type.cast<mlir::quant::UniformQuantizedPerAxisType>();
|
auto qtype = type.cast<mlir::quant::UniformQuantizedPerAxisType>();
|
||||||
return GetTFLiteType(qtype.getStorageType(), qtype.isSigned());
|
return GetTFLiteType(qtype.getStorageType(), qtype.isSigned());
|
||||||
}
|
}
|
||||||
|
case mlir::TF::TensorFlowTypes::RESOURCE: {
|
||||||
|
// Treat tf.resource values as integer values in flatbuffer.
|
||||||
|
// TODO(b/146131919): Maybe need to have a detailed design for supporting
|
||||||
|
// other resource types beyonds hash table resources and resource
|
||||||
|
// variables.
|
||||||
|
return tflite::TensorType_INT32;
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
// TFLite export fills FLOAT32 for unknown data types. Returning an error
|
// TFLite export fills FLOAT32 for unknown data types. Returning an error
|
||||||
// for now for safety and this could be revisited when required.
|
// for now for safety and this could be revisited when required.
|
||||||
@ -317,6 +325,48 @@ static std::unique_ptr<::tensorflow::NodeDef> getTensorFlowNodeDef(
|
|||||||
return std::move(status_or_node_def.ValueOrDie());
|
return std::move(status_or_node_def.ValueOrDie());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Converts a mlir padding StringRef to TfLitePadding.
|
||||||
|
// Returns llvm::None if conversion fails.
|
||||||
|
static Optional<TfLitePadding> GetTflitePadding(Operation* inst,
|
||||||
|
llvm::StringRef padding) {
|
||||||
|
const tflite::Padding padding_attr =
|
||||||
|
std::move(llvm::StringSwitch<tflite::Padding>(padding)
|
||||||
|
.Case("SAME", tflite::Padding_SAME)
|
||||||
|
.Case("VALID", tflite::Padding_VALID));
|
||||||
|
if (padding_attr == tflite::Padding_SAME) {
|
||||||
|
return kTfLitePaddingSame;
|
||||||
|
}
|
||||||
|
if (padding_attr == tflite::Padding_VALID) {
|
||||||
|
return kTfLitePaddingValid;
|
||||||
|
}
|
||||||
|
|
||||||
|
return inst->emitOpError() << "Invalid padding attribute: " << padding,
|
||||||
|
llvm::None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extracts TfLitePoolParams from a TFL custom op.
|
||||||
|
// Template parameter, TFLOp, should be a TFL custom op containing attributes
|
||||||
|
// generated from TfLitePoolParams.
|
||||||
|
// Returns llvm::None if conversion fails.
|
||||||
|
template <typename TFLOp>
|
||||||
|
static Optional<TfLitePoolParams> GetTflitePoolParams(Operation* inst,
|
||||||
|
TFLOp op) {
|
||||||
|
TfLitePoolParams pool_params;
|
||||||
|
pool_params.stride_height = op.stride_h().getSExtValue();
|
||||||
|
pool_params.stride_width = op.stride_w().getSExtValue();
|
||||||
|
pool_params.filter_height = op.filter_h().getSExtValue();
|
||||||
|
pool_params.filter_width = op.filter_w().getSExtValue();
|
||||||
|
const auto padding = GetTflitePadding(inst, op.padding());
|
||||||
|
if (padding) {
|
||||||
|
pool_params.padding = *padding;
|
||||||
|
pool_params.activation = kTfLiteActNone;
|
||||||
|
pool_params.computed.padding = TfLitePaddingValues{0, 0, 0, 0};
|
||||||
|
return pool_params;
|
||||||
|
}
|
||||||
|
|
||||||
|
return llvm::None;
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Translates an MLIR module in TFLite dialect to TFLite FlatBuffer.
|
// Translates an MLIR module in TFLite dialect to TFLite FlatBuffer.
|
||||||
@ -375,9 +425,31 @@ class Translator {
|
|||||||
mlir::TF::WhileOp op, const std::vector<int32_t>& operands,
|
mlir::TF::WhileOp op, const std::vector<int32_t>& operands,
|
||||||
const std::vector<int32_t>& results);
|
const std::vector<int32_t>& results);
|
||||||
|
|
||||||
|
// Builds custom operators.
|
||||||
|
// Templated on a) data type of custom_option to be stored into flatbuffer,
|
||||||
|
// and b) TFL custom op type.
|
||||||
|
template <typename CustomOptionType, typename TFLOp>
|
||||||
|
BufferOffset<tflite::Operator> BuildCustomOperator(
|
||||||
|
const CustomOptionType& custom_option, const std::string& opcode_name,
|
||||||
|
TFLOp op, const std::vector<int32_t>& operands,
|
||||||
|
const std::vector<int32_t>& results);
|
||||||
|
|
||||||
BufferOffset<tflite::Operator> BuildNumericVerifyOperator(
|
BufferOffset<tflite::Operator> BuildNumericVerifyOperator(
|
||||||
mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
|
mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
|
||||||
const std::vector<int32_t>& results);
|
const std::vector<int32_t>& results);
|
||||||
|
Optional<BufferOffset<tflite::Operator>>
|
||||||
|
BuildConvolution2DTransposeBiasOperator(
|
||||||
|
Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op,
|
||||||
|
const std::vector<int32_t>& operands,
|
||||||
|
const std::vector<int32_t>& results);
|
||||||
|
Optional<BufferOffset<tflite::Operator>> BuildMaxPoolingWithArgMax2DOperator(
|
||||||
|
Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op,
|
||||||
|
const std::vector<int32_t>& operands,
|
||||||
|
const std::vector<int32_t>& results);
|
||||||
|
Optional<BufferOffset<tflite::Operator>> BuildMaxUnpooling2DOperator(
|
||||||
|
Operation* inst, mlir::TFL::MaxUnpooling2DOp op,
|
||||||
|
const std::vector<int32_t>& operands,
|
||||||
|
const std::vector<int32_t>& results);
|
||||||
|
|
||||||
Optional<CustomOptionsOffset> CreateFlexOpCustomOptions(
|
Optional<CustomOptionsOffset> CreateFlexOpCustomOptions(
|
||||||
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
|
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
|
||||||
@ -615,19 +687,72 @@ BufferOffset<tflite::Operator> Translator::BuildWhileOperator(
|
|||||||
builtin_options);
|
builtin_options);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename CustomOptionType, typename TFLOp>
|
||||||
|
BufferOffset<tflite::Operator> Translator::BuildCustomOperator(
|
||||||
|
const CustomOptionType& custom_option, const std::string& opcode_name,
|
||||||
|
TFLOp op, const std::vector<int32_t>& operands,
|
||||||
|
const std::vector<int32_t>& results) {
|
||||||
|
std::vector<uint8_t> custom_option_vector(sizeof(CustomOptionType));
|
||||||
|
memcpy(custom_option_vector.data(), &custom_option, sizeof(CustomOptionType));
|
||||||
|
auto opcode_index =
|
||||||
|
GetOpcodeIndex(opcode_name, tflite::BuiltinOperator_CUSTOM);
|
||||||
|
return tflite::CreateOperator(
|
||||||
|
builder_, opcode_index, builder_.CreateVector(operands),
|
||||||
|
builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
|
||||||
|
/*builtin_options=*/0,
|
||||||
|
builder_.CreateVector<uint8_t>(custom_option_vector),
|
||||||
|
tflite::CustomOptionsFormat_FLEXBUFFERS);
|
||||||
|
}
|
||||||
|
|
||||||
BufferOffset<tflite::Operator> Translator::BuildNumericVerifyOperator(
|
BufferOffset<tflite::Operator> Translator::BuildNumericVerifyOperator(
|
||||||
mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
|
mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
|
||||||
const std::vector<int32_t>& results) {
|
const std::vector<int32_t>& results) {
|
||||||
float tolerance = op.tolerance().convertToFloat();
|
float tolerance = op.tolerance().convertToFloat();
|
||||||
std::vector<uint8_t> custom_options(sizeof(float));
|
return BuildCustomOperator(tolerance, "NumericVerify", op, operands, results);
|
||||||
memcpy(custom_options.data(), &tolerance, sizeof(float));
|
}
|
||||||
auto opcode_index =
|
|
||||||
GetOpcodeIndex("NumericVerify", tflite::BuiltinOperator_CUSTOM);
|
Optional<BufferOffset<tflite::Operator>>
|
||||||
return tflite::CreateOperator(
|
Translator::BuildConvolution2DTransposeBiasOperator(
|
||||||
builder_, opcode_index, builder_.CreateVector(operands),
|
Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op,
|
||||||
builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
|
const std::vector<int32_t>& operands, const std::vector<int32_t>& results) {
|
||||||
/*builtin_options=*/0, builder_.CreateVector<uint8_t>(custom_options),
|
TfLiteTransposeConvParams conv_params;
|
||||||
tflite::CustomOptionsFormat_FLEXBUFFERS);
|
conv_params.stride_height = op.stride_h().getSExtValue();
|
||||||
|
conv_params.stride_width = op.stride_w().getSExtValue();
|
||||||
|
const auto padding = GetTflitePadding(inst, op.padding());
|
||||||
|
if (padding) {
|
||||||
|
conv_params.padding = *padding;
|
||||||
|
return BuildCustomOperator(conv_params, "Convolution2DTransposeBias", op,
|
||||||
|
operands, results);
|
||||||
|
}
|
||||||
|
|
||||||
|
return llvm::None;
|
||||||
|
}
|
||||||
|
|
||||||
|
Optional<BufferOffset<tflite::Operator>>
|
||||||
|
Translator::BuildMaxPoolingWithArgMax2DOperator(
|
||||||
|
Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op,
|
||||||
|
const std::vector<int32_t>& operands, const std::vector<int32_t>& results) {
|
||||||
|
const auto pool_params = GetTflitePoolParams(inst, op);
|
||||||
|
if (pool_params) {
|
||||||
|
return BuildCustomOperator(*pool_params, "MaxPoolingWithArgmax2D", op,
|
||||||
|
operands, results);
|
||||||
|
}
|
||||||
|
|
||||||
|
return llvm::None;
|
||||||
|
}
|
||||||
|
|
||||||
|
Optional<BufferOffset<tflite::Operator>>
|
||||||
|
Translator::BuildMaxUnpooling2DOperator(Operation* inst,
|
||||||
|
mlir::TFL::MaxUnpooling2DOp op,
|
||||||
|
const std::vector<int32_t>& operands,
|
||||||
|
const std::vector<int32_t>& results) {
|
||||||
|
const auto pool_params = GetTflitePoolParams(inst, op);
|
||||||
|
if (pool_params) {
|
||||||
|
return BuildCustomOperator(*pool_params, "MaxUnpooling2D", op, operands,
|
||||||
|
results);
|
||||||
|
}
|
||||||
|
|
||||||
|
return llvm::None;
|
||||||
}
|
}
|
||||||
|
|
||||||
Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
|
Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
|
||||||
@ -769,6 +894,20 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
|||||||
if (auto verify_op = dyn_cast<mlir::TFL::NumericVerifyOp>(inst)) {
|
if (auto verify_op = dyn_cast<mlir::TFL::NumericVerifyOp>(inst)) {
|
||||||
return BuildNumericVerifyOperator(verify_op, operands, results);
|
return BuildNumericVerifyOperator(verify_op, operands, results);
|
||||||
}
|
}
|
||||||
|
if (auto conv_transpose_bias_op =
|
||||||
|
dyn_cast<mlir::TFL::Convolution2DTransposeBiasOp>(inst)) {
|
||||||
|
return BuildConvolution2DTransposeBiasOperator(
|
||||||
|
inst, conv_transpose_bias_op, operands, results);
|
||||||
|
}
|
||||||
|
if (auto max_pooling_with_arg_max_op =
|
||||||
|
dyn_cast<mlir::TFL::MaxPoolingWithArgMax2DOp>(inst)) {
|
||||||
|
return BuildMaxPoolingWithArgMax2DOperator(
|
||||||
|
inst, max_pooling_with_arg_max_op, operands, results);
|
||||||
|
}
|
||||||
|
if (auto max_unpooling_op = dyn_cast<mlir::TFL::MaxUnpooling2DOp>(inst)) {
|
||||||
|
return BuildMaxUnpooling2DOperator(inst, max_unpooling_op, operands,
|
||||||
|
results);
|
||||||
|
}
|
||||||
inst->emitOpError("is not a supported TFLite op");
|
inst->emitOpError("is not a supported TFLite op");
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
}
|
}
|
||||||
@ -904,11 +1043,6 @@ void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) {
|
|||||||
|
|
||||||
bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) {
|
bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) {
|
||||||
std::vector<int> operand_indices;
|
std::vector<int> operand_indices;
|
||||||
// TODO(b/138254427): When the bug is addressed, we'll be able to inspect
|
|
||||||
// for the presence of a specific OpTrait using mlir::Operation, without
|
|
||||||
// having to cast it to specific ops like below.
|
|
||||||
// Until then, when a new RNN/LSTM op is added to TFLite and has stateful
|
|
||||||
// tensors as operands, they will need to be added here as well.
|
|
||||||
if (!mlir::TFL::IsStatefulOp(op, &operand_indices)) return false;
|
if (!mlir::TFL::IsStatefulOp(op, &operand_indices)) return false;
|
||||||
return absl::c_find(operand_indices, operand_index) != operand_indices.end();
|
return absl::c_find(operand_indices, operand_index) != operand_indices.end();
|
||||||
}
|
}
|
||||||
|
@ -1728,6 +1728,7 @@ static LogicalResult Verify(TransposeOp op) {
|
|||||||
// TableGen'd op method definitions
|
// TableGen'd op method definitions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.cc.inc"
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
|
||||||
|
|
||||||
|
@ -27,7 +27,6 @@ limitations under the License.
|
|||||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
|
|
||||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
|
||||||
@ -44,6 +43,7 @@ class TensorFlowLiteDialect : public Dialect {
|
|||||||
Location loc) override;
|
Location loc) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.h.inc"
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc"
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc"
|
||||||
|
|
||||||
|
@ -249,14 +249,39 @@ def TFL_ComparisonBinaryBuilder : OpBuilder<
|
|||||||
}]>;
|
}]>;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TFL native op trait for stateful operands and channel indices.
|
// TFL op interface for stateful operands.
|
||||||
|
|
||||||
class StatefulOperands<list<int> operands>
|
def TFL_StatefulOp : OpInterface<"StatefulOpInterface"> {
|
||||||
: ParamNativeOpTrait<"TFL::StatefulOperands", StrJoinInt<operands>.result>;
|
let description = [{
|
||||||
|
Interface for ops that are stateful and need to identify stateful operands.
|
||||||
|
|
||||||
|
Stateful operands correspond to TF's variables semantics. An op that has 1
|
||||||
|
or more stateful operands is a stateful op.
|
||||||
|
}];
|
||||||
|
|
||||||
class ChannelDimIndex<int index>
|
let methods = [
|
||||||
: ParamNativeOpTrait<"TFL::ChannelDimIndex", !cast<string>(index)>;
|
InterfaceMethod<
|
||||||
|
[{Returns the indices of stateful operands.}],
|
||||||
|
"std::vector<int>", "GetStatefulOperands", (ins)
|
||||||
|
>,
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TFL op interface for output channel index.
|
||||||
|
|
||||||
|
def TFL_ChannelDimIndexInterface : OpInterface<"ChannelDimIndexInterface"> {
|
||||||
|
let description = [{
|
||||||
|
Interface for defining the index of out channel index.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let methods = [
|
||||||
|
InterfaceMethod<
|
||||||
|
[{Returns the dimension index of the output channels.}],
|
||||||
|
"int", "GetChannelDimIndex", (ins)
|
||||||
|
>,
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TFL op base class.
|
// TFL op base class.
|
||||||
@ -285,7 +310,7 @@ class TFL_Op<string mnemonic, list<OpTrait> traits = []> :
|
|||||||
|
|
||||||
class TFL_ConvOp<string mnemonic, string opSummary, int index> :
|
class TFL_ConvOp<string mnemonic, string opSummary, int index> :
|
||||||
TFL_Op<mnemonic, [NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
|
TFL_Op<mnemonic, [NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
|
||||||
ChannelDimIndex<index>, AffineOpCoefficient<index, 1>]> {
|
TFL_ChannelDimIndexInterface, AffineOpCoefficient<index, 1>]> {
|
||||||
let summary = opSummary # " operator";
|
let summary = opSummary # " operator";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -486,8 +511,7 @@ def TFL_ArgMaxOp : TFL_Op<"arg_max", [NoSideEffect]> {
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (
|
let arguments = (
|
||||||
// TODO: Add support for uint8.
|
ins TensorOf<[F32, I32, I8, TFL_Uint8, QI8, QUI8]>:$input,
|
||||||
ins TensorOf<[F32, I32, I8]>:$input,
|
|
||||||
TFL_I32OrI64Tensor:$dim
|
TFL_I32OrI64Tensor:$dim
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -515,8 +539,7 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> {
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (
|
let arguments = (
|
||||||
// TODO(pkanwar): Add support for uint8.
|
ins TensorOf<[F32, I32, I8, TFL_Uint8, QI8, QUI8]>:$input,
|
||||||
ins TensorOf<[F32, I32, I8]>:$input,
|
|
||||||
TFL_I32OrI64Tensor:$dim
|
TFL_I32OrI64Tensor:$dim
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -617,7 +640,12 @@ def TFL_ExternalConstOp : Op<TFL_Dialect, "external_const", [NoSideEffect]> {
|
|||||||
let results = (outs AnyTensor:$output);
|
let results = (outs AnyTensor:$output);
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0>;
|
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0> {
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// StatefulOpInterface:
|
||||||
|
int GetChannelDimIndex() { return 0; }
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def TFL_CosOp: TFL_Op<"cos", [
|
def TFL_CosOp: TFL_Op<"cos", [
|
||||||
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
|
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
|
||||||
@ -637,6 +665,11 @@ def TFL_CosOp: TFL_Op<"cos", [
|
|||||||
def TFL_DepthwiseConv2DOp :
|
def TFL_DepthwiseConv2DOp :
|
||||||
TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> {
|
TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> {
|
||||||
let arguments = !con(TFL_Conv2DOp.arguments, (ins I32Attr:$depth_multiplier));
|
let arguments = !con(TFL_Conv2DOp.arguments, (ins I32Attr:$depth_multiplier));
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// StatefulOpInterface:
|
||||||
|
int GetChannelDimIndex() { return 3; }
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_FCWO_Default : StrEnumAttrCase<"DEFAULT">;
|
def TFL_FCWO_Default : StrEnumAttrCase<"DEFAULT">;
|
||||||
@ -650,7 +683,8 @@ def TFL_FullyConnectedOptionsWeightFormatAttr :
|
|||||||
|
|
||||||
// TODO(jpienaar): Update post discussion on semantics of FC OP.
|
// TODO(jpienaar): Update post discussion on semantics of FC OP.
|
||||||
def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
|
def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
|
||||||
NoSideEffect, AccumulatorUniformScale<2, 0, 1>, ChannelDimIndex<0>,
|
NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
|
||||||
|
TFL_ChannelDimIndexInterface,
|
||||||
AffineOpCoefficient<-1, 1>]> {
|
AffineOpCoefficient<-1, 1>]> {
|
||||||
let summary = "Fully connected op";
|
let summary = "Fully connected op";
|
||||||
|
|
||||||
@ -672,6 +706,11 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
|
|||||||
let verifier = [{ return Verify(*this); }];
|
let verifier = [{ return Verify(*this); }];
|
||||||
|
|
||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// ChannelDimIndexInterface:
|
||||||
|
int GetChannelDimIndex() { return 0; }
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_GatherOp : TFL_Op<"gather", [
|
def TFL_GatherOp : TFL_Op<"gather", [
|
||||||
@ -1208,7 +1247,8 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [Broadcastable, NoSideEffect]> {
|
|||||||
let builders = [TFL_BroadcastableBinaryBuilder];
|
let builders = [TFL_BroadcastableBinaryBuilder];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_GreaterOp : TFL_Op<"greater", [NoSideEffect, NoQuantizableResult]> {
|
def TFL_GreaterOp : TFL_Op<"greater", [
|
||||||
|
Broadcastable, NoSideEffect, NoQuantizableResult]> {
|
||||||
let summary = "Greater operator";
|
let summary = "Greater operator";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -1221,6 +1261,8 @@ def TFL_GreaterOp : TFL_Op<"greater", [NoSideEffect, NoQuantizableResult]> {
|
|||||||
|
|
||||||
let results = (outs AnyTensor:$output);
|
let results = (outs AnyTensor:$output);
|
||||||
|
|
||||||
|
let builders = [TFL_ComparisonBinaryBuilder];
|
||||||
|
|
||||||
let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }];
|
let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }];
|
||||||
|
|
||||||
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
|
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
|
||||||
@ -1287,7 +1329,8 @@ def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [NoSideEffect, SameOperandsAndResultTy
|
|||||||
let hasOptions = 0b1;
|
let hasOptions = 0b1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_LessOp : TFL_Op<"less", [NoSideEffect, NoQuantizableResult]> {
|
def TFL_LessOp : TFL_Op<"less", [
|
||||||
|
Broadcastable, NoSideEffect, NoQuantizableResult]> {
|
||||||
let summary = "Less operator";
|
let summary = "Less operator";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -2123,7 +2166,7 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2",
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor: A Tensor. Must be one of the following types:
|
tensor: A Tensor. Must be one of the following types:
|
||||||
int16, int32, int64, float32 Up to 8-D.
|
uint8, int16, int32, int64, float32, bool Up to 8-D.
|
||||||
|
|
||||||
axis: A Tensor. Must be one of the following types: int32, int64.
|
axis: A Tensor. Must be one of the following types: int32, int64.
|
||||||
with only 1 element which is the axis index.
|
with only 1 element which is the axis index.
|
||||||
@ -2132,12 +2175,12 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2",
|
|||||||
|
|
||||||
let arguments = (
|
let arguments = (
|
||||||
ins
|
ins
|
||||||
TensorOf<[F32, I16, I32, I64]>:$input,
|
TensorOf<[F32, I16, I32, I64, TFL_Uint8, I1]>:$input,
|
||||||
TensorOf<[I32, I64]>:$axis
|
TensorOf<[I32, I64]>:$axis
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TensorOf<[F32, I16, I32, I64, I8]>:$output
|
TensorOf<[F32, I16, I32, I64, TFL_Uint8, I1]>:$output
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2341,9 +2384,9 @@ def TFL_TanhOp: TFL_Op<"tanh", [
|
|||||||
let results = (outs TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$y);
|
let results = (outs TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$y);
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
|
def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, SameOperandsAndResultsScale,
|
||||||
PredOpTrait<"resultant element type needs to match first operand type",
|
PredOpTrait<"resultant element type needs to match first operand type",
|
||||||
TCresVTEtIsSameAsOp<0,0>>]> {
|
TFL_TCresVTEtIsSameAsOp<0,0>>]> {
|
||||||
let summary = "Tile operator.";
|
let summary = "Tile operator.";
|
||||||
let description = [{
|
let description = [{
|
||||||
Constructs a tensor by tiling a given tensor.
|
Constructs a tensor by tiling a given tensor.
|
||||||
@ -2356,10 +2399,11 @@ def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TensorOf<[F32, I1, I32, I64, TFL_Uint8]>:$input,
|
TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8]>:$input,
|
||||||
TFL_I32OrI64Tensor:$multiples);
|
TFL_I32OrI64Tensor:$multiples);
|
||||||
|
|
||||||
let results = (outs TensorOf<[F32, I1, I32, I64, TFL_Uint8]>:$output);
|
let results = (outs
|
||||||
|
TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8]>:$output);
|
||||||
|
|
||||||
let hasOptions = 0;
|
let hasOptions = 0;
|
||||||
}
|
}
|
||||||
@ -2369,7 +2413,7 @@ def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
|
|||||||
// TODO(jpienaar): Check that k is less or equal the internal dimension
|
// TODO(jpienaar): Check that k is less or equal the internal dimension
|
||||||
def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>,
|
def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>,
|
||||||
PredOpTrait<"result and input element type match",
|
PredOpTrait<"result and input element type match",
|
||||||
TCresVTEtIsSameAsOp<0,0>>]> {
|
TCresVTEtIsSameAsOp<0,0>>, SameOperandsAndResultsScale]> {
|
||||||
let summary = "TopK operator";
|
let summary = "TopK operator";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -2379,11 +2423,11 @@ def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>,
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TensorOf<[F32, I8, I32, I64, TFL_Uint8]>:$input,
|
TensorOf<[F32, I8, I32, I64, TFL_Uint8, QI8, QUI8]>:$input,
|
||||||
I32Tensor:$k);
|
I32Tensor:$k);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
AnyTensor:$values,
|
TensorOf<[F32, I8, I32, I64, TFL_Uint8, QI8, QUI8]>:$values,
|
||||||
I32Tensor:$indices);
|
I32Tensor:$indices);
|
||||||
|
|
||||||
let builders = [OpBuilder<"Builder *builder, OperationState &result, "
|
let builders = [OpBuilder<"Builder *builder, OperationState &result, "
|
||||||
@ -2907,6 +2951,20 @@ def TFL_QuantizeOp: TFL_Op<"quantize", [
|
|||||||
let results = (outs AnyTensor:$output);
|
let results = (outs AnyTensor:$output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TFL_DensifyOp: TFL_Op<"densify", [NoSideEffect,
|
||||||
|
SameOperandsAndResultType,
|
||||||
|
NoQuantizableResult]> {
|
||||||
|
let summary = "Densify operator";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Converts sparse tensor to dense format.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins AnyTensor:$input);
|
||||||
|
|
||||||
|
let results = (outs AnyTensor:$output);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// LSTM Ops
|
// LSTM Ops
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -2996,7 +3054,7 @@ def TFL_LSTMOp :
|
|||||||
LstmOptionalPeepholeWeightConstraint,
|
LstmOptionalPeepholeWeightConstraint,
|
||||||
LstmProjectionWeightBiasConstraint,
|
LstmProjectionWeightBiasConstraint,
|
||||||
LstmResultConstraint,
|
LstmResultConstraint,
|
||||||
StatefulOperands<[18, 19]>]> {
|
TFL_StatefulOp]> {
|
||||||
let summary = "The full lstm operator";
|
let summary = "The full lstm operator";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -3080,6 +3138,11 @@ Ba et al. “Layer Normalization”
|
|||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
|
|
||||||
let verifier = [{ return Verify(*this); }];
|
let verifier = [{ return Verify(*this); }];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// StatefulOpInterface:
|
||||||
|
std::vector<int> GetStatefulOperands() { return {18, 19}; }
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnidirectionalSequenceLstm op.
|
// UnidirectionalSequenceLstm op.
|
||||||
@ -3091,7 +3154,7 @@ def TFL_UnidirectionalSequenceLSTMOp :
|
|||||||
LstmOptionalPeepholeWeightConstraint,
|
LstmOptionalPeepholeWeightConstraint,
|
||||||
LstmProjectionWeightBiasConstraint,
|
LstmProjectionWeightBiasConstraint,
|
||||||
LstmResultConstraint,
|
LstmResultConstraint,
|
||||||
StatefulOperands<[18, 19]>]> {
|
TFL_StatefulOp]> {
|
||||||
let summary = "Unidirectional sequence lstm operator";
|
let summary = "Unidirectional sequence lstm operator";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -3160,6 +3223,11 @@ def TFL_UnidirectionalSequenceLSTMOp :
|
|||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
|
|
||||||
let verifier = [{ return Verify(*this); }];
|
let verifier = [{ return Verify(*this); }];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// StatefulOpInterface:
|
||||||
|
std::vector<int> GetStatefulOperands() { return {18, 19}; }
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def RnnResultConstraint : PredOpTrait<
|
def RnnResultConstraint : PredOpTrait<
|
||||||
@ -3169,7 +3237,7 @@ def RnnResultConstraint : PredOpTrait<
|
|||||||
// UnidirectionalSequenceRNN op.
|
// UnidirectionalSequenceRNN op.
|
||||||
def TFL_UnidirectionalSequenceRNNOp :
|
def TFL_UnidirectionalSequenceRNNOp :
|
||||||
TFL_Op<"unidirectional_sequence_rnn",
|
TFL_Op<"unidirectional_sequence_rnn",
|
||||||
[RnnResultConstraint, StatefulOperands<[4]>]> {
|
[RnnResultConstraint, TFL_StatefulOp]> {
|
||||||
|
|
||||||
let summary = "Unidirectional sequence rnn operator";
|
let summary = "Unidirectional sequence rnn operator";
|
||||||
|
|
||||||
@ -3213,6 +3281,11 @@ def TFL_UnidirectionalSequenceRNNOp :
|
|||||||
let customOption = "SequenceRNNOptions";
|
let customOption = "SequenceRNNOptions";
|
||||||
|
|
||||||
let verifier = [{ return Verify(*this); }];
|
let verifier = [{ return Verify(*this); }];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// StatefulOpInterface:
|
||||||
|
std::vector<int> GetStatefulOperands() { return {4}; }
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_WhereOp : TFL_Op<"where", [NoSideEffect]> {
|
def TFL_WhereOp : TFL_Op<"where", [NoSideEffect]> {
|
||||||
@ -3264,7 +3337,7 @@ def SVDFResultConstraint: PredOpTrait<
|
|||||||
// SVDF op.
|
// SVDF op.
|
||||||
def TFL_SVDFOp :
|
def TFL_SVDFOp :
|
||||||
TFL_Op<"svdf",
|
TFL_Op<"svdf",
|
||||||
[SVDFResultConstraint, StatefulOperands<[4]>]> {
|
[SVDFResultConstraint, TFL_StatefulOp]> {
|
||||||
|
|
||||||
let summary = "Single value decomposition filter operator";
|
let summary = "Single value decomposition filter operator";
|
||||||
|
|
||||||
@ -3300,6 +3373,25 @@ def TFL_SVDFOp :
|
|||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
|
|
||||||
let verifier = [{ return Verify(*this); }];
|
let verifier = [{ return Verify(*this); }];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// StatefulOpInterface:
|
||||||
|
std::vector<int> GetStatefulOperands() { return {4}; }
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def TFL_SegmentSumOp: TFL_Op<"segment_sum", [NoSideEffect]> {
|
||||||
|
let summary = "SegmentSum operator";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Computes the sum along segments of a tensor.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TensorOf<[F32, I32]>:$data,
|
||||||
|
I32Tensor:$segment_ids
|
||||||
|
);
|
||||||
|
let results = (outs TensorOf<[F32, I32]>:$output);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // TFL_OPS
|
#endif // TFL_OPS
|
||||||
|
@ -1,67 +0,0 @@
|
|||||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
// This file defines the op traits used in the MLIR TensorFlow Lite dialect.
|
|
||||||
|
|
||||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
|
|
||||||
#define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
|
|
||||||
|
|
||||||
#include "mlir/IR/OpDefinition.h"
|
|
||||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
|
||||||
|
|
||||||
namespace mlir {
|
|
||||||
namespace OpTrait {
|
|
||||||
namespace TFL {
|
|
||||||
|
|
||||||
// The trait to specify that the specified operands of the TFL op are stateful.
|
|
||||||
// This is used as a trait like this:
|
|
||||||
//
|
|
||||||
// class LSTMOp
|
|
||||||
// : public Op<LSTMOp, OpTrait::TFL::StatefulOperands<18, 19>::Impl> {
|
|
||||||
//
|
|
||||||
template <int... Operands>
|
|
||||||
class StatefulOperands {
|
|
||||||
public:
|
|
||||||
template <typename ConcreteType>
|
|
||||||
class Impl
|
|
||||||
: public TraitBase<ConcreteType, StatefulOperands<Operands...>::Impl> {
|
|
||||||
public:
|
|
||||||
static std::vector<int> GetStatefulOperands() {
|
|
||||||
return std::vector<int>({Operands...});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
// The trait to specify the channel dimension index of the input (first operand)
|
|
||||||
// of an affine TFL op (Conv2D, DepthwiseConv2D, FullyConnected).
|
|
||||||
//
|
|
||||||
// class Conv2DOp
|
|
||||||
// : public Op<Conv2DOp, OpTrait::TFL::ChannelDimIndex<0>::Impl> {
|
|
||||||
//
|
|
||||||
template <int Index>
|
|
||||||
class ChannelDimIndex {
|
|
||||||
public:
|
|
||||||
template <typename ConcreteType>
|
|
||||||
class Impl : public TraitBase<ConcreteType, ChannelDimIndex<Index>::Impl> {
|
|
||||||
public:
|
|
||||||
static int GetChannelDimIndex() { return Index; }
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace TFL
|
|
||||||
} // namespace OpTrait
|
|
||||||
} // namespace mlir
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
|
|
@ -32,6 +32,6 @@ cc_library(
|
|||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:Support",
|
"@llvm-project//mlir:Support",
|
||||||
"@llvm-project//mlir:ViewOpGraph",
|
"@llvm-project//mlir:Transforms",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -107,9 +107,6 @@ void WarningUnusedFlags(const toco::ModelFlags& model_flags,
|
|||||||
if (toco_flags.output_format()) {
|
if (toco_flags.output_format()) {
|
||||||
LOG(WARNING) << "Ignored output_format.";
|
LOG(WARNING) << "Ignored output_format.";
|
||||||
}
|
}
|
||||||
if (toco_flags.default_ranges_min() || toco_flags.default_ranges_max()) {
|
|
||||||
LOG(WARNING) << "Ignored default_ranges_stats.";
|
|
||||||
}
|
|
||||||
if (toco_flags.drop_control_dependency()) {
|
if (toco_flags.drop_control_dependency()) {
|
||||||
LOG(WARNING) << "Ignored drop_control_dependency.";
|
LOG(WARNING) << "Ignored drop_control_dependency.";
|
||||||
}
|
}
|
||||||
@ -242,6 +239,13 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
|||||||
tensorflow::ParseOutputArrayInfo(output_arrays, &specs.outputs));
|
tensorflow::ParseOutputArrayInfo(output_arrays, &specs.outputs));
|
||||||
|
|
||||||
// Other flags.
|
// Other flags.
|
||||||
|
if (toco_flags.has_default_ranges_min()) {
|
||||||
|
quant_specs.default_ranges.first = toco_flags.default_ranges_min();
|
||||||
|
}
|
||||||
|
if (toco_flags.has_default_ranges_max()) {
|
||||||
|
quant_specs.default_ranges.second = toco_flags.default_ranges_max();
|
||||||
|
}
|
||||||
|
|
||||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||||
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
|
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
|
||||||
bool emit_custom_ops = toco_flags.allow_custom_ops();
|
bool emit_custom_ops = toco_flags.allow_custom_ops();
|
||||||
|
@ -206,10 +206,17 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
|
|||||||
std::unique_ptr<OpPassBase<FuncOp>>
|
std::unique_ptr<OpPassBase<FuncOp>>
|
||||||
CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) {
|
CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) {
|
||||||
auto get_name_func = [](Operation *op) {
|
auto get_name_func = [](Operation *op) {
|
||||||
if (auto name = op->getAttrOfType<StringAttr>("name"))
|
Location loc = op->getLoc();
|
||||||
return name.getValue();
|
if (auto name = loc.dyn_cast<NameLoc>()) {
|
||||||
else
|
return name.getName().strref();
|
||||||
return llvm::StringRef("");
|
} else if (auto fused_name = loc.dyn_cast<FusedLoc>()) {
|
||||||
|
for (auto sub_loc : fused_name.getLocations()) {
|
||||||
|
if (auto named_sub_loc = sub_loc.dyn_cast<NameLoc>()) {
|
||||||
|
return named_sub_loc.getName().strref();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return llvm::StringRef("");
|
||||||
};
|
};
|
||||||
|
|
||||||
return CreateImportQuantStatsPass(get_name_func, stats_str);
|
return CreateImportQuantStatsPass(get_name_func, stats_str);
|
||||||
|
@ -23,7 +23,6 @@ cc_library(
|
|||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"quantize_model.h",
|
"quantize_model.h",
|
||||||
"//tensorflow/compiler/mlir/lite:transforms/passes.h",
|
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/compiler/mlir/lite:common",
|
"//tensorflow/compiler/mlir/lite:common",
|
||||||
|
@ -73,19 +73,19 @@ TfLiteStatus QuantizeModel(
|
|||||||
|
|
||||||
// Apply quantization passes
|
// Apply quantization passes
|
||||||
PassManager pm(module->getContext());
|
PassManager pm(module->getContext());
|
||||||
TFL::QuantizationSpecs pass_config;
|
TFL::QuantizationSpecs quant_specs;
|
||||||
pass_config.inference_type = tensorflow::DT_QINT8;
|
quant_specs.inference_type = tensorflow::DT_QINT8;
|
||||||
pass_config.post_training_quantization = true;
|
quant_specs.post_training_quantization = true;
|
||||||
|
|
||||||
bool emit_adaptor = false;
|
bool emit_adaptor = false;
|
||||||
auto input_tf_type = tflite::TflTypeToTfType(input_type);
|
auto input_tf_type = tflite::TflTypeToTfType(input_type);
|
||||||
if (input_tf_type == tensorflow::DT_FLOAT) {
|
if (input_tf_type == tensorflow::DT_FLOAT) {
|
||||||
emit_adaptor = true;
|
emit_adaptor = true;
|
||||||
} else if (input_tf_type == tensorflow::DT_UINT8) {
|
} else if (input_tf_type == tensorflow::DT_UINT8) {
|
||||||
pass_config.inference_type = tensorflow::DT_QUINT8;
|
quant_specs.inference_type = tensorflow::DT_QUINT8;
|
||||||
}
|
}
|
||||||
|
|
||||||
pm.addPass(TFL::CreatePrepareQuantizePass(pass_config));
|
pm.addPass(TFL::CreatePrepareQuantizePass(quant_specs));
|
||||||
pm.addPass(TFL::CreateQuantizePass());
|
pm.addPass(TFL::CreateQuantizePass());
|
||||||
pm.addPass(TFL::CreatePostQuantizePass(emit_adaptor));
|
pm.addPass(TFL::CreatePostQuantizePass(emit_adaptor));
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "llvm/ADT/Optional.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
|
||||||
@ -64,6 +65,10 @@ struct QuantizationSpecs {
|
|||||||
// quantization aware training or calibration, for the remaining tensors.
|
// quantization aware training or calibration, for the remaining tensors.
|
||||||
std::vector<std::pair<double, double>> input_ranges;
|
std::vector<std::pair<double, double>> input_ranges;
|
||||||
|
|
||||||
|
// The default ranges can be used when a tensor doesn't have quantization
|
||||||
|
// parameters and couldn't be quantized. Used only for latency tests.
|
||||||
|
std::pair<llvm::Optional<double>, llvm::Optional<double>> default_ranges;
|
||||||
|
|
||||||
// A serialized "QuantizationInfo" object to specify value ranges for some of
|
// A serialized "QuantizationInfo" object to specify value ranges for some of
|
||||||
// the tensors with known names.
|
// the tensors with known names.
|
||||||
std::string serialized_quant_stats = "";
|
std::string serialized_quant_stats = "";
|
||||||
|
@ -35,7 +35,6 @@ limitations under the License.
|
|||||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
|
|
||||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
@ -3,7 +3,8 @@
|
|||||||
|
|
||||||
// CHECK-LABEL: import_stats_skip
|
// CHECK-LABEL: import_stats_skip
|
||||||
func @import_stats_skip(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
|
func @import_stats_skip(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
|
||||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "skip"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||||
|
loc(fused["skip1", "skip2.cc":10:8, callsite("op" at "skip3.cc":10:8)])
|
||||||
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
|
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
|
||||||
|
|
||||||
// CHECK-NEXT: "tfl.split"
|
// CHECK-NEXT: "tfl.split"
|
||||||
@ -12,7 +13,8 @@ func @import_stats_skip(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf3
|
|||||||
|
|
||||||
// CHECK-LABEL: import_stats_name
|
// CHECK-LABEL: import_stats_name
|
||||||
func @import_stats_name(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
|
func @import_stats_name(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
|
||||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||||
|
loc(fused["skip1.cc":10:8, "op", callsite("skip2" at "skip3.cc":10:8)])
|
||||||
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
|
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
|
||||||
|
|
||||||
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"
|
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"
|
||||||
@ -23,7 +25,8 @@ func @import_stats_name(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf3
|
|||||||
|
|
||||||
// CHECK-LABEL: import_stats_name_port
|
// CHECK-LABEL: import_stats_name_port
|
||||||
func @import_stats_name_port(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
|
func @import_stats_name_port(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
|
||||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op_0"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||||
|
loc(fused["skip1.cc":10:8, "op_0", callsite("skip2" at "skip3.cc":10:8)])
|
||||||
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
|
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
|
||||||
|
|
||||||
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"
|
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"
|
||||||
@ -34,6 +37,7 @@ func @import_stats_name_port(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor
|
|||||||
// CHECK-LABEL: import_stats_name_regex
|
// CHECK-LABEL: import_stats_name_regex
|
||||||
func @import_stats_name_regex(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
|
func @import_stats_name_regex(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
|
||||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op_regex"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op_regex"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||||
|
loc(fused["skip1.cc":10:8, "op_regex", callsite("skip2" at "skip3.cc":10:8)])
|
||||||
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
|
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
|
||||||
|
|
||||||
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"
|
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"
|
||||||
|
@ -0,0 +1,89 @@
|
|||||||
|
// RUN: tf-opt %s --tfl-default-quant --tfl-quantize | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: hardcode_all
|
||||||
|
func @hardcode_all(%arg0: tensor<2x2xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x2xf32> {
|
||||||
|
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function="NONE"}: (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
|
||||||
|
return %0 : tensor<2x2xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform<u8:f32, 0.0078431372549019607:128>>}
|
||||||
|
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>}
|
||||||
|
// Quantized tfl.add
|
||||||
|
// CHECK: %[[add:.*]] = "tfl.add"(%[[q1]], %[[q0]]) {fused_activation_function = "NONE"} : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||||
|
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[add]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>)
|
||||||
|
// CHECK: return %[[dq]]
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: hardcode_input
|
||||||
|
func @hardcode_input(%arg0: tensor<2x2xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x2xf32> {
|
||||||
|
%0 = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>}: (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>
|
||||||
|
%1 = "tfl.dequantize"(%0) : (tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>) -> tensor<2x2xf32>
|
||||||
|
%4 = "tfl.add"(%1, %arg1) {fused_activation_function="NONE"}: (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
|
||||||
|
return %4 : tensor<2x2xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform<u8:f32, 0.0078431372549019607:128>>}
|
||||||
|
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e+00:128>>}
|
||||||
|
// CHECK: %[[add:.*]] = "tfl.add"(%[[q1]], %[[q0]]) {fused_activation_function = "NONE"} : (tensor<2x2x!quant.uniform<u8:f32, 1.000000e+00:128>>
|
||||||
|
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[add]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>)
|
||||||
|
// CHECK: return %[[dq]]
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: hardcode_input_deq
|
||||||
|
func @hardcode_input_deq(%arg0: tensor<2x2x!quant.uniform<u8:f32, 1.0>>, %arg1: tensor<2x1xf32>) -> tensor<2x2xf32> {
|
||||||
|
%1 = "tfl.dequantize"(%arg0) : (tensor<2x2x!quant.uniform<u8:f32, 1.0>>) -> tensor<2x2xf32>
|
||||||
|
%4 = "tfl.add"(%1, %arg1) {fused_activation_function="NONE"}: (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
|
||||||
|
return %4 : tensor<2x2xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[q:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform<u8:f32, 0.0078431372549019607:128>>}
|
||||||
|
// CHECK: %[[add:.*]] = "tfl.add"(%arg0, %[[q]]) {fused_activation_function = "NONE"} : (tensor<2x2x!quant.uniform<u8:f32, 1.000000e+00>>
|
||||||
|
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[add]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>)
|
||||||
|
// CHECK: return %[[dq]]
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: hardcode_output
|
||||||
|
func @hardcode_output(%arg0: tensor<2x2xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x2xf32> {
|
||||||
|
%0 = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>}: (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>
|
||||||
|
%1 = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform<u8:f32, 1.0:128>>}: (tensor<2x1xf32>) -> tensor<2x1x!quant.uniform<u8:f32, 1.0:128>>
|
||||||
|
%2 = "tfl.dequantize"(%0) : (tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>) -> tensor<2x2xf32>
|
||||||
|
%3 = "tfl.dequantize"(%1) : (tensor<2x1x!quant.uniform<u8:f32, 1.0:128>>) -> tensor<2x1xf32>
|
||||||
|
%4 = "tfl.add"(%2, %3) {fused_activation_function="NONE"}: (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
|
||||||
|
return %4 : tensor<2x2xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e+00:128>>}
|
||||||
|
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform<u8:f32, 1.000000e+00:128>>}
|
||||||
|
// CHECK: %[[add:.*]] = "tfl.add"(%[[q0]], %[[q1]]) {fused_activation_function = "NONE"} : (tensor<2x2x!quant.uniform<u8:f32, 1.000000e+00:128>>
|
||||||
|
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[add]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>)
|
||||||
|
// CHECK: return %[[dq]]
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_conv_2d_add
|
||||||
|
func @test_conv_2d_add(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32, 1.0>>, %arg1: tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 1.0>>, %arg2: tensor<32x!quant.uniform<i32:f32, 1.0>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>> {
|
||||||
|
%0 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 1.0>>) -> tensor<1x224x224x3xf32>
|
||||||
|
%1 = "tfl.dequantize"(%arg1) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 1.0>>) -> tensor<32x3x3x3xf32>
|
||||||
|
%2 = "tfl.dequantize"(%arg2) : (tensor<32x!quant.uniform<i32:f32, 1.0>>) -> tensor<32xf32>
|
||||||
|
%3 = "tfl.conv_2d"(%0, %1, %2) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
|
||||||
|
%4 = "tfl.pseudo_qconst"() {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>, value = dense<1> : tensor<1x112x112x32xi8>} : () -> tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>
|
||||||
|
%5 = "tfl.dequantize"(%4) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>) -> tensor<1x112x112x32xf32>
|
||||||
|
%6 = "tfl.add"(%3, %5) {fused_activation_function="NONE"}: (tensor<1x112x112x32xf32>, tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32>
|
||||||
|
%7 = "tfl.quantize"(%6) {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>
|
||||||
|
return %7 : tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>
|
||||||
|
|
||||||
|
// CHECK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %arg1, %arg2)
|
||||||
|
// CHECK-SAME: -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||||
|
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"()
|
||||||
|
// CHECK: %[[add:.*]] = "tfl.add"(%[[conv]], %[[cst]])
|
||||||
|
// CHECK-SAME: -> tensor<1x112x112x32x!quant.uniform<u8:f32, 1.000000e+00>>
|
||||||
|
// CHECK: return %[[add]]
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_conv_2d_activation_and_bias
|
||||||
|
func @test_conv_2d_activation_and_bias(%arg0: tensor<1x224x224x3xf32>, %arg1: tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 1.0>>, %arg2: tensor<32xf32>) -> tensor<1x112x112x32xf32> {
|
||||||
|
%0 = "tfl.dequantize"(%arg1) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 1.0>>) -> tensor<32x3x3x3xf32>
|
||||||
|
%1 = "tfl.conv_2d"(%arg0, %0, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
|
||||||
|
return %1 : tensor<1x112x112x32xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg2) {qtype = tensor<32x!quant.uniform<i32:f32, 0.0078431372549019607>>}
|
||||||
|
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x224x224x3x!quant.uniform<u8:f32, 0.0078431372549019607:128>>}
|
||||||
|
// CHECK: %[[conv:.*]] = "tfl.conv_2d"(%[[q1]], %arg1, %[[q0]])
|
||||||
|
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[conv]]) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.0078431372549019607:128>>)
|
||||||
|
// CHECK: return %[[dq]]
|
||||||
|
}
|
@ -0,0 +1,76 @@
|
|||||||
|
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s
|
||||||
|
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s
|
||||||
|
|
||||||
|
|
||||||
|
func @main(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>) -> tensor<1x64x84x32xf32> {
|
||||||
|
|
||||||
|
// CHECK: {
|
||||||
|
// CHECK-NEXT: version: 3,
|
||||||
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
|
// CHECK-NEXT: builtin_code: CUSTOM,
|
||||||
|
// CHECK-NEXT: custom_code: "Convolution2DTransposeBias"
|
||||||
|
// CHECK-NEXT: } ],
|
||||||
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
// CHECK-NEXT: shape: [ 32, 4, 4, 128 ],
|
||||||
|
// CHECK-NEXT: buffer: 1,
|
||||||
|
// CHECK-NEXT: name: "arg0",
|
||||||
|
// CHECK-NEXT: quantization: {
|
||||||
|
// CHECK-EMPTY:
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }, {
|
||||||
|
// CHECK-NEXT: shape: [ 1, 32, 42, 128 ],
|
||||||
|
// CHECK-NEXT: buffer: 2,
|
||||||
|
// CHECK-NEXT: name: "arg1",
|
||||||
|
// CHECK-NEXT: quantization: {
|
||||||
|
// CHECK-EMPTY:
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }, {
|
||||||
|
// CHECK-NEXT: shape: [ 4 ],
|
||||||
|
// CHECK-NEXT: type: INT32,
|
||||||
|
// CHECK-NEXT: buffer: 3,
|
||||||
|
// CHECK-NEXT: name: "arg2",
|
||||||
|
// CHECK-NEXT: quantization: {
|
||||||
|
// CHECK-EMPTY:
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }, {
|
||||||
|
// CHECK-NEXT: shape: [ 1, 64, 84, 32 ],
|
||||||
|
// CHECK-NEXT: buffer: 4,
|
||||||
|
// CHECK-NEXT: name: "tfl.convolution_2d_transpose_bias",
|
||||||
|
// CHECK-NEXT: quantization: {
|
||||||
|
// CHECK-EMPTY:
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: } ],
|
||||||
|
// CHECK-NEXT: inputs: [ 0, 1, 2 ],
|
||||||
|
// CHECK-NEXT: outputs: [ 3 ],
|
||||||
|
// CHECK-NEXT: operators: [ {
|
||||||
|
// CHECK-NEXT: inputs: [ 0, 1, 2 ],
|
||||||
|
// CHECK-NEXT: outputs: [ 3 ],
|
||||||
|
// CHECK-NEXT: custom_options: [ 1, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0 ]
|
||||||
|
// CHECK-NEXT: } ],
|
||||||
|
// CHECK-NEXT: name: "main"
|
||||||
|
// CHECK-NEXT: } ],
|
||||||
|
// CHECK-NEXT: description: "MLIR Converted.",
|
||||||
|
// CHECK-NEXT: buffers: [ {
|
||||||
|
// CHECK-EMPTY:
|
||||||
|
// CHECK-NEXT: }, {
|
||||||
|
// CHECK-EMPTY:
|
||||||
|
// CHECK-NEXT: }, {
|
||||||
|
// CHECK-EMPTY:
|
||||||
|
// CHECK-NEXT: }, {
|
||||||
|
// CHECK-EMPTY:
|
||||||
|
// CHECK-NEXT: }, {
|
||||||
|
// CHECK-EMPTY:
|
||||||
|
// CHECK-NEXT: } ]
|
||||||
|
// CHECK-NEXT:}
|
||||||
|
|
||||||
|
// MLIR-LABEL: func @main(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>)
|
||||||
|
// MLIR-SAME: -> tensor<1x64x84x32xf32>
|
||||||
|
// MLIR: %0 = "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %arg2)
|
||||||
|
// MLIR-SAME: {padding = "SAME", stride_h = 1 : i32, stride_w = 2 : i32}
|
||||||
|
// MLIR-SAME: (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32>
|
||||||
|
// MLIR-NEXT: return %0 : tensor<1x64x84x32xf32>
|
||||||
|
|
||||||
|
%0 = "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 1 : i32, stride_w = 2 : i32} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32>
|
||||||
|
return %0 : tensor<1x64x84x32xf32>
|
||||||
|
}
|
@ -0,0 +1,39 @@
|
|||||||
|
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -emit-builtin-tflite-ops=false -o - | flatbuffer_to_string - | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK: {
|
||||||
|
// CHECK: version: 3,
|
||||||
|
// CHECK: operator_codes: [ {
|
||||||
|
// CHECK: builtin_code: CUSTOM,
|
||||||
|
// CHECK: custom_code: "HashTableV2"
|
||||||
|
// CHECK: } ],
|
||||||
|
// CHECK: subgraphs: [ {
|
||||||
|
// CHECK: tensors: [ {
|
||||||
|
// CHECK: shape: [ ],
|
||||||
|
// CHECK: type: INT32,
|
||||||
|
// CHECK: buffer: 1,
|
||||||
|
// CHECK: name: "tf.HashTableV2",
|
||||||
|
// CHECK: quantization: {
|
||||||
|
// CHECK-EMPTY
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: } ],
|
||||||
|
// CHECK: inputs: [ ],
|
||||||
|
// CHECK: outputs: [ 0 ],
|
||||||
|
// CHECK: operators: [ {
|
||||||
|
// CHECK: inputs: [ ],
|
||||||
|
// CHECK: outputs: [ 0 ],
|
||||||
|
// CHECK: custom_options:
|
||||||
|
// CHECK: name: "main"
|
||||||
|
// CHECK: } ],
|
||||||
|
// CHECK: description: "MLIR Converted.",
|
||||||
|
// CHECK: buffers: [ {
|
||||||
|
// CHECK-EMPTY
|
||||||
|
// CHECK: }, {
|
||||||
|
// CHECK-EMPTY
|
||||||
|
// CHECK: } ]
|
||||||
|
// CHECK: }
|
||||||
|
|
||||||
|
func @main() -> tensor<*x!tf.resource> {
|
||||||
|
%0 = "tf.HashTableV2"() {container = "" , shared_name= "table", use_node_name_sharing = false, key_dtype = i32, value_dtype = i32 } : () -> tensor<*x!tf.resource>
|
||||||
|
return %0 : tensor<*x!tf.resource>
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,65 @@
|
|||||||
|
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s
|
||||||
|
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s
|
||||||
|
|
||||||
|
func @main(%arg0: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) {
|
||||||
|
|
||||||
|
// CHECK: {
|
||||||
|
// CHECK-NEXT: version: 3,
|
||||||
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
|
// CHECK-NEXT: builtin_code: CUSTOM,
|
||||||
|
// CHECK-NEXT: custom_code: "MaxPoolingWithArgmax2D"
|
||||||
|
// CHECK-NEXT: } ],
|
||||||
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
// CHECK-NEXT: shape: [ 1, 64, 64, 32 ],
|
||||||
|
// CHECK-NEXT: buffer: 1,
|
||||||
|
// CHECK-NEXT: name: "arg0",
|
||||||
|
// CHECK-NEXT: quantization: {
|
||||||
|
// CHECK-EMPTY:
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }, {
|
||||||
|
// CHECK-NEXT: shape: [ 1, 32, 32, 32 ],
|
||||||
|
// CHECK-NEXT: buffer: 2,
|
||||||
|
// CHECK-NEXT: name: "tfl.max_pooling_with_argmax_2d",
|
||||||
|
// CHECK-NEXT: quantization: {
|
||||||
|
// CHECK-EMPTY:
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }, {
|
||||||
|
// CHECK-NEXT: shape: [ 1, 32, 32, 32 ],
|
||||||
|
// CHECK-NEXT: buffer: 3,
|
||||||
|
// CHECK-NEXT: name: "tfl.max_pooling_with_argmax_2d:1",
|
||||||
|
// CHECK-NEXT: quantization: {
|
||||||
|
// CHECK-EMPTY:
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: } ],
|
||||||
|
// CHECK-NEXT: inputs: [ 0 ],
|
||||||
|
// CHECK-NEXT: outputs: [ 1, 2 ],
|
||||||
|
// CHECK-NEXT: operators: [ {
|
||||||
|
// CHECK-NEXT: inputs: [ 0 ],
|
||||||
|
// CHECK-NEXT: outputs: [ 1, 2 ],
|
||||||
|
// CHECK-NEXT: custom_options: [ 1, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||||
|
// CHECK-NEXT: } ],
|
||||||
|
// CHECK-NEXT: name: "main"
|
||||||
|
// CHECK-NEXT: } ],
|
||||||
|
// CHECK-NEXT: description: "MLIR Converted.",
|
||||||
|
// CHECK-NEXT: buffers: [ {
|
||||||
|
// CHECK-EMPTY:
|
||||||
|
// CHECK-NEXT: }, {
|
||||||
|
// CHECK-EMPTY:
|
||||||
|
// CHECK-NEXT: }, {
|
||||||
|
// CHECK-EMPTY:
|
||||||
|
// CHECK-NEXT: }, {
|
||||||
|
// CHECK-EMPTY:
|
||||||
|
// CHECK-NEXT: } ]
|
||||||
|
// CHECK-NEXT:}
|
||||||
|
|
||||||
|
// MLIR-LABEL: func @main(%arg0: tensor<1x64x64x32xf32>)
|
||||||
|
// MLIR-SAME: -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
|
||||||
|
// MLIR: %value, %indices = "tfl.max_pooling_with_argmax_2d"(%arg0)
|
||||||
|
// MLIR-SAME: {filter_h = 4 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 1 : i32}
|
||||||
|
// MLIR-SAME: (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
|
||||||
|
// MLIR-NEXT: return %value, %indices : tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>
|
||||||
|
|
||||||
|
%0, %1 = "tfl.max_pooling_with_argmax_2d"(%arg0) {filter_h = 4 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 1 : i32} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
|
||||||
|
return %0, %1 : tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>
|
||||||
|
}
|
@ -0,0 +1,65 @@
|
|||||||
|
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s
|
||||||
|
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s
|
||||||
|
|
||||||
|
func @main(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf32>) -> tensor<1x8x8x128xf32> {
|
||||||
|
|
||||||
|
// CHECK: {
|
||||||
|
// CHECK-NEXT: version: 3,
|
||||||
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
|
// CHECK-NEXT: builtin_code: CUSTOM,
|
||||||
|
// CHECK-NEXT: custom_code: "MaxUnpooling2D"
|
||||||
|
// CHECK-NEXT: } ],
|
||||||
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
// CHECK-NEXT: shape: [ 1, 8, 8, 128 ],
|
||||||
|
// CHECK-NEXT: buffer: 1,
|
||||||
|
// CHECK-NEXT: name: "arg0",
|
||||||
|
// CHECK-NEXT: quantization: {
|
||||||
|
// CHECK-EMPTY:
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }, {
|
||||||
|
// CHECK-NEXT: shape: [ 1, 8, 8, 128 ],
|
||||||
|
// CHECK-NEXT: buffer: 2,
|
||||||
|
// CHECK-NEXT: name: "arg1",
|
||||||
|
// CHECK-NEXT: quantization: {
|
||||||
|
// CHECK-EMPTY:
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }, {
|
||||||
|
// CHECK-NEXT: shape: [ 1, 8, 8, 128 ],
|
||||||
|
// CHECK-NEXT: buffer: 3,
|
||||||
|
// CHECK-NEXT: name: "tfl.max_unpooling_2d",
|
||||||
|
// CHECK-NEXT: quantization: {
|
||||||
|
// CHECK-EMPTY:
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: } ],
|
||||||
|
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||||
|
// CHECK-NEXT: outputs: [ 2 ],
|
||||||
|
// CHECK-NEXT: operators: [ {
|
||||||
|
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||||
|
// CHECK-NEXT: outputs: [ 2 ],
|
||||||
|
// CHECK-NEXT: custom_options: [ 1, 0, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||||
|
// CHECK-NEXT: } ],
|
||||||
|
// CHECK-NEXT: name: "main"
|
||||||
|
// CHECK-NEXT: } ],
|
||||||
|
// CHECK-NEXT: description: "MLIR Converted.",
|
||||||
|
// CHECK-NEXT: buffers: [ {
|
||||||
|
// CHECK-EMPTY:
|
||||||
|
// CHECK-NEXT: }, {
|
||||||
|
// CHECK-EMPTY:
|
||||||
|
// CHECK-NEXT: }, {
|
||||||
|
// CHECK-EMPTY:
|
||||||
|
// CHECK-NEXT: }, {
|
||||||
|
// CHECK-EMPTY:
|
||||||
|
// CHECK-NEXT: } ]
|
||||||
|
// CHECK-NEXT:}
|
||||||
|
|
||||||
|
// MLIR-LABEL: func @main(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf32>)
|
||||||
|
// MLIR-SAME: -> tensor<1x8x8x128xf32>
|
||||||
|
// MLIR: %0 = "tfl.max_unpooling_2d"(%arg0, %arg1)
|
||||||
|
// MLIR-SAME: {filter_h = 1 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 4 : i32, stride_w = 2 : i32}
|
||||||
|
// MLIR-SAME: (tensor<1x8x8x128xf32>, tensor<1x8x8x128xf32>) -> tensor<1x8x8x128xf32>
|
||||||
|
// MLIR-NEXT: return %0 : tensor<1x8x8x128xf32>
|
||||||
|
|
||||||
|
%0 = "tfl.max_unpooling_2d"(%arg0, %arg1) {filter_h = 1 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 4 : i32, stride_w = 2 : i32} : (tensor<1x8x8x128xf32>, tensor<1x8x8x128xf32>) -> (tensor<1x8x8x128xf32>)
|
||||||
|
return %0 : tensor<1x8x8x128xf32>
|
||||||
|
}
|
@ -1977,3 +1977,12 @@ func @testTransposeConvBadOutputShape(%arg1: tensor<32x4x4x128xf32>, %arg2: tens
|
|||||||
%0 = "tfl.transpose_conv"(%cst, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>) -> tensor<1x64x84x31xf32>
|
%0 = "tfl.transpose_conv"(%cst, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>) -> tensor<1x64x84x31xf32>
|
||||||
return %0 : tensor<1x64x84x31xf32>
|
return %0 : tensor<1x64x84x31xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: testDensify
|
||||||
|
func @testDensify(%arg0: tensor<? x f32>) -> tensor<? x f32> {
|
||||||
|
// CHECK: "tfl.densify"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||||
|
%0 = "tfl.densify"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
|
||||||
|
return %0 : tensor<? x f32>
|
||||||
|
}
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
|
// Run optimize pass only and check the results.
|
||||||
// RUN: tf-opt %s -tfl-optimize | FileCheck %s
|
// RUN: tf-opt %s -tfl-optimize | FileCheck %s
|
||||||
|
// Run optimize pass and then canonicalize pass, and make sure some folding is applied.
|
||||||
|
// RUN: tf-opt %s -tfl-optimize -canonicalize | FileCheck --check-prefix=FOLD %s
|
||||||
|
|
||||||
// CHECK-LABEL: fusedConv2dRelu
|
// CHECK-LABEL: fusedConv2dRelu
|
||||||
func @fusedConv2dRelu(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<256x30x30x16xf32> {
|
func @fusedConv2dRelu(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<256x30x30x16xf32> {
|
||||||
@ -75,10 +78,10 @@ func @fuseSubIntoFollowingConv2d(%arg0: tensor<256x32x32x3xf32>) -> tensor<256x3
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @fuseAddIntoDepthwiseConv2d
|
// CHECK-LABEL: @fuseAddIntoDepthwiseConv2d
|
||||||
func @fuseAddIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> {
|
func @fuseAddIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> {
|
||||||
%cst = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
|
%cst = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
|
||||||
%cst_0 = constant dense<1.5> : tensor<16xf32>
|
%cst_0 = constant dense<1.5> : tensor<16xf32>
|
||||||
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||||
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||||
return %1 : tensor<256x30x30x16xf32>
|
return %1 : tensor<256x30x30x16xf32>
|
||||||
|
|
||||||
@ -87,10 +90,10 @@ func @fuseAddIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<1
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: fuseSubIntoDepthwiseConv2d
|
// CHECK-LABEL: fuseSubIntoDepthwiseConv2d
|
||||||
func @fuseSubIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> {
|
func @fuseSubIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> {
|
||||||
%cst = constant dense<0.5> : tensor<16xf32>
|
%cst = constant dense<0.5> : tensor<16xf32>
|
||||||
%cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
|
%cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
|
||||||
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||||
%1 = "tfl.sub"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
%1 = "tfl.sub"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||||
return %1 : tensor<256x30x30x16xf32>
|
return %1 : tensor<256x30x30x16xf32>
|
||||||
|
|
||||||
@ -128,10 +131,10 @@ func @fuseAddWithRelu6IntoConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<1
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @fuseAddWithRelu6IntoDepthwiseConv2d
|
// CHECK-LABEL: @fuseAddWithRelu6IntoDepthwiseConv2d
|
||||||
func @fuseAddWithRelu6IntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> {
|
func @fuseAddWithRelu6IntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> {
|
||||||
%cst = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
|
%cst = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
|
||||||
%cst_0 = constant dense<1.5> : tensor<16xf32>
|
%cst_0 = constant dense<1.5> : tensor<16xf32>
|
||||||
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||||
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "RELU6"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "RELU6"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||||
return %1 : tensor<256x30x30x16xf32>
|
return %1 : tensor<256x30x30x16xf32>
|
||||||
|
|
||||||
@ -302,6 +305,58 @@ func @FuseFullyConnectedAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf
|
|||||||
// CHECK: return %[[fc]]
|
// CHECK: return %[[fc]]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @FuseFullyConnectedReshapeAddConst
|
||||||
|
// FOLD-LABEL: @FuseFullyConnectedReshapeAddConst
|
||||||
|
func @FuseFullyConnectedReshapeAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||||
|
%cst = constant dense<3.0> : tensor<40x40xf32>
|
||||||
|
%cst2 = constant dense<2.0> : tensor<40xf32>
|
||||||
|
%shape1 = constant dense<[1, 40, 40]> : tensor<3xi32>
|
||||||
|
%shape2 = constant dense<[40, 40]> : tensor<2xi32>
|
||||||
|
|
||||||
|
%0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, tensor<40x40xf32>) -> (tensor<40x40xf32>)
|
||||||
|
%1 = "tfl.reshape"(%0, %shape1) : (tensor<40x40xf32>, tensor<3xi32>) -> tensor<1x40x40xf32>
|
||||||
|
%2 = "tfl.add"(%1, %cst2) {fused_activation_function = "NONE"} : (tensor<1x40x40xf32>, tensor<40xf32>) -> tensor<1x40x40xf32>
|
||||||
|
%3 = "tfl.reshape"(%2, %shape2) : (tensor<1x40x40xf32>, tensor<2xi32>) -> tensor<40x40xf32>
|
||||||
|
|
||||||
|
return %3 : tensor<40x40xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40x40xf32>
|
||||||
|
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
|
||||||
|
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%[[fc]]
|
||||||
|
// CHECK: %[[rs2:.*]] = "tfl.reshape"(%[[rs1]]
|
||||||
|
// CHECK: return %[[rs2]]
|
||||||
|
|
||||||
|
// FOLD: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40x40xf32>
|
||||||
|
// FOLD: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
|
||||||
|
// FOLD: return %[[fc]]
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @NotReorderReshapeAddIfNotBroadcastable
|
||||||
|
func @NotReorderReshapeAddIfNotBroadcastable(%arg0: tensor<40x10x4xf32>) -> tensor<40x40xf32> {
|
||||||
|
%cst = constant dense<2.0> : tensor<40xf32>
|
||||||
|
%shape = constant dense<[40, 40]> : tensor<2xi32>
|
||||||
|
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<40x10x4xf32>, tensor<2xi32>) -> tensor<40x40xf32>
|
||||||
|
%2 = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40xf32>) -> tensor<40x40xf32>
|
||||||
|
return %2 : tensor<40x40xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0
|
||||||
|
// CHECK: %[[rs2:.*]] = "tfl.add"(%[[rs1]]
|
||||||
|
// CHECK: return %[[rs2]]
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @NotReorderReshapeAddIfNotTailingDim
|
||||||
|
func @NotReorderReshapeAddIfNotTailingDim(%arg0: tensor<40x40x1xf32>) -> tensor<40x40xf32> {
|
||||||
|
%cst = constant dense<2.0> : tensor<1x40xf32>
|
||||||
|
%shape = constant dense<[40, 40]> : tensor<2xi32>
|
||||||
|
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<40x40x1xf32>, tensor<2xi32>) -> tensor<40x40xf32>
|
||||||
|
%2 = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<1x40xf32>) -> tensor<40x40xf32>
|
||||||
|
return %2 : tensor<40x40xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0
|
||||||
|
// CHECK: %[[rs2:.*]] = "tfl.add"(%[[rs1]]
|
||||||
|
// CHECK: return %[[rs2]]
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @FuseFullyConnectedRelu
|
// CHECK-LABEL: @FuseFullyConnectedRelu
|
||||||
func @FuseFullyConnectedRelu(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> {
|
func @FuseFullyConnectedRelu(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> {
|
||||||
%0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32>
|
%0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32>
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
// RUN: tf-opt -tfl-prepare-composite-funcs-tf %s | FileCheck %s --dump-input-on-failure
|
// RUN: tf-opt -tfl-prepare-composite-funcs-tf %s -split-input-file | FileCheck %s --dump-input-on-failure
|
||||||
|
|
||||||
|
module{
|
||||||
func @embedding(%arg0: tensor<*xf32>, %arg1: tensor<*xi32>) -> tensor<*xf32> attributes {tf._implements = "embedding_matmul", tf._reference = "mlir"} {
|
func @embedding(%arg0: tensor<*xf32>, %arg1: tensor<*xi32>) -> tensor<*xf32> attributes {tf._implements = "embedding_matmul", tf._reference = "mlir"} {
|
||||||
%0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
%0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||||
%1 = "tf.ExpandDims"(%arg1, %0) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
%1 = "tf.ExpandDims"(%arg1, %0) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||||
@ -148,3 +149,39 @@ func @layernormalizedlstmcellsimple(%arg0: tensor<1x?xf32>, %arg1: tensor<3x4xf3
|
|||||||
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x?xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<3x1xf32>, tensor<3xf32>, tensor<1x3xf32>, tensor<1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x3xf32>
|
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x?xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<3x1xf32>, tensor<3xf32>, tensor<1x3xf32>, tensor<1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x3xf32>
|
||||||
// CHECK: [[VAL_104:%.*]] = tensor_cast [[VAL_105:%.*]] : tensor<1x3xf32> to tensor<1x?xf32>
|
// CHECK: [[VAL_104:%.*]] = tensor_cast [[VAL_105:%.*]] : tensor<1x3xf32> to tensor<1x?xf32>
|
||||||
// CHECK: return [[VAL_104]] : tensor<1x?xf32>
|
// CHECK: return [[VAL_104]] : tensor<1x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
module {
|
||||||
|
func @inference_standard_lstm_7410(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.signature.is_stateful} {
|
||||||
|
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
|
||||||
|
%1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
|
||||||
|
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
|
||||||
|
%3 = "tf.Add"(%2, %arg1) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
|
||||||
|
%4 = "tf.Add"(%2, %arg2) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x?x10xf32>
|
||||||
|
%5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
|
||||||
|
%6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||||
|
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: func @inference_standard_lstm_7410([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<?x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.signature.is_stateful} {
|
||||||
|
// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||||
|
// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
|
||||||
|
// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||||
|
// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32>
|
||||||
|
// CHECK: [[VAL_10:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
|
// CHECK: [[VAL_11:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
// CHECK: [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
|
||||||
|
// CHECK: [[VAL_13:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
|
// CHECK: [[VAL_14:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
// CHECK: [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
|
||||||
|
// CHECK: [[VAL_16:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
|
// CHECK: [[VAL_17:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
// CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
|
||||||
|
// CHECK: [[VAL_19:%.*]] = constant unit
|
||||||
|
// CHECK: [[VAL_20:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) ( {
|
||||||
|
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<?x8x10xf32>
|
||||||
|
// CHECK: return [[VAL_21:%.*]] : tensor<?x8x10xf32>
|
||||||
|
|
||||||
|
}
|
||||||
|
@ -414,6 +414,14 @@ func @CheckNumerics(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
|||||||
// CHECK: return %arg0 : tensor<3xf32>
|
// CHECK: return %arg0 : tensor<3xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @placeholder_with_default(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
||||||
|
%0 = "tf.PlaceholderWithDefault"(%arg0): (tensor<3xf32>) -> tensor<3xf32>
|
||||||
|
return %0 : tensor<3xf32>
|
||||||
|
// Should be converted to Identity and then from Identity to value
|
||||||
|
// CHECK-LABEL: placeholder_with_default
|
||||||
|
// CHECK: return %arg0 : tensor<3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @NoPadStridedSliceNonNewAxisMask
|
// CHECK-LABEL: @NoPadStridedSliceNonNewAxisMask
|
||||||
func @NoPadStridedSliceNonNewAxisMask(%arg0: tensor<1x2x3x1xf32>) -> tensor<1x2x3x1xf32> {
|
func @NoPadStridedSliceNonNewAxisMask(%arg0: tensor<1x2x3x1xf32>) -> tensor<1x2x3x1xf32> {
|
||||||
%cst = constant dense<0> : tensor<4xi32>
|
%cst = constant dense<0> : tensor<4xi32>
|
||||||
@ -426,8 +434,8 @@ func @NoPadStridedSliceNonNewAxisMask(%arg0: tensor<1x2x3x1xf32>) -> tensor<1x2x
|
|||||||
// CHECK: %0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 15 : i64, ellipsis_mask = 0 : i64, end_mask = 15 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1x2x3x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
|
// CHECK: %0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 15 : i64, ellipsis_mask = 0 : i64, end_mask = 15 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1x2x3x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @PadStridedSliceNewAxisMask
|
// CHECK-LABEL: @PadStridedSliceNewAxisMask1
|
||||||
func @PadStridedSliceNewAxisMask(%arg0: tensor<2x3xf32>) -> tensor<1x2x3x1xf32> {
|
func @PadStridedSliceNewAxisMask1(%arg0: tensor<2x3xf32>) -> tensor<1x2x3x1xf32> {
|
||||||
%cst = constant dense<0> : tensor<4xi32>
|
%cst = constant dense<0> : tensor<4xi32>
|
||||||
%cst_0 = constant dense<1> : tensor<4xi32>
|
%cst_0 = constant dense<1> : tensor<4xi32>
|
||||||
%0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 9 : i64, shrink_axis_mask = 0 : i64} : (tensor<2x3xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
|
%0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 9 : i64, shrink_axis_mask = 0 : i64} : (tensor<2x3xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
|
||||||
@ -439,3 +447,12 @@ func @PadStridedSliceNewAxisMask(%arg0: tensor<2x3xf32>) -> tensor<1x2x3x1xf32>
|
|||||||
// CHECK: %0 = "tf.Reshape"(%arg0, %[[cst_1]]) : (tensor<2x3xf32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
|
// CHECK: %0 = "tf.Reshape"(%arg0, %[[cst_1]]) : (tensor<2x3xf32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
|
||||||
// CHECK: %1 = "tf.StridedSlice"(%0, %cst, %cst, %cst_0) {begin_mask = 15 : i64, ellipsis_mask = 0 : i64, end_mask = 15 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1x2x3x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
|
// CHECK: %1 = "tf.StridedSlice"(%0, %cst, %cst, %cst_0) {begin_mask = 15 : i64, ellipsis_mask = 0 : i64, end_mask = 15 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1x2x3x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @PadStridedSliceNewAxisMask2
|
||||||
|
func @PadStridedSliceNewAxisMask2(%arg0: tensor<4x64x64x1xf32>) -> tensor<1x4x64x64xf32> {
|
||||||
|
%cst = constant dense<0> : tensor<3xi32>
|
||||||
|
%cst_0 = constant dense<1> : tensor<3xi32>
|
||||||
|
%0 = "tf.Squeeze"(%arg0) {T = f32, _output_shapes = ["tfshape$dim { size: 4 } dim { size: 64 } dim { size: 64 }"], device = "", squeeze_dims = []} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
|
||||||
|
%1 = "tf.StridedSlice"(%0, %cst, %cst, %cst_0) {Index = i32, T = f32, _output_shapes = ["tfshape$dim { size: 1 } dim { size: 4 } dim { size: 64 } dim { size: 64 }"], begin_mask = 6 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 1 : i64, shrink_axis_mask = 0 : i64} : (tensor<4x64x64xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x4x64x64xf32>
|
||||||
|
return %1 : tensor<1x4x64x64xf32>
|
||||||
|
}
|
||||||
|
@ -43,6 +43,16 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
|
|||||||
quant_specs.inference_type != quant_specs.inference_input_type;
|
quant_specs.inference_type != quant_specs.inference_input_type;
|
||||||
pass_manager->addPass(
|
pass_manager->addPass(
|
||||||
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
|
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
|
||||||
|
|
||||||
|
if (quant_specs.default_ranges.first.hasValue() ||
|
||||||
|
quant_specs.default_ranges.second.hasValue()) {
|
||||||
|
pass_manager->addPass(mlir::TFL::CreateDefaultQuantParamsPass(
|
||||||
|
quant_specs.default_ranges.first.getValueOr(0.0),
|
||||||
|
quant_specs.default_ranges.second.getValueOr(0.0)));
|
||||||
|
pass_manager->addPass(mlir::TFL::CreateQuantizePass());
|
||||||
|
pass_manager->addPass(
|
||||||
|
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||||
@ -115,7 +125,8 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
|||||||
if (pass_config.emit_builtin_tflite_ops) {
|
if (pass_config.emit_builtin_tflite_ops) {
|
||||||
// Prepare for TFLite dialect, rerun canonicalization, and then legalize to
|
// Prepare for TFLite dialect, rerun canonicalization, and then legalize to
|
||||||
// the TFLite dialect.
|
// the TFLite dialect.
|
||||||
pass_manager->addPass(mlir::TFL::CreatePrepareTFPass());
|
pass_manager->addPass(
|
||||||
|
mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul));
|
||||||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||||
pass_manager->addPass(mlir::TFL::CreateLegalizeTFPass());
|
pass_manager->addPass(mlir::TFL::CreateLegalizeTFPass());
|
||||||
pass_manager->addPass(mlir::TFL::CreateOptimizePass());
|
pass_manager->addPass(mlir::TFL::CreateOptimizePass());
|
||||||
|
@ -86,15 +86,15 @@ StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
|
|||||||
if (use_splatted_constant) {
|
if (use_splatted_constant) {
|
||||||
return tensorflow::GraphdefToSplattedMlirTranslateFunction(
|
return tensorflow::GraphdefToSplattedMlirTranslateFunction(
|
||||||
file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
|
file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
|
||||||
input_shapes, output_arrays, prune_unused_nodes,
|
input_shapes, output_arrays, /*control_output_arrays=*/"",
|
||||||
/*convert_legacy_fed_inputs=*/true,
|
prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
|
||||||
/*graph_as_function=*/false, /*upgrade_legacy=*/true, context);
|
/*graph_as_function=*/false, /*upgrade_legacy=*/true, context);
|
||||||
}
|
}
|
||||||
return tensorflow::GraphdefToMlirTranslateFunction(
|
return tensorflow::GraphdefToMlirTranslateFunction(
|
||||||
file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
|
file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
|
||||||
input_shapes, output_arrays, prune_unused_nodes,
|
input_shapes, output_arrays, /*control_output_arrays=*/"",
|
||||||
/*convert_legacy_fed_inputs=*/true, /*graph_as_function=*/false,
|
prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
|
||||||
/*upgrade_legacy=*/true, context);
|
/*graph_as_function=*/false, /*upgrade_legacy=*/true, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvertTFExecutorToTFLOrFlatbuffer(
|
Status ConvertTFExecutorToTFLOrFlatbuffer(
|
||||||
|
234
tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc
Normal file
234
tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc
Normal file
@ -0,0 +1,234 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
|
#include "mlir/IR/AffineExpr.h"
|
||||||
|
#include "mlir/IR/AffineMap.h"
|
||||||
|
#include "mlir/IR/Attributes.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Support/Functional.h"
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/StringSwitch.h"
|
||||||
|
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||||
|
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||||
|
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// The Pass to add default quantization parameters for the activations which
|
||||||
|
// don't have quantization information. These default parameters are usually
|
||||||
|
// not from real measurement, so this pass is only for test purpose.
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace TFL {
|
||||||
|
// Includs an auto-generated function, which can retrieve the quantization
|
||||||
|
// specification for an TFL operation. The signature of the function is
|
||||||
|
// std::unique_pointer<OpQuantSpec> TFL::GetOpQuantSpec(Operation *)
|
||||||
|
#include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class DefaultQuantParamsPass : public FunctionPass<DefaultQuantParamsPass> {
|
||||||
|
public:
|
||||||
|
explicit DefaultQuantParamsPass(double default_min, double default_max)
|
||||||
|
: default_min_(default_min), default_max_(default_max) {}
|
||||||
|
|
||||||
|
void runOnFunction() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Whether the value is used as a bias input of another op. Here we assume
|
||||||
|
// bias is used immediately by the user. This assumption is always correct
|
||||||
|
// after constant folding.
|
||||||
|
bool UsedAsBias(Value value) {
|
||||||
|
for (auto &use : value.getUses()) {
|
||||||
|
auto biases = TFL::GetOpQuantSpec(use.getOwner())->biases_params;
|
||||||
|
if (biases.find(use.getOperandNumber()) != biases.end()) return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uses `quant_params` to quantize `value` and inserting a pair of
|
||||||
|
// tfl.quantize and tfl.dequantize ops for this `value`.
|
||||||
|
void QuantizeValue(OpBuilder builder, Value value,
|
||||||
|
TFL::QuantParams quant_params);
|
||||||
|
|
||||||
|
// If the value hasn't been quantized, the functions adds it to `values`.
|
||||||
|
void AddToWorkListIfUnquantized(Value value, std::vector<Value> *values);
|
||||||
|
|
||||||
|
// Converts the default min/max to the default quantization parameters.
|
||||||
|
TFL::QuantParams GetDefaultQuantParams(Builder builder);
|
||||||
|
|
||||||
|
// Gets the quantization parameters for the bias of an operation by using the
|
||||||
|
// quantization parameters from the non-biases operands.
|
||||||
|
TFL::QuantParams GetQuantParamsForBias(Operation *op, int bias,
|
||||||
|
const std::vector<int> &non_biases,
|
||||||
|
TFL::AccumulatorScaleFunc func);
|
||||||
|
|
||||||
|
double default_min_;
|
||||||
|
double default_max_;
|
||||||
|
TFL::QuantParams default_quant_params_;
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void DefaultQuantParamsPass::runOnFunction() {
|
||||||
|
FuncOp func = getFunction();
|
||||||
|
OpBuilder builder(func);
|
||||||
|
|
||||||
|
std::vector<Value> activation_values;
|
||||||
|
std::vector<Value> bias_values;
|
||||||
|
|
||||||
|
// First of all, collect all the values (block arguments and op results) which
|
||||||
|
// are required to be quantized.
|
||||||
|
for (auto arg : func.getBody().begin()->getArguments()) {
|
||||||
|
if (UsedAsBias(arg)) {
|
||||||
|
AddToWorkListIfUnquantized(arg, &bias_values);
|
||||||
|
} else {
|
||||||
|
AddToWorkListIfUnquantized(arg, &activation_values);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func.walk([&](Operation *op) {
|
||||||
|
if (op->isKnownTerminator() ||
|
||||||
|
op->hasTrait<OpTrait::quant::NoQuantizableResult>())
|
||||||
|
return;
|
||||||
|
|
||||||
|
for (auto res : op->getResults()) {
|
||||||
|
if (UsedAsBias(res)) {
|
||||||
|
AddToWorkListIfUnquantized(res, &bias_values);
|
||||||
|
} else {
|
||||||
|
AddToWorkListIfUnquantized(res, &activation_values);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Apply the default quantization parameters for these activation values.
|
||||||
|
TFL::QuantParams default_params = GetDefaultQuantParams(builder);
|
||||||
|
for (Value value : activation_values) {
|
||||||
|
QuantizeValue(builder, value, default_params);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Since all the non-biases operands have quantization parameters now, we
|
||||||
|
// should be able to propagate them to the bias operand.
|
||||||
|
for (Value bias : bias_values) {
|
||||||
|
Operation *op = *bias.user_begin();
|
||||||
|
auto spec = TFL::GetOpQuantSpec(op);
|
||||||
|
for (auto &it : spec->biases_params) {
|
||||||
|
TFL::QuantParams bias_params = GetQuantParamsForBias(
|
||||||
|
op, it.first, it.second.first, it.second.second);
|
||||||
|
if (!bias_params) continue;
|
||||||
|
QuantizeValue(builder, bias, bias_params);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void DefaultQuantParamsPass::AddToWorkListIfUnquantized(
|
||||||
|
Value value, std::vector<Value> *values) {
|
||||||
|
// If the result isn't with float type, this result is an integer tensor and
|
||||||
|
// doesn't require quantization.
|
||||||
|
auto tensor_type = value.getType().dyn_cast<TensorType>();
|
||||||
|
if (!tensor_type) {
|
||||||
|
// There are none type values.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!tensor_type.getElementType().isF32()) return;
|
||||||
|
|
||||||
|
// If the result is consumed by a quantize op, it has been quantized.
|
||||||
|
if (value.hasOneUse() &&
|
||||||
|
llvm::isa<TFL::QuantizeOp>(*value.getUsers().begin()))
|
||||||
|
return;
|
||||||
|
|
||||||
|
// Add this result to the list to apply the default value.
|
||||||
|
values->push_back(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DefaultQuantParamsPass::QuantizeValue(OpBuilder builder, Value value,
|
||||||
|
TFL::QuantParams quant_params) {
|
||||||
|
Type expressed_type = value.getType();
|
||||||
|
Type new_type = quant_params.castFromExpressedType(expressed_type);
|
||||||
|
// This value isn't an expressed type (float), skip.
|
||||||
|
if (!new_type) return;
|
||||||
|
|
||||||
|
Block &block = value.getParentRegion()->front();
|
||||||
|
Operation *op = value.getDefiningOp();
|
||||||
|
if (op) {
|
||||||
|
builder.setInsertionPoint(&block, ++Block::iterator(op));
|
||||||
|
} else {
|
||||||
|
builder.setInsertionPointToStart(&block);
|
||||||
|
}
|
||||||
|
TypeAttr type_attr = TypeAttr::get(new_type);
|
||||||
|
auto quantize = builder.create<TFL::QuantizeOp>(value.getLoc(), new_type,
|
||||||
|
value, type_attr);
|
||||||
|
auto dequantize = builder.create<TFL::DequantizeOp>(
|
||||||
|
value.getLoc(), expressed_type, quantize.output());
|
||||||
|
value.replaceAllUsesWith(dequantize);
|
||||||
|
|
||||||
|
// `quantize` is using `dequantize` now, so we should set its operand to
|
||||||
|
// `value`.
|
||||||
|
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
TFL::QuantParams DefaultQuantParamsPass::GetQuantParamsForBias(
|
||||||
|
Operation *op, int bias, const std::vector<int> &non_biases,
|
||||||
|
TFL::AccumulatorScaleFunc func) {
|
||||||
|
std::vector<quant::QuantizedType> non_bias_types;
|
||||||
|
non_bias_types.reserve(non_biases.size());
|
||||||
|
for (int non_bias : non_biases) {
|
||||||
|
Operation *non_bias_define = op->getOperand(non_bias).getDefiningOp();
|
||||||
|
if (auto dequant = llvm::dyn_cast<TFL::DequantizeOp>(non_bias_define)) {
|
||||||
|
auto non_bias_type = dequant.input().getType().cast<TensorType>();
|
||||||
|
auto non_bias_ele_type =
|
||||||
|
non_bias_type.getElementType().cast<quant::QuantizedType>();
|
||||||
|
non_bias_types.push_back(non_bias_ele_type);
|
||||||
|
} else {
|
||||||
|
// The non-bias hasn't been quantized, let's skip this bias.
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// The non-bias hasn't been quantized, let's skip this bias.
|
||||||
|
if (non_bias_types.size() != non_biases.size()) return {};
|
||||||
|
|
||||||
|
return func(non_bias_types);
|
||||||
|
}
|
||||||
|
|
||||||
|
TFL::QuantParams DefaultQuantParamsPass::GetDefaultQuantParams(
|
||||||
|
Builder builder) {
|
||||||
|
if (!default_quant_params_) {
|
||||||
|
default_quant_params_ = quant::fakeQuantAttrsToType(
|
||||||
|
builder.getUnknownLoc(),
|
||||||
|
/*numBits=*/8, default_min_, default_max_, /*narrowRange=*/false,
|
||||||
|
builder.getF32Type());
|
||||||
|
}
|
||||||
|
return default_quant_params_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creates an instance of the default quant parameters pass.
|
||||||
|
std::unique_ptr<OpPassBase<FuncOp>> CreateDefaultQuantParamsPass(
|
||||||
|
double default_min, double default_max) {
|
||||||
|
return absl::make_unique<DefaultQuantParamsPass>(default_min, default_max);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Registers this pass with default values, only for test
|
||||||
|
static PassRegistration<DefaultQuantParamsPass> pass(
|
||||||
|
"tfl-default-quant",
|
||||||
|
"Apply quantization with default quantization parameter", [] {
|
||||||
|
return CreateDefaultQuantParamsPass(/*default_min=*/-1.0,
|
||||||
|
/*default_max=*/1.0);
|
||||||
|
});
|
||||||
|
|
||||||
|
} // namespace TFL
|
||||||
|
} // namespace mlir
|
@ -150,6 +150,7 @@ def : Pat<(TF_RoundOp $arg), (TFL_RoundOp $arg)>;
|
|||||||
def : Pat<(TF_RsqrtOp $arg), (TFL_RsqrtOp $arg)>;
|
def : Pat<(TF_RsqrtOp $arg), (TFL_RsqrtOp $arg)>;
|
||||||
def : Pat<(TF_SqrtOp $arg), (TFL_SqrtOp $arg)>;
|
def : Pat<(TF_SqrtOp $arg), (TFL_SqrtOp $arg)>;
|
||||||
def : Pat<(TF_SquareOp $arg), (TFL_SquareOp $arg)>;
|
def : Pat<(TF_SquareOp $arg), (TFL_SquareOp $arg)>;
|
||||||
|
def : Pat<(TF_SegmentSumOp $data, I32Tensor:$segment_ids), (TFL_SegmentSumOp $data, $segment_ids)>;
|
||||||
def : Pat<(TF_SelectOp $cond, $x, $y), (TFL_SelectOp $cond, $x, $y)>;
|
def : Pat<(TF_SelectOp $cond, $x, $y), (TFL_SelectOp $cond, $x, $y)>;
|
||||||
def : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), (TFL_SelectOp $cond, $x, $y), [(HasSameStaticShapes $src_op)]>;
|
def : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), (TFL_SelectOp $cond, $x, $y), [(HasSameStaticShapes $src_op)]>;
|
||||||
def : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), (TFL_SelectV2Op $cond, $x, $y), [(HasNotSameStaticShapes $src_op)]>;
|
def : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), (TFL_SelectV2Op $cond, $x, $y), [(HasNotSameStaticShapes $src_op)]>;
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user