Merge branch 'master' into google_upstream_training_ops
This commit is contained in:
commit
9760afc119
10
.bazelrc
10
.bazelrc
@ -69,6 +69,7 @@
|
||||
# rbe_linux_py3: Linux Python 3 RBE config
|
||||
#
|
||||
# rbe_win_py37: Windows Python 3.7 RBE config
|
||||
# rbe_win_py38: Windows Python 3.8 RBE config
|
||||
#
|
||||
# tensorflow_testing_rbe_linux: RBE options to use RBE with tensorflow-testing project on linux
|
||||
# tensorflow_testing_rbe_win: RBE options to use RBE with tensorflow-testing project on windows
|
||||
@ -279,7 +280,6 @@ build:windows --host_linkopt=/OPT:REF
|
||||
build:windows --linkopt=/OPT:ICF
|
||||
build:windows --host_linkopt=/OPT:ICF
|
||||
build:windows --experimental_strict_action_env=true
|
||||
build:windows --incompatible_windows_native_test_wrapper
|
||||
|
||||
# Verbose failure logs when something goes wrong
|
||||
build:windows --verbose_failures
|
||||
@ -344,6 +344,7 @@ build:rbe_linux --config=avx_linux
|
||||
build:rbe_linux --config=short_logs
|
||||
# TODO(gunan): Check why we need this specified in rbe, but not in other builds.
|
||||
build:rbe_linux --linkopt=-lrt
|
||||
build:rbe_linux --linkopt=-lm
|
||||
|
||||
build:rbe_cpu_linux --config=rbe_linux
|
||||
build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain"
|
||||
@ -392,6 +393,7 @@ build:rbe_win --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe
|
||||
|
||||
# TODO(gunan): Remove once we use MSVC 2019 with latest patches.
|
||||
build:rbe_win --define=override_eigen_strong_inline=true
|
||||
build:rbe_win --jobs=500
|
||||
|
||||
build:rbe_win_py37 --config=rbe
|
||||
build:rbe_win_py37 --repo_env=PYTHON_BIN_PATH=C:\\Python37\\python.exe
|
||||
@ -399,6 +401,12 @@ build:rbe_win_py37 --repo_env=PYTHON_LIB_PATH=C:\\Python37\\lib\\site-packages
|
||||
build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py37
|
||||
build:rbe_win_py37 --python_path=C:\\Python37\\python.exe
|
||||
|
||||
build:rbe_win_py38 --config=rbe
|
||||
build:rbe_win_py38 --repo_env=PYTHON_BIN_PATH=C:\\Python38\\python.exe
|
||||
build:rbe_win_py38 --repo_env=PYTHON_LIB_PATH=C:\\Python38\\lib\\site-packages
|
||||
build:rbe_win_py38 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py38
|
||||
build:rbe_win_py38 --python_path=C:\\Python38\\python.exe
|
||||
|
||||
# These you may need to change for your own GCP project.
|
||||
build:tensorflow_testing_rbe --project_id=tensorflow-testing
|
||||
common:tensorflow_testing_rbe_linux --remote_instance_name=projects/tensorflow-testing/instances/default_instance
|
||||
|
194
RELEASE.md
194
RELEASE.md
File diff suppressed because one or more lines are too long
@ -245,4 +245,4 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc=
|
||||
### Known Vulnerabilities
|
||||
|
||||
For a list of known vulnerabilities and security advisories for TensorFlow,
|
||||
[click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/index.md).
|
||||
[click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/README.md).
|
||||
|
38
WORKSPACE
38
WORKSPACE
@ -1,11 +1,13 @@
|
||||
workspace(name = "org_tensorflow")
|
||||
|
||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file")
|
||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||
load("//third_party:repo.bzl", "tf_http_archive")
|
||||
|
||||
http_archive(
|
||||
tf_http_archive(
|
||||
name = "io_bazel_rules_closure",
|
||||
sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9",
|
||||
strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149",
|
||||
patch_file = "@org_tensorflow//third_party:rules_closure.patch",
|
||||
urls = [
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz",
|
||||
"https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13
|
||||
@ -48,38 +50,6 @@ load("//third_party/toolchains/preconfig/generate:workspace.bzl",
|
||||
|
||||
remote_config_workspace()
|
||||
|
||||
# Apple and Swift rules.
|
||||
http_archive(
|
||||
name = "build_bazel_rules_apple",
|
||||
sha256 = "a045a436b642c70fb0c10ca84ff0fd2dcbd59cc89100d597a61e8374afafb366",
|
||||
urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.18.0/rules_apple.0.18.0.tar.gz"],
|
||||
) # https://github.com/bazelbuild/rules_apple/releases
|
||||
http_archive(
|
||||
name = "build_bazel_rules_swift",
|
||||
sha256 = "18cd4df4e410b0439a4935f9ca035bd979993d42372ba79e7f2d4fafe9596ef0",
|
||||
urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.12.1/rules_swift.0.12.1.tar.gz"],
|
||||
) # https://github.com/bazelbuild/rules_swift/releases
|
||||
http_archive(
|
||||
name = "build_bazel_apple_support",
|
||||
sha256 = "122ebf7fe7d1c8e938af6aeaee0efe788a3a2449ece5a8d6a428cb18d6f88033",
|
||||
urls = ["https://github.com/bazelbuild/apple_support/releases/download/0.7.1/apple_support.0.7.1.tar.gz"],
|
||||
) # https://github.com/bazelbuild/apple_support/releases
|
||||
http_archive(
|
||||
name = "bazel_skylib",
|
||||
sha256 = "1dde365491125a3db70731e25658dfdd3bc5dbdfd11b840b3e987ecf043c7ca0",
|
||||
urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.9.0/bazel-skylib.0.9.0.tar.gz"],
|
||||
) # https://github.com/bazelbuild/bazel-skylib/releases
|
||||
http_archive(
|
||||
name = "com_github_apple_swift_swift_protobuf",
|
||||
type = "zip",
|
||||
strip_prefix = "swift-protobuf-1.6.0/",
|
||||
urls = ["https://github.com/apple/swift-protobuf/archive/1.6.0.zip"],
|
||||
) # https://github.com/apple/swift-protobuf/releases
|
||||
http_file(
|
||||
name = "xctestrunner",
|
||||
executable = 1,
|
||||
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.9/ios_test_runner.par"],
|
||||
) # https://github.com/google/xctestrunner/releases
|
||||
# Use `swift_rules_dependencies` to fetch the toolchains. With the
|
||||
# `git_repository` rules above, the following call will skip redefining them.
|
||||
load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies")
|
||||
|
@ -1221,7 +1221,7 @@ def is_reduced_optimize_huge_functions_available(environ_cp):
|
||||
only, as of 2019-11-19). TensorFlow needs this flag to massively reduce
|
||||
compile times, but until 16.4 is officially released, we can't depend on it.
|
||||
|
||||
See also https://groups.google.com/a/tensorflow.org/g/build/c/SsW98Eo7l3o
|
||||
See also https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
|
||||
|
||||
Because it's very annoying to check this manually (to check the MSVC installed
|
||||
versions, you need to use the registry, and it's not clear if Bazel will be
|
||||
|
@ -2,6 +2,7 @@
|
||||
# TensorFlow is a computational framework, primarily for use in machine
|
||||
# learning applications.
|
||||
|
||||
load("@bazel_skylib//lib:selects.bzl", "selects")
|
||||
load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary")
|
||||
load(
|
||||
"//tensorflow/core/platform:build_config.bzl",
|
||||
@ -478,6 +479,7 @@ bzl_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core/platform:build_config_root_bzl",
|
||||
"//tensorflow/core/platform:rules_cc_bzl",
|
||||
"//tensorflow/core/platform/default:cuda_build_defs_bzl",
|
||||
"//third_party/mkl:build_defs_bzl",
|
||||
"//third_party/mkl_dnn:build_defs_bzl",
|
||||
|
@ -23,10 +23,6 @@ from __future__ import print_function
|
||||
# pylint: disable=g-bad-import-order
|
||||
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
||||
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
|
||||
del LazyLoader
|
||||
|
||||
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
|
||||
from tensorflow.python.platform import app # pylint: disable=g-import-not-at-top
|
||||
app.flags = flags
|
||||
|
@ -54,9 +54,10 @@ filegroup(
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_eager_hdrs",
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"c_api_internal.h",
|
||||
"python_api.h",
|
||||
"tf_status_helper.h",
|
||||
"tf_status_internal.h",
|
||||
"tf_tensor_internal.h",
|
||||
@ -98,6 +99,17 @@ tf_cuda_library(
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_tf_session_hdrs",
|
||||
srcs = [
|
||||
"python_api.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_attrtype",
|
||||
hdrs = ["tf_attrtype.h"],
|
||||
@ -302,6 +314,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"//tensorflow/core/platform",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -639,7 +652,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/kernels:ops_testutil",
|
||||
"//third_party/eigen3",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
@ -549,7 +550,7 @@ TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(TFE_Op* op,
|
||||
TF_Status* status) {
|
||||
TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification;
|
||||
|
||||
n->thread.reset(op->operation.EagerContext()->TFEnv()->StartThread(
|
||||
n->thread.reset(op->operation.EagerContext().TFEnv()->StartThread(
|
||||
tensorflow::ThreadOptions(), "ExecuteOpThread",
|
||||
[op, retvals, num_retvals, n]() {
|
||||
TFE_Execute(op, retvals, num_retvals, n->status.get());
|
||||
@ -767,8 +768,9 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
|
||||
} while (0);
|
||||
|
||||
// New server created for new server_def. Unused if updating server_def.
|
||||
tensorflow::EagerContext* context = ctx->context;
|
||||
tensorflow::GrpcServer* grpc_server =
|
||||
dynamic_cast<tensorflow::GrpcServer*>(ctx->context->GetServer());
|
||||
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||
if (grpc_server == nullptr) {
|
||||
std::unique_ptr<tensorflow::ServerInterface> new_server;
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
|
||||
@ -779,12 +781,12 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
|
||||
}
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
|
||||
|
||||
LOG_AND_RETURN_IF_ERROR(ctx->context->StoreCollectiveOpsServer(
|
||||
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
|
||||
std::move(new_server), grpc_server->worker_env()->device_mgr,
|
||||
grpc_server->worker_env()->collective_executor_mgr));
|
||||
} else {
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
|
||||
LOG_AND_RETURN_IF_ERROR(ctx->context->StoreCollectiveOpsServer(
|
||||
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
|
||||
/*new_server=*/nullptr, grpc_server->worker_env()->device_mgr,
|
||||
grpc_server->worker_env()->collective_executor_mgr));
|
||||
}
|
||||
|
@ -1260,11 +1260,10 @@ TEST_F(CApiFunctionTest, GraphToFunctionDefWithPlaceholderAttr) {
|
||||
NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node3", "v2",
|
||||
&node3);
|
||||
|
||||
TF_Output inputs[] = {};
|
||||
TF_Output outputs[] = {{node1, 0}, {node2, 0}, {node3, 0}};
|
||||
func_ = TF_GraphToFunction(
|
||||
func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
|
||||
/*opers=*/nullptr, 0, inputs, 3, outputs,
|
||||
/*opers=*/nullptr, 0, nullptr, 3, outputs,
|
||||
/*output_names=*/nullptr,
|
||||
/*opts=*/nullptr, /*description=*/nullptr, s.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
|
||||
@ -1300,10 +1299,9 @@ TEST_F(CApiFunctionTest, GraphToFunctionDefWithArgAttr) {
|
||||
&node);
|
||||
|
||||
TF_Output inputs[] = {{node, 0}};
|
||||
TF_Output outputs[] = {};
|
||||
func_ = TF_GraphToFunction(
|
||||
func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
|
||||
/*opers=*/nullptr, 1, inputs, 0, outputs,
|
||||
/*opers=*/nullptr, 1, inputs, 0, nullptr,
|
||||
/*output_names=*/nullptr,
|
||||
/*opts=*/nullptr, /*description=*/nullptr, s.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
|
||||
@ -1603,11 +1601,10 @@ void DefineStatefulFunction(const char* name, TF_Function** func) {
|
||||
TF_Operation* random =
|
||||
RandomUniform(shape, TF_FLOAT, func_graph.get(), s.get());
|
||||
|
||||
TF_Output inputs[] = {};
|
||||
TF_Output outputs[] = {{random, 0}};
|
||||
*func = TF_GraphToFunction(func_graph.get(), name,
|
||||
/*append_hash_to_fn_name=*/false, -1,
|
||||
/*opers=*/nullptr, 0, inputs, 1, outputs,
|
||||
/*opers=*/nullptr, 0, nullptr, 1, outputs,
|
||||
/*output_names=*/nullptr,
|
||||
/*opts=*/nullptr, "", s.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
#include <memory.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <sys/time.h>
|
||||
#include <time.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
@ -58,12 +58,8 @@ int main(int argc, char** argv) {
|
||||
}
|
||||
|
||||
char file_name[100];
|
||||
struct timeval t;
|
||||
if (gettimeofday(&t, NULL)) {
|
||||
perror("gettimeofday failed");
|
||||
return 1;
|
||||
}
|
||||
snprintf(file_name, sizeof(file_name), "test-%d-%ld.txt", getpid(), t.tv_sec);
|
||||
time_t t = time(NULL);
|
||||
snprintf(file_name, sizeof(file_name), "test-%d-%ld.txt", getpid(), t);
|
||||
|
||||
size_t length = 2 + strlen(path) + strlen(file_name);
|
||||
char* full_path = malloc(length);
|
||||
|
@ -26,8 +26,8 @@ tf_cuda_library(
|
||||
"c_api.cc",
|
||||
"c_api_debug.cc",
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.cc",
|
||||
"c_api_internal.h",
|
||||
"tensor_handle_interface.h",
|
||||
],
|
||||
hdrs = ["c_api.h"],
|
||||
copts = tf_copts() + tfe_xla_copts(),
|
||||
@ -89,10 +89,11 @@ tf_cuda_library(
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_eager_hdrs",
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
"tensor_handle_interface.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
@ -102,7 +103,10 @@ filegroup(
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api_internal",
|
||||
srcs = ["c_api_experimental.h"],
|
||||
srcs = [
|
||||
"c_api_experimental.h",
|
||||
"tensor_handle_interface.h",
|
||||
],
|
||||
hdrs = ["c_api_internal.h"],
|
||||
visibility = [
|
||||
"//learning/deepmind/courier:__subpackages__",
|
||||
@ -125,18 +129,6 @@ tf_cuda_library(
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/common_runtime/eager:kernel_and_device",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"//tensorflow/core/distributed_runtime:remote_device",
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/distributed_runtime/eager:eager_client",
|
||||
"//tensorflow/core/distributed_runtime/eager:remote_tensor_handle",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
|
||||
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
|
||||
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
|
||||
"//tensorflow/core/profiler/lib:profiler_lib",
|
||||
"//tensorflow/core/profiler/lib:profiler_session",
|
||||
],
|
||||
)
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
@ -43,6 +44,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/platform.h" // NOLINT
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
#include "tensorflow/core/protobuf/device_filters.pb.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
@ -81,6 +83,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
@ -93,10 +96,8 @@ using tensorflow::string;
|
||||
namespace {
|
||||
|
||||
const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) {
|
||||
if (op->inference_ctx) {
|
||||
return op->inference_ctx->op_def;
|
||||
}
|
||||
const tensorflow::OpDef* op_def;
|
||||
const tensorflow::OpDef* op_def = op->operation.OpDef();
|
||||
if (op_def) return op_def;
|
||||
status->status =
|
||||
tensorflow::OpDefForOp(op->operation.Name().c_str(), &op_def);
|
||||
return op_def;
|
||||
@ -265,9 +266,9 @@ tensorflow::Status GetReplacedFromExistingWorkers(
|
||||
}
|
||||
|
||||
tensorflow::Status CreateRemoteContexts(
|
||||
const std::vector<string>& remote_workers, tensorflow::uint64 context_id,
|
||||
tensorflow::uint64 context_view_id, int keep_alive_secs,
|
||||
const tensorflow::ServerDef& server_def,
|
||||
TFE_Context* ctx, const std::vector<string>& remote_workers,
|
||||
tensorflow::uint64 context_id, tensorflow::uint64 context_view_id,
|
||||
int keep_alive_secs, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
|
||||
const bool lazy_copy_remote_function_inputs,
|
||||
const tensorflow::eager::CreateContextRequest& base_request) {
|
||||
@ -296,7 +297,7 @@ tensorflow::Status CreateRemoteContexts(
|
||||
continue;
|
||||
}
|
||||
|
||||
tensorflow::eager::CreateContextRequest request(base_request);
|
||||
tensorflow::eager::CreateContextRequest request;
|
||||
tensorflow::eager::CreateContextResponse* response =
|
||||
new tensorflow::eager::CreateContextResponse();
|
||||
request.set_context_id(context_id);
|
||||
@ -304,6 +305,21 @@ tensorflow::Status CreateRemoteContexts(
|
||||
*request.mutable_server_def() = server_def;
|
||||
request.mutable_server_def()->set_job_name(parsed_name.job);
|
||||
request.mutable_server_def()->set_task_index(parsed_name.task);
|
||||
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
|
||||
server_def.default_session_config());
|
||||
|
||||
std::vector<bool> filtered_device_mask;
|
||||
ctx->context->FilterDevicesForRemoteWorkers(
|
||||
remote_worker, base_request.cluster_device_attributes(),
|
||||
&filtered_device_mask);
|
||||
DCHECK_EQ(filtered_device_mask.size(),
|
||||
base_request.cluster_device_attributes_size());
|
||||
for (int i = 0; i < filtered_device_mask.size(); i++) {
|
||||
if (filtered_device_mask[i]) {
|
||||
const auto& da = base_request.cluster_device_attributes(i);
|
||||
*request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
}
|
||||
request.set_async(async);
|
||||
request.set_keep_alive_secs(keep_alive_secs);
|
||||
request.set_lazy_copy_remote_function_inputs(
|
||||
@ -325,13 +341,34 @@ tensorflow::Status CreateRemoteContexts(
|
||||
}
|
||||
|
||||
tensorflow::Status UpdateRemoteContexts(
|
||||
const std::vector<string>& remote_workers, tensorflow::uint64 context_id,
|
||||
TFE_Context* ctx, const std::vector<string>& remote_workers,
|
||||
const std::vector<string>& added_workers,
|
||||
const std::vector<string>& removed_workers, tensorflow::uint64 context_id,
|
||||
tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* remote_eager_workers,
|
||||
const tensorflow::eager::CreateContextRequest& base_request) {
|
||||
int num_remote_workers = remote_workers.size();
|
||||
tensorflow::BlockingCounter counter(num_remote_workers);
|
||||
std::vector<tensorflow::Status> statuses(num_remote_workers);
|
||||
|
||||
int cluster_device_count = base_request.cluster_device_attributes_size();
|
||||
std::unordered_set<string> added_or_removed(added_workers.begin(),
|
||||
added_workers.end());
|
||||
std::copy(removed_workers.begin(), removed_workers.end(),
|
||||
std::inserter(added_or_removed, added_or_removed.end()));
|
||||
// Whether each device is in the updated (added or removed) workers
|
||||
std::vector<bool> device_added_or_removed(cluster_device_count);
|
||||
for (int i = 0; i < base_request.cluster_device_attributes_size(); i++) {
|
||||
const auto& da = base_request.cluster_device_attributes().at(i);
|
||||
tensorflow::DeviceNameUtils::ParsedName pn;
|
||||
tensorflow::DeviceNameUtils::ParseFullName(da.name(), &pn);
|
||||
string task_name;
|
||||
tensorflow::DeviceNameUtils::GetTaskName(pn, &task_name);
|
||||
if (added_or_removed.find(task_name) != added_or_removed.end()) {
|
||||
device_added_or_removed[i] = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
const string& remote_worker = remote_workers[i];
|
||||
tensorflow::DeviceNameUtils::ParsedName parsed_name;
|
||||
@ -354,17 +391,42 @@ tensorflow::Status UpdateRemoteContexts(
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<bool> filtered_device_mask;
|
||||
ctx->context->FilterDevicesForRemoteWorkers(
|
||||
remote_worker, base_request.cluster_device_attributes(),
|
||||
&filtered_device_mask);
|
||||
DCHECK_EQ(filtered_device_mask.size(), cluster_device_count);
|
||||
|
||||
// If any of the devices that match the device filters are in the set of
|
||||
// added or removed workers, we must send a complete UpdateContextRequest.
|
||||
// Otherwise, only send a simple request to increment context view ID.
|
||||
std::vector<bool> added_or_removed_filtered_devices(cluster_device_count);
|
||||
std::transform(device_added_or_removed.begin(),
|
||||
device_added_or_removed.end(), filtered_device_mask.begin(),
|
||||
added_or_removed_filtered_devices.begin(),
|
||||
std::logical_and<bool>());
|
||||
const bool full_update_request =
|
||||
std::accumulate(added_or_removed_filtered_devices.begin(),
|
||||
added_or_removed_filtered_devices.end(), false,
|
||||
std::logical_or<bool>());
|
||||
|
||||
tensorflow::eager::UpdateContextRequest request;
|
||||
auto* response = new tensorflow::eager::UpdateContextResponse();
|
||||
|
||||
*request.mutable_server_def() = server_def;
|
||||
request.mutable_server_def()->set_job_name(parsed_name.job);
|
||||
request.mutable_server_def()->set_task_index(parsed_name.task);
|
||||
for (const auto& da : base_request.cluster_device_attributes()) {
|
||||
*request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
request.set_context_id(context_id);
|
||||
request.set_context_view_id(context_view_id);
|
||||
if (full_update_request) {
|
||||
*request.mutable_server_def() = server_def;
|
||||
request.mutable_server_def()->set_job_name(parsed_name.job);
|
||||
request.mutable_server_def()->set_task_index(parsed_name.task);
|
||||
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
|
||||
server_def.default_session_config());
|
||||
for (int i = 0; i < cluster_device_count; i++) {
|
||||
if (filtered_device_mask[i]) {
|
||||
const auto& da = base_request.cluster_device_attributes(i);
|
||||
*request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eager_client->UpdateContextAsync(
|
||||
&request, response,
|
||||
@ -409,6 +471,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
|
||||
// New server created for new server_def. Unused if updating server_def.
|
||||
std::unique_ptr<tensorflow::ServerInterface> new_server;
|
||||
tensorflow::EagerContext* context = ctx->context;
|
||||
tensorflow::GrpcServer* grpc_server;
|
||||
if (reset_context) {
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
|
||||
@ -416,26 +479,25 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
|
||||
} else {
|
||||
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(
|
||||
ctx->context->GetServer(), worker_name, &curr_remote_workers));
|
||||
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
|
||||
&curr_remote_workers));
|
||||
// No need to check the cast here, since `ListRemoteWorkers` already checks
|
||||
// if the server is a GRPC server or not.
|
||||
grpc_server =
|
||||
dynamic_cast<tensorflow::GrpcServer*>(ctx->context->GetServer());
|
||||
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
|
||||
}
|
||||
|
||||
tensorflow::uint64 context_id = ctx->context->GetContextId();
|
||||
tensorflow::uint64 context_view_id = ctx->context->GetContextViewId();
|
||||
tensorflow::uint64 context_id = context->GetContextId();
|
||||
tensorflow::uint64 context_view_id = context->GetContextViewId();
|
||||
if (reset_context) {
|
||||
context_id = tensorflow::EagerContext::NewContextId();
|
||||
context_view_id = 0;
|
||||
// Make master eager context accessible by local eager service, which might
|
||||
// receive send tensor requests from remote workers.
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->AddMasterEagerContextToEagerService(
|
||||
context_id, ctx->context));
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->AddMasterEagerContextToEagerService(context_id, context));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
|
||||
@ -464,11 +526,11 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
&new_remote_device_mgr));
|
||||
remote_device_mgr = new_remote_device_mgr.get();
|
||||
} else {
|
||||
ctx->context->ClearCachesAndDefaultExecutor();
|
||||
context->ClearCachesAndDefaultExecutor();
|
||||
// TODO(b/143914772): Potential memory leak if rendezvous has pending
|
||||
// tensors for removed / replaced workers.
|
||||
|
||||
remote_device_mgr = ctx->context->GetOwnedRemoteDeviceMgr();
|
||||
remote_device_mgr = context->GetOwnedRemoteDeviceMgr();
|
||||
if (remote_device_mgr == nullptr) {
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::errors::InvalidArgument(
|
||||
"Updating context with an invalid set of remote devices."));
|
||||
@ -479,8 +541,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
&added_workers, &removed_workers,
|
||||
&existing_workers);
|
||||
LOG_AND_RETURN_IF_ERROR(GetReplacedFromExistingWorkers(
|
||||
&existing_workers, context_id, ctx->context->GetContextViewId(),
|
||||
server_def, remote_eager_workers.get(), &replaced_workers));
|
||||
&existing_workers, context_id, context->GetContextViewId(), server_def,
|
||||
remote_eager_workers.get(), &replaced_workers));
|
||||
if (VLOG_IS_ON(1)) {
|
||||
VLOG(1) << "Updating cluster with following changes";
|
||||
for (const string& w : added_workers) VLOG(1) << " Added worker " << w;
|
||||
@ -516,7 +578,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
grpc_server->worker_env()->device_mgr->ListDeviceAttributes(
|
||||
&local_device_attributes);
|
||||
|
||||
// This request make sure that we can create Rendevzous properly between
|
||||
// This request make sure that we can create Rendezvous properly between
|
||||
// Local and Remote context.
|
||||
tensorflow::eager::CreateContextRequest base_request;
|
||||
for (const auto& da : cluster_device_attributes) {
|
||||
@ -525,18 +587,14 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
for (const auto& da : local_device_attributes) {
|
||||
*base_request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
base_request.mutable_server_def()
|
||||
->mutable_default_session_config()
|
||||
->MergeFrom(server_def.default_session_config());
|
||||
|
||||
// Initialize remote eager workers.
|
||||
// TODO(b/138847548) Create remote eager contexts in async mode by default.
|
||||
if (reset_context) {
|
||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
||||
remote_workers, context_id, context_view_id, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(),
|
||||
ctx->context->Executor().Async(),
|
||||
ctx->context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||
ctx, remote_workers, context_id, context_view_id, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||
context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||
} else {
|
||||
// The master's context_view_id will be incremented by one
|
||||
// the UpdateRemoteMaster call later. We want all new workers and
|
||||
@ -544,10 +602,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
// we must set their context_view_id to the existing master's
|
||||
// context_view_id + 1.
|
||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
||||
added_workers, context_id, context_view_id + 1, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(),
|
||||
ctx->context->Executor().Async(),
|
||||
ctx->context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||
ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||
context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||
if (!existing_workers.empty()) {
|
||||
if (VLOG_IS_ON(1)) {
|
||||
for (const string& w : existing_workers) {
|
||||
@ -555,8 +612,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
}
|
||||
}
|
||||
LOG_AND_RETURN_IF_ERROR(UpdateRemoteContexts(
|
||||
existing_workers, context_id, context_view_id + 1, server_def,
|
||||
remote_eager_workers.get(), base_request));
|
||||
ctx, existing_workers, added_workers, removed_workers, context_id,
|
||||
context_view_id + 1, server_def, remote_eager_workers.get(),
|
||||
base_request));
|
||||
}
|
||||
}
|
||||
|
||||
@ -578,12 +636,12 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
|
||||
|
||||
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
|
||||
tensorflow::eager::CreateClusterFLR(context_id, ctx->context,
|
||||
tensorflow::eager::CreateClusterFLR(context_id, context,
|
||||
worker_session.get());
|
||||
auto remote_mgr = absl::make_unique<tensorflow::eager::RemoteMgr>(
|
||||
/*is_master=*/true, ctx->context);
|
||||
/*is_master=*/true, context);
|
||||
|
||||
LOG_AND_RETURN_IF_ERROR(ctx->context->InitializeRemoteMaster(
|
||||
LOG_AND_RETURN_IF_ERROR(context->InitializeRemoteMaster(
|
||||
std::move(new_server), grpc_server->worker_env(), worker_session,
|
||||
std::move(remote_eager_workers), std::move(new_remote_device_mgr),
|
||||
remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr,
|
||||
@ -601,9 +659,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
|
||||
session_name, &worker_session));
|
||||
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
|
||||
tensorflow::eager::CreateClusterFLR(context_id, ctx->context,
|
||||
tensorflow::eager::CreateClusterFLR(context_id, context,
|
||||
worker_session.get());
|
||||
LOG_AND_RETURN_IF_ERROR(ctx->context->UpdateRemoteMaster(
|
||||
LOG_AND_RETURN_IF_ERROR(context->UpdateRemoteMaster(
|
||||
grpc_server->worker_env(), std::move(remote_eager_workers),
|
||||
added_workers, removed_workers, context_id, r, device_mgr,
|
||||
keep_alive_secs, cluster_flr));
|
||||
@ -614,76 +672,6 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
}
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
|
||||
tensorflow::Status OpInferSingleInputAttrs(TFE_Op* op,
|
||||
TFE_TensorHandle* input) {
|
||||
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
|
||||
const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++);
|
||||
if (!input_def.number_attr().empty() || !input_def.type_list_attr().empty()) {
|
||||
// Some clients that are still setting their input attributes manually are
|
||||
// adding input list to their op by calling `TFE_OpAddInput` for each of
|
||||
// its elements instead of calling `TFE_OpAddInputList`. When this happens,
|
||||
// we cannot detect the end of such list, thus lose track of the input
|
||||
// arguments in the op definition. To guarantee backward compatibility with
|
||||
// those clients, disable automatic inference in this case.
|
||||
op->inference_ctx.reset(nullptr);
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
const std::string& type_attr = input_def.type_attr();
|
||||
if (!type_attr.empty() && ictx->attrs.find(type_attr) == ictx->attrs.end()) {
|
||||
op->operation.MutableAttrs()->Set(type_attr, input->handle->dtype);
|
||||
ictx->attrs.insert(type_attr);
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
void OpInferSingleTypeInputListAttrs(TFE_Op* op,
|
||||
const tensorflow::OpDef::ArgDef& input_def,
|
||||
const tensorflow::DataType dtype,
|
||||
int num_inputs) {
|
||||
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
|
||||
if (ictx->attrs.find(input_def.number_attr()) == ictx->attrs.end()) {
|
||||
op->operation.MutableAttrs()->Set(input_def.number_attr(), num_inputs);
|
||||
ictx->attrs.insert(input_def.number_attr());
|
||||
}
|
||||
if (ictx->attrs.find(input_def.type_attr()) == ictx->attrs.end()) {
|
||||
op->operation.MutableAttrs()->Set(input_def.type_attr(), dtype);
|
||||
ictx->attrs.insert(input_def.type_attr());
|
||||
}
|
||||
}
|
||||
|
||||
void OpInferMixedTypeInputListAttrs(
|
||||
TFE_Op* op, const tensorflow::OpDef::ArgDef& input_def,
|
||||
const std::vector<tensorflow::DataType>& dtypes) {
|
||||
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
|
||||
if (ictx->attrs.find(input_def.type_list_attr()) == ictx->attrs.end()) {
|
||||
op->operation.MutableAttrs()->Set(
|
||||
input_def.type_list_attr(),
|
||||
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(dtypes.data(),
|
||||
dtypes.size()));
|
||||
ictx->attrs.insert(input_def.type_list_attr());
|
||||
}
|
||||
}
|
||||
|
||||
tensorflow::Status OpInferInputListAttrs(TFE_Op* op, TFE_TensorHandle** inputs,
|
||||
int num_inputs) {
|
||||
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
|
||||
const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++);
|
||||
if (!input_def.type_list_attr().empty()) {
|
||||
std::vector<tensorflow::DataType> dtypes(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
dtypes[i] = inputs[i]->handle->dtype;
|
||||
}
|
||||
OpInferMixedTypeInputListAttrs(op, input_def, dtypes);
|
||||
} else if (!input_def.type_attr().empty() &&
|
||||
!input_def.number_attr().empty()) {
|
||||
OpInferSingleTypeInputListAttrs(op, input_def, inputs[0]->handle->dtype,
|
||||
num_inputs);
|
||||
} else {
|
||||
return tensorflow::errors::InvalidArgument("Invalid input list definition");
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
@ -719,12 +707,14 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
tensorflow::Rendezvous* r =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||
|
||||
return new TFE_Context(opts->session_options.options,
|
||||
opts->device_placement_policy, opts->mirroring_policy,
|
||||
opts->async, opts->lazy_remote_inputs_copy,
|
||||
device_mgr.release(),
|
||||
/*device_mgr_owned*/ true, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator());
|
||||
return new TFE_Context{new tensorflow::EagerContext(
|
||||
opts->session_options.options,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
opts->device_placement_policy),
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
|
||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
|
||||
/*device_mgr_owned*/ true, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator())};
|
||||
}
|
||||
|
||||
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
|
||||
@ -735,22 +725,28 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
|
||||
tensorflow::Rendezvous* r =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr);
|
||||
|
||||
return new TFE_Context(opts->session_options.options,
|
||||
opts->device_placement_policy, opts->mirroring_policy,
|
||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
|
||||
/*device_mgr_owned*/ false, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator());
|
||||
return new TFE_Context{new tensorflow::EagerContext(
|
||||
opts->session_options.options,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
opts->device_placement_policy),
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
|
||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
|
||||
/*device_mgr_owned*/ false, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator())};
|
||||
}
|
||||
|
||||
void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; }
|
||||
void TFE_DeleteContext(TFE_Context* ctx) {
|
||||
// context->RefCountIsOne() should be true here.
|
||||
// TODO(iga): Remove EagerContext refcounting.
|
||||
ctx->context->Unref();
|
||||
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
|
||||
TF_DeviceList* list = new TF_DeviceList;
|
||||
ctx->context->local_device_mgr()->ListDeviceAttributes(&list->response);
|
||||
if (ctx->context->remote_device_mgr()) {
|
||||
ctx->context->remote_device_mgr()->ListDeviceAttributes(&list->response);
|
||||
}
|
||||
return list;
|
||||
TF_DeviceList* l = new TF_DeviceList;
|
||||
ctx->context->ListDevices(&l->response);
|
||||
return l;
|
||||
}
|
||||
|
||||
void TFE_ContextClearCaches(TFE_Context* ctx) {
|
||||
@ -773,6 +769,22 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
||||
"Invalid tensorflow.ServerDef protocol buffer");
|
||||
return;
|
||||
}
|
||||
if (server_def.has_cluster_device_filters()) {
|
||||
const auto& cdf = server_def.cluster_device_filters();
|
||||
for (const auto& jdf : cdf.jobs()) {
|
||||
const string& remote_prefix = "/job:" + jdf.name() + "/task:";
|
||||
for (const auto& tdf : jdf.tasks()) {
|
||||
const int32_t task_index = tdf.first;
|
||||
std::vector<string> device_filters(tdf.second.device_filters_size());
|
||||
for (int i = 0; i < tdf.second.device_filters_size(); i++) {
|
||||
device_filters[i] = tdf.second.device_filters(i);
|
||||
}
|
||||
const string remote_worker = remote_prefix + std::to_string(task_index);
|
||||
status->status =
|
||||
ctx->context->SetRemoteDeviceFilters(remote_worker, device_filters);
|
||||
}
|
||||
}
|
||||
}
|
||||
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
|
||||
ctx, /*reset_context=*/true);
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
@ -797,6 +809,11 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Trying to update a context with invalid context id.");
|
||||
}
|
||||
if (server_def.has_cluster_device_filters()) {
|
||||
LOG(WARNING) << "Device filters can only be specified when initializing "
|
||||
"the cluster. Any changes in device filters are ignored "
|
||||
"when updating the server def.";
|
||||
}
|
||||
// TODO(haoyuzhang): Check server_def compatibility before the update
|
||||
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
|
||||
ctx, /*reset_context=*/false);
|
||||
@ -811,8 +828,9 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
"TFE_ContextSetServerDef not supported on mobile");
|
||||
return false;
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
tensorflow::EagerContext* context = ctx->context;
|
||||
tensorflow::GrpcServer* grpc_server =
|
||||
static_cast<tensorflow::GrpcServer*>(ctx->context->GetServer());
|
||||
static_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||
|
||||
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
|
||||
status->status = grpc_server->master_env()->worker_cache->GetEagerClientCache(
|
||||
@ -831,7 +849,7 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
|
||||
// Send a rpc request to the worker to check aliveness.
|
||||
tensorflow::eager::KeepAliveRequest request;
|
||||
request.set_context_id(ctx->context->GetContextId());
|
||||
request.set_context_id(context->GetContextId());
|
||||
tensorflow::eager::KeepAliveResponse response;
|
||||
|
||||
tensorflow::Status keep_alive_status;
|
||||
@ -886,138 +904,212 @@ void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
|
||||
if (h == nullptr) return;
|
||||
tensorflow::profiler::TraceMe activity(
|
||||
"TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
|
||||
VLOG(1) << "Deleting tensor handle " << h << " with internal handle "
|
||||
<< h->handle;
|
||||
if (h->handle) {
|
||||
h->handle->Unref();
|
||||
}
|
||||
delete h;
|
||||
}
|
||||
|
||||
tensorflow::TensorHandleInterface::~TensorHandleInterface() {
|
||||
VLOG(1) << "Deleting tensor handle " << this << " with internal handle "
|
||||
<< handle_;
|
||||
if (handle_) {
|
||||
handle_->Unref();
|
||||
}
|
||||
}
|
||||
|
||||
bool tensorflow::TensorHandleInterface::IsValid(Status* status) const {
|
||||
if (handle_ == nullptr) {
|
||||
*status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
|
||||
return static_cast<TF_DataType>(h->handle->dtype);
|
||||
return h->handle->DataType();
|
||||
}
|
||||
|
||||
TF_DataType tensorflow::TensorHandleInterface::DataType() const {
|
||||
return static_cast<TF_DataType>(handle_->dtype);
|
||||
}
|
||||
|
||||
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return -1;
|
||||
}
|
||||
|
||||
return h->handle->NumDims(&status->status);
|
||||
}
|
||||
|
||||
int tensorflow::TensorHandleInterface::NumDims(Status* status) const {
|
||||
if (!IsValid(status)) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
int result;
|
||||
status->status = h->handle->NumDims(&result);
|
||||
*status = handle_->NumDims(&result);
|
||||
return result;
|
||||
}
|
||||
|
||||
int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return -1;
|
||||
}
|
||||
|
||||
return h->handle->NumElements(&status->status);
|
||||
}
|
||||
|
||||
int64_t tensorflow::TensorHandleInterface::NumElements(Status* status) const {
|
||||
if (!IsValid(status)) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
tensorflow::int64 result;
|
||||
status->status = h->handle->NumElements(&result);
|
||||
*status = handle_->NumElements(&result);
|
||||
return result;
|
||||
}
|
||||
|
||||
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
|
||||
TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return -1;
|
||||
}
|
||||
|
||||
return h->handle->Dim(dim_index, &status->status);
|
||||
}
|
||||
|
||||
int64_t tensorflow::TensorHandleInterface::Dim(int dim_index,
|
||||
Status* status) const {
|
||||
if (!IsValid(status)) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
tensorflow::int64 result;
|
||||
status->status = h->handle->Dim(dim_index, &result);
|
||||
*status = handle_->Dim(dim_index, &result);
|
||||
return result;
|
||||
}
|
||||
|
||||
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::Device* d = h->handle->op_device();
|
||||
return h->handle->DeviceName(&status->status);
|
||||
}
|
||||
|
||||
const char* tensorflow::TensorHandleInterface::DeviceName(
|
||||
Status* status) const {
|
||||
if (!IsValid(status)) {
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::Device* d = handle_->op_device();
|
||||
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
: d->name().c_str();
|
||||
}
|
||||
|
||||
const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
|
||||
TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::Device* d = h->handle->device();
|
||||
return h->handle->BackingDeviceName(&status->status);
|
||||
}
|
||||
|
||||
const char* tensorflow::TensorHandleInterface::BackingDeviceName(
|
||||
Status* status) const {
|
||||
if (!IsValid(status)) {
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::Device* d = handle_->device();
|
||||
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
: d->name().c_str();
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
|
||||
TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr || !h->handle->IsValid(&status->status)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
h->handle->Ref();
|
||||
return new TFE_TensorHandle{
|
||||
std::unique_ptr<AbstractTensorHandleInterface>(h->handle->Copy())};
|
||||
}
|
||||
|
||||
return new TFE_TensorHandle(h->handle);
|
||||
AbstractTensorHandleInterface* tensorflow::TensorHandleInterface::Copy() {
|
||||
handle_->Ref();
|
||||
return new TensorHandleInterface(handle_);
|
||||
}
|
||||
|
||||
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::TensorHandle* handle = h->handle;
|
||||
|
||||
return h->handle->Resolve(&status->status);
|
||||
}
|
||||
|
||||
TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
|
||||
if (!IsValid(status)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// TODO(agarwal): move this implementation inside TFE_TensorHandle.
|
||||
if (handle->IsRemote()) {
|
||||
if (handle_->IsRemote()) {
|
||||
const tensorflow::Tensor* t = nullptr;
|
||||
tensorflow::TensorHandle* h_cpu = nullptr;
|
||||
status->status = EagerCopyToDevice(
|
||||
handle, handle->Context(), &handle->Context()->Executor(),
|
||||
handle->Context()->HostCPU(), false, &h_cpu);
|
||||
if (!status->status.ok()) {
|
||||
*status = EagerCopyToDevice(handle_, handle_->Context(),
|
||||
&handle_->Context()->Executor(),
|
||||
handle_->Context()->HostCPU(), false, &h_cpu);
|
||||
if (!status->ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
status->status = h_cpu->Tensor(&t);
|
||||
if (!status->status.ok()) {
|
||||
*status = h_cpu->Tensor(&t);
|
||||
if (!status->ok()) {
|
||||
h_cpu->Unref();
|
||||
return nullptr;
|
||||
}
|
||||
TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, &status->status);
|
||||
TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status);
|
||||
h_cpu->Unref();
|
||||
return retval;
|
||||
} else {
|
||||
tensorflow::Tensor tensor;
|
||||
if (IsCPU(handle->device())) {
|
||||
if (IsCPU(handle_->device())) {
|
||||
const tensorflow::Tensor* src = nullptr;
|
||||
status->status = handle->Tensor(&src);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
*status = handle_->Tensor(&src);
|
||||
if (!status->ok()) return nullptr;
|
||||
tensor = *src;
|
||||
} else {
|
||||
tensorflow::EagerContext* ctx = handle->Context();
|
||||
tensorflow::EagerContext* ctx = handle_->Context();
|
||||
CHECK_NE(ctx, nullptr);
|
||||
status->status = h->handle->CopyToDevice(ctx, ctx->HostCPU(), &tensor);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
*status = handle_->CopyToDevice(*ctx, ctx->HostCPU(), &tensor);
|
||||
if (!status->ok()) return nullptr;
|
||||
}
|
||||
return tensorflow::TF_TensorFromTensor(tensor, &status->status);
|
||||
return tensorflow::TF_TensorFromTensor(tensor, status);
|
||||
}
|
||||
}
|
||||
|
||||
void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr || !h->handle->IsValid(&status->status)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::TensorHandle* handle = h->handle;
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||
->Handle();
|
||||
|
||||
if (handle->IsRemote()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
@ -1046,7 +1138,8 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
void (*deallocator)(void* data, size_t len, void* arg),
|
||||
void* deallocator_arg, TF_Status* status) {
|
||||
tensorflow::Device* device;
|
||||
status->status = ctx->context->FindDeviceFromName(device_name, &device);
|
||||
tensorflow::EagerContext* context = ctx->context;
|
||||
status->status = context->FindDeviceFromName(device_name, &device);
|
||||
if (!status->status.ok()) {
|
||||
deallocator(data, len, deallocator_arg);
|
||||
return nullptr;
|
||||
@ -1074,11 +1167,12 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
buf->Unref();
|
||||
tensorflow::TensorHandle* ret_handle;
|
||||
status->status = tensorflow::TensorHandle::CreateLocalHandle(
|
||||
t, device, ctx->context, &ret_handle);
|
||||
t, device, context, &ret_handle);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TFE_TensorHandle(ret_handle);
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(ret_handle)};
|
||||
}
|
||||
|
||||
// This function will block till the operation that produces `h` has
|
||||
@ -1086,12 +1180,14 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
// bytes of the memory pointed to by the device pointer returned above.
|
||||
size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
||||
TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr || !h->handle->IsValid(&status->status)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return 0;
|
||||
}
|
||||
tensorflow::TensorHandle* handle = h->handle;
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||
->Handle();
|
||||
|
||||
if (handle->IsRemote()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
@ -1109,8 +1205,14 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
||||
|
||||
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status) {
|
||||
return NewOrResetOp(ctx, op_or_function_name, nullptr, status,
|
||||
/* op_to_reset= */ nullptr);
|
||||
std::unique_ptr<TFE_Op> new_op(
|
||||
new TFE_Op{tensorflow::EagerOperation(ctx->context)});
|
||||
status->status =
|
||||
new_op->operation.Reset(op_or_function_name, nullptr, false, nullptr);
|
||||
if (!status->status.ok()) {
|
||||
new_op.reset();
|
||||
}
|
||||
return new_op.release();
|
||||
}
|
||||
|
||||
void TFE_DeleteOp(TFE_Op* op) { delete op; }
|
||||
@ -1121,7 +1223,7 @@ void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
|
||||
|
||||
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
|
||||
tensorflow::Device* device = (op->operation.Device() == nullptr)
|
||||
? op->operation.EagerContext()->HostCPU()
|
||||
? op->operation.EagerContext().HostCPU()
|
||||
: op->operation.Device();
|
||||
return device->name().c_str();
|
||||
}
|
||||
@ -1135,20 +1237,23 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
|
||||
}
|
||||
|
||||
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
|
||||
op->operation.AddInput(input->handle);
|
||||
if (op->inference_ctx) {
|
||||
status->status = OpInferSingleInputAttrs(op, input);
|
||||
}
|
||||
tensorflow::TensorHandle* h =
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
input->handle.get())
|
||||
->Handle();
|
||||
op->operation.AddInput(h);
|
||||
status->status = op->operation.MaybeInferSingleInputAttrs(h);
|
||||
}
|
||||
|
||||
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
|
||||
TF_Status* status) {
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
op->operation.AddInput(inputs[i]->handle);
|
||||
}
|
||||
if (op->inference_ctx) {
|
||||
status->status = OpInferInputListAttrs(op, inputs, num_inputs);
|
||||
op->operation.AddInput(
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
inputs[i]->handle.get())
|
||||
->Handle());
|
||||
}
|
||||
status->status = op->operation.InferInputListAttrs(num_inputs);
|
||||
}
|
||||
|
||||
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
||||
@ -1381,15 +1486,16 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
|
||||
|
||||
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||
TF_Status* status) {
|
||||
VLOG(1) << "Calling TFE_Execute() on op " << op;
|
||||
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
|
||||
VLOG(1) << "Calling TFE_Execute() on op " << op;
|
||||
status->status = tensorflow::EagerExecute(&op->operation,
|
||||
handle_retvals.data(), num_retvals);
|
||||
if (!status->status.ok()) {
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals[i] = new TFE_TensorHandle(handle_retvals[i]);
|
||||
retvals[i] = new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(handle_retvals[i])};
|
||||
}
|
||||
}
|
||||
|
||||
@ -1399,15 +1505,18 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
TF_Status* status) {
|
||||
tensorflow::TensorHandle* handle = nullptr;
|
||||
tensorflow::Device* device;
|
||||
status->status = ctx->context->FindDeviceFromName(device_name, &device);
|
||||
tensorflow::EagerContext* context = ctx->context;
|
||||
status->status = context->FindDeviceFromName(device_name, &device);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context,
|
||||
&ctx->context->Executor(),
|
||||
device, false, &handle);
|
||||
status->status = tensorflow::EagerCopyToDevice(
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||
->Handle(),
|
||||
context, &context->Executor(), device, false, &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle(handle);
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
@ -1455,11 +1564,12 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
|
||||
|
||||
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
|
||||
TF_Status* status) {
|
||||
status->status = ctx->context->Executor().WaitForAllPendingNodes();
|
||||
tensorflow::EagerContext* context = ctx->context;
|
||||
status->status = context->Executor().WaitForAllPendingNodes();
|
||||
if (!status->status.ok()) return;
|
||||
tensorflow::mutex_lock ml(*ctx->context->MetadataMu());
|
||||
status->status = MessageToBuffer(*ctx->context->RunMetadataProto(), buf);
|
||||
ctx->context->ClearRunMetadata();
|
||||
tensorflow::mutex_lock ml(*context->MetadataMu());
|
||||
status->status = MessageToBuffer(*context->RunMetadataProto(), buf);
|
||||
context->ClearRunMetadata();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -213,7 +213,7 @@ TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo(
|
||||
TFE_TensorDebugInfo* debug_info);
|
||||
|
||||
// Returns the number of dimensions used to represent the tensor on its device.
|
||||
// The number of dimensions used to reprensent the tensor on device can be
|
||||
// The number of dimensions used to represent the tensor on device can be
|
||||
// different from the number returned by TFE_TensorHandleNumDims.
|
||||
// The return value was current at the time of TFE_TensorDebugInfo creation.
|
||||
TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims(
|
||||
|
@ -28,19 +28,22 @@ using tensorflow::string;
|
||||
|
||||
namespace {
|
||||
|
||||
std::vector<int64> TensorShapeAsVector(TFE_TensorHandle* handle,
|
||||
TF_Status* status) {
|
||||
std::vector<int64> TensorShapeAsVector(const tensorflow::TensorHandle& handle,
|
||||
tensorflow::Status* status) {
|
||||
std::vector<int64> shape;
|
||||
int rank = TFE_TensorHandleNumDims(handle, status);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
int rank = -1;
|
||||
*status = handle.NumDims(&rank);
|
||||
if (!status->ok()) {
|
||||
return shape;
|
||||
}
|
||||
shape.reserve(rank);
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
shape.push_back(TFE_TensorHandleDim(handle, i, status));
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
tensorflow::int64 dim;
|
||||
*status = handle.Dim(i, &dim);
|
||||
if (!status->ok()) {
|
||||
return shape;
|
||||
}
|
||||
shape.push_back(dim);
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
@ -51,14 +54,19 @@ extern "C" {
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
TFE_TensorHandle* h, TF_Status* status) {
|
||||
return h->handle->TensorDebugInfo(&status->status);
|
||||
}
|
||||
|
||||
TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
|
||||
Status* status) {
|
||||
const tensorflow::Tensor* tensor;
|
||||
status->status = h->handle->Tensor(&tensor);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
*status = handle_->Tensor(&tensor);
|
||||
if (!status->ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
tensorflow::Device* device = h->handle->device();
|
||||
tensorflow::Device* device = handle_->device();
|
||||
|
||||
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
|
||||
tensorflow::XlaDevice* xla_device =
|
||||
@ -67,15 +75,15 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
tensorflow::XlaDevice::PaddedShapeFn shape_fn =
|
||||
xla_device->metadata().padded_shape_fn();
|
||||
xla::Shape padded_shape;
|
||||
status->status = shape_fn(*tensor, &padded_shape);
|
||||
if (!status->status.ok()) {
|
||||
*status = shape_fn(*tensor, &padded_shape);
|
||||
if (!status->ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (VLOG_IS_ON(3)) {
|
||||
std::vector<int64> shape_to_log = TensorShapeAsVector(h, status);
|
||||
if (!status->status.ok()) {
|
||||
std::vector<int64> shape_to_log = TensorShapeAsVector(*handle_, status);
|
||||
if (!status->ok()) {
|
||||
// Ignore the status here as we are simply logging.
|
||||
status->status = tensorflow::Status::OK();
|
||||
*status = tensorflow::Status::OK();
|
||||
} else {
|
||||
VLOG(3) << "Fully padded shape of ["
|
||||
<< absl::StrJoin(shape_to_log, ", ") << "] is "
|
||||
@ -88,7 +96,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
// Currently, the only case of XlaTensor containing a tuple shape is to
|
||||
// represent 64 bit ints, doubles, and complex numbers (we don't support
|
||||
// 64bit complex numbers).
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
*status = tensorflow::errors::InvalidArgument(
|
||||
"XlaTensors should only contain tuples of size 2. Shape: ",
|
||||
padded_shape.DebugString());
|
||||
return nullptr;
|
||||
@ -100,13 +108,13 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
const xla::Shape& shape1 =
|
||||
xla::ShapeUtil::GetTupleElementShape(padded_shape, 1);
|
||||
if (shape0.IsTuple() || shape1.IsTuple()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
*status = tensorflow::errors::InvalidArgument(
|
||||
"XlaTensors should not contain nested tuples. Shape: ",
|
||||
padded_shape.DebugString());
|
||||
return nullptr;
|
||||
}
|
||||
if (!xla::ShapeUtil::Equal(shape0, shape1)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
*status = tensorflow::errors::InvalidArgument(
|
||||
"Subshapes of XlaTensors should be the same. Shape: ",
|
||||
padded_shape.DebugString());
|
||||
return nullptr;
|
||||
@ -131,15 +139,15 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
dev_dims.push_back(padded_shape.dimensions(dim_index));
|
||||
}
|
||||
}
|
||||
status->status = tensorflow::Status::OK();
|
||||
*status = tensorflow::Status::OK();
|
||||
return new TFE_TensorDebugInfo(dev_dims);
|
||||
}
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
|
||||
// If the tensor is not an XLA tensor, the device shape is
|
||||
// the same as regular tensor shape.
|
||||
std::vector<int64> dev_dims = TensorShapeAsVector(h, status);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
std::vector<int64> dev_dims = TensorShapeAsVector(*handle_, status);
|
||||
if (!status->ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TFE_TensorDebugInfo(dev_dims);
|
||||
|
@ -18,22 +18,23 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/profiler/rpc/client/capture_profile.h"
|
||||
#include "tensorflow/core/profiler/rpc/profiler_server.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name,
|
||||
const char* raw_device_name, TF_Status* status,
|
||||
TFE_Op* op_to_reset) {
|
||||
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
||||
const char* raw_device_name, TF_Status* status) {
|
||||
if (op_to_reset) {
|
||||
NewOrResetOp(ctx, op_or_function_name, raw_device_name, status,
|
||||
op_to_reset);
|
||||
status->status = op_to_reset->operation.Reset(
|
||||
op_or_function_name, raw_device_name, false, nullptr);
|
||||
} else {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"op_to_reset should not be nullptr");
|
||||
@ -41,7 +42,9 @@ void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name,
|
||||
}
|
||||
|
||||
void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
|
||||
op->operation.ConsumeInput(h->handle);
|
||||
op->operation.ConsumeInput(
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||
->Handle());
|
||||
}
|
||||
|
||||
TFE_Profiler* TFE_NewProfiler() { return new TFE_Profiler(); }
|
||||
@ -85,14 +88,14 @@ bool TFE_ProfilerClientStartTracing(const char* service_addr,
|
||||
int num_tracing_attempts,
|
||||
TF_Status* status) {
|
||||
tensorflow::Status s =
|
||||
tensorflow::profiler::client::ValidateHostPortPair(service_addr);
|
||||
tensorflow::profiler::ValidateHostPortPair(service_addr);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
return false;
|
||||
}
|
||||
s = tensorflow::profiler::client::StartTracing(
|
||||
service_addr, logdir, worker_list, include_dataset_ops, duration_ms,
|
||||
num_tracing_attempts);
|
||||
s = tensorflow::profiler::Trace(service_addr, logdir, worker_list,
|
||||
include_dataset_ops, duration_ms,
|
||||
num_tracing_attempts);
|
||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
return s.ok();
|
||||
}
|
||||
@ -101,14 +104,14 @@ void TFE_ProfilerClientMonitor(const char* service_addr, int duration_ms,
|
||||
int monitoring_level, bool display_timestamp,
|
||||
TF_Buffer* result, TF_Status* status) {
|
||||
tensorflow::Status s =
|
||||
tensorflow::profiler::client::ValidateHostPortPair(service_addr);
|
||||
tensorflow::profiler::ValidateHostPortPair(service_addr);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
return;
|
||||
}
|
||||
string content;
|
||||
s = tensorflow::profiler::client::Monitor(
|
||||
service_addr, duration_ms, monitoring_level, display_timestamp, &content);
|
||||
s = tensorflow::profiler::Monitor(service_addr, duration_ms, monitoring_level,
|
||||
display_timestamp, &content);
|
||||
void* data = tensorflow::port::Malloc(content.length());
|
||||
content.copy(static_cast<char*>(data), content.length(), 0);
|
||||
result->data = data;
|
||||
@ -616,3 +619,16 @@ void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
|
||||
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
|
||||
return new TFE_Executor(&ctx->context->Executor());
|
||||
}
|
||||
|
||||
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
||||
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
|
||||
ctx->context->HostCPU()->parsed_name());
|
||||
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
|
||||
void* data = tensorflow::port::Malloc(str.length());
|
||||
str.copy(static_cast<char*>(data), str.length(), 0);
|
||||
buf->data = data;
|
||||
buf->length = str.length();
|
||||
buf->data_deallocator = [](void* data, size_t length) {
|
||||
tensorflow::port::Free(data);
|
||||
};
|
||||
}
|
||||
|
@ -29,10 +29,10 @@ extern "C" {
|
||||
// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster
|
||||
// than seperately calling it because if the existing op has the same
|
||||
// `raw_device_name`, it skips parsing and just leave as it is.
|
||||
TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Context* ctx,
|
||||
TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset,
|
||||
const char* op_or_function_name,
|
||||
const char* raw_device_name,
|
||||
TF_Status* status, TFE_Op* op_to_reset);
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
|
||||
TF_Status* status);
|
||||
@ -458,6 +458,11 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
void (*deallocator)(void* data, size_t len, void* arg),
|
||||
void* deallocator_arg, TF_Status* status);
|
||||
|
||||
// Retrieves the address space (i.e. job, replia, task) of the local host and
|
||||
// saves it in the buffer.
|
||||
TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
|
||||
TF_Buffer* buf);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -1,66 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/host_info.h"
|
||||
|
||||
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
const char* raw_device_name, TF_Status* status,
|
||||
TFE_Op* op_to_reset) {
|
||||
const char* name = op_or_function_name; // Shorthand
|
||||
const tensorflow::AttrTypeMap* types;
|
||||
bool is_function = false;
|
||||
status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (op_to_reset && op_to_reset->ctx != ctx) {
|
||||
status->status = tensorflow::errors::Internal(
|
||||
"Cannot reset a TFE_Op from another TFE_Context");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
|
||||
if (!is_function) {
|
||||
const tensorflow::OpDef* op_def;
|
||||
status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
inference_ctx.reset(new TFE_OpInferenceContext(op_def));
|
||||
} else if (!ctx->context->FindFunctionByName(name)) {
|
||||
status->status = tensorflow::errors::NotFound(
|
||||
"'", name,
|
||||
"' is neither a type of a primitive operation nor a name "
|
||||
"of a function registered in binary running on ",
|
||||
tensorflow::port::Hostname(),
|
||||
". Make sure the operation or function is "
|
||||
"registered in the binary running in this process.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (op_to_reset) {
|
||||
status->status = op_to_reset->Reset(
|
||||
name, is_function, types, raw_device_name, std::move(inference_ctx));
|
||||
return op_to_reset;
|
||||
}
|
||||
|
||||
TFE_Op* new_op =
|
||||
new TFE_Op(ctx, name, is_function, types, std::move(inference_ctx));
|
||||
status->status = new_op->operation.SetDeviceName(raw_device_name);
|
||||
return new_op;
|
||||
}
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
@ -62,36 +63,10 @@ struct TFE_ContextOptions {
|
||||
};
|
||||
|
||||
struct TFE_Context {
|
||||
TFE_Context(const tensorflow::SessionOptions& opts,
|
||||
TFE_ContextDevicePlacementPolicy default_device_placement_policy,
|
||||
TFE_ContextMirroringPolicy default_mirroring_policy, bool async,
|
||||
const bool lazy_remote_inputs_copy,
|
||||
const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned,
|
||||
tensorflow::Rendezvous* rendezvous,
|
||||
const tensorflow::CustomKernelCreator* custom_kernel_creator)
|
||||
: context(new tensorflow::EagerContext(
|
||||
opts,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
default_device_placement_policy),
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(
|
||||
default_mirroring_policy),
|
||||
async, lazy_remote_inputs_copy, device_mgr, device_mgr_owned,
|
||||
rendezvous, custom_kernel_creator)) {}
|
||||
|
||||
~TFE_Context() {
|
||||
// TODO(iga): Add a separate API method to shutdown TFE_Context so that we
|
||||
// don't send RPCs and block in destructor.
|
||||
context->WaitForAndCloseRemoteContexts();
|
||||
// context->RefCountIsOne() should be true here.
|
||||
// TODO(iga): Remove EagerContext refcounting.
|
||||
context->Unref();
|
||||
}
|
||||
|
||||
tensorflow::EagerContext* context;
|
||||
};
|
||||
|
||||
struct TFE_TensorHandle {
|
||||
explicit TFE_TensorHandle(tensorflow::TensorHandle* h) : handle(h) {}
|
||||
static TFE_TensorHandle* CreateLocalHandle(const class tensorflow::Tensor& t,
|
||||
TF_Status* s) {
|
||||
tensorflow::TensorHandle* handle;
|
||||
@ -99,10 +74,11 @@ struct TFE_TensorHandle {
|
||||
if (!s->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TFE_TensorHandle(handle);
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
|
||||
}
|
||||
|
||||
tensorflow::TensorHandle* handle;
|
||||
std::unique_ptr<AbstractTensorHandleInterface> handle;
|
||||
};
|
||||
|
||||
struct TFE_TensorDebugInfo {
|
||||
@ -113,46 +89,10 @@ struct TFE_TensorDebugInfo {
|
||||
std::vector<tensorflow::int64> dev_dims;
|
||||
};
|
||||
|
||||
struct TFE_OpInferenceContext {
|
||||
explicit TFE_OpInferenceContext(const tensorflow::OpDef* op_def)
|
||||
: op_def(op_def) {}
|
||||
|
||||
const tensorflow::OpDef* op_def; // op definition from protobuf
|
||||
int input_arg_idx = 0; // arg definition index for the next input to be added
|
||||
tensorflow::gtl::FlatSet<std::string> attrs; // attributes inferred so far
|
||||
};
|
||||
|
||||
struct TFE_Op {
|
||||
TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
|
||||
const tensorflow::AttrTypeMap* t,
|
||||
std::unique_ptr<TFE_OpInferenceContext> inference_ctx)
|
||||
: ctx(ctx),
|
||||
operation(ctx->context, op, is_function, t),
|
||||
inference_ctx(std::move(inference_ctx)) {}
|
||||
|
||||
void Clear() {
|
||||
operation.Clear();
|
||||
inference_ctx.reset();
|
||||
}
|
||||
|
||||
tensorflow::Status Reset(const char* op, bool is_function,
|
||||
const tensorflow::AttrTypeMap* t,
|
||||
const char* raw_device_name,
|
||||
std::unique_ptr<TFE_OpInferenceContext> infer_ctx) {
|
||||
inference_ctx = std::move(infer_ctx);
|
||||
return operation.Reset(ctx->context, op, is_function, t, raw_device_name,
|
||||
nullptr);
|
||||
}
|
||||
|
||||
TFE_Context* ctx;
|
||||
tensorflow::EagerOperation operation;
|
||||
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
|
||||
};
|
||||
|
||||
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
const char* raw_device_name, TF_Status* status,
|
||||
TFE_Op* op_to_reset = nullptr);
|
||||
|
||||
struct TFE_Profiler {
|
||||
explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); }
|
||||
|
||||
|
@ -1362,10 +1362,11 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
|
||||
TFE_TensorHandle* inputs[] = {input1, input2};
|
||||
TFE_OpAddInput(concatOp, dim, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
CHECK(concatOp->inference_ctx);
|
||||
CHECK(concatOp->operation.OpDef());
|
||||
TFE_OpAddInput(concatOp, inputs[0], status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_FALSE(concatOp->inference_ctx) << "Inference context is still present";
|
||||
EXPECT_FALSE(concatOp->operation.OpDef())
|
||||
<< "Inference context is still present";
|
||||
TFE_OpAddInput(concatOp, inputs[1], status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
|
@ -284,7 +284,7 @@ class ForwardAccumulator {
|
||||
// Temporarily push or pop transient state for this accumulator.
|
||||
//
|
||||
// Allows an accumulator which is currently processing an operation to
|
||||
// temporarily reset its state. Without pushing and poping, accumulators
|
||||
// temporarily reset its state. Without pushing and popping, accumulators
|
||||
// ignore operations executed as a direct result of their own jvp
|
||||
// computations.
|
||||
void PushState() { call_state_.emplace(nullptr, false); }
|
||||
|
90
tensorflow/c/eager/tensor_handle_interface.h
Normal file
90
tensorflow/c/eager/tensor_handle_interface.h
Normal file
@ -0,0 +1,90 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
||||
#define TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
|
||||
// Abstract interface to a TensorHandle.
|
||||
//
|
||||
// A TensorHandle is management class around a Tensor which may track additional
|
||||
// metadata and synchronization.
|
||||
//
|
||||
// This allows us to hide concrete implementations of TensorHandle from header
|
||||
// files. The interface lists the common functionality that must be provided by
|
||||
// any concrete implementation. However, in cases where the true concrete class
|
||||
// is needed a static_cast can be applied.
|
||||
class AbstractTensorHandleInterface {
|
||||
public:
|
||||
virtual ~AbstractTensorHandleInterface() {}
|
||||
|
||||
// Check if the handle is in a valid initialized state.
|
||||
virtual bool IsValid(tensorflow::Status* status) const = 0;
|
||||
// Returns tensor dtype.
|
||||
virtual TF_DataType DataType() const = 0;
|
||||
// Returns number of dimensions.
|
||||
virtual int NumDims(tensorflow::Status* status) const = 0;
|
||||
// Returns number of elements across all dimensions.
|
||||
virtual int64_t NumElements(tensorflow::Status* status) const = 0;
|
||||
// Returns size of specified dimension
|
||||
virtual int64_t Dim(int dim_index, tensorflow::Status* status) const = 0;
|
||||
|
||||
// Returns the device which created the handle.
|
||||
virtual const char* DeviceName(tensorflow::Status* status) const = 0;
|
||||
// Returns the device where the tensor was placed.
|
||||
virtual const char* BackingDeviceName(tensorflow::Status* status) const = 0;
|
||||
// Returns a tensor for the handle. If tensor is remote, it will be copied.
|
||||
virtual TF_Tensor* Resolve(tensorflow::Status* status) = 0;
|
||||
// Returns debug information about the tensor.
|
||||
virtual TFE_TensorDebugInfo* TensorDebugInfo(tensorflow::Status* status) = 0;
|
||||
|
||||
// Return a copy of the handle.
|
||||
virtual AbstractTensorHandleInterface* Copy() = 0;
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TensorHandleInterface : public AbstractTensorHandleInterface {
|
||||
public:
|
||||
explicit TensorHandleInterface(TensorHandle* h) : handle_(h) {}
|
||||
~TensorHandleInterface() override;
|
||||
|
||||
bool IsValid(Status* status) const override;
|
||||
TF_DataType DataType() const override;
|
||||
int NumDims(Status* status) const override;
|
||||
int64_t NumElements(Status* status) const override;
|
||||
int64_t Dim(int dim_index, Status* status) const override;
|
||||
|
||||
const char* DeviceName(Status* status) const override;
|
||||
const char* BackingDeviceName(Status* status) const override;
|
||||
TF_Tensor* Resolve(Status* status) override;
|
||||
TFE_TensorDebugInfo* TensorDebugInfo(Status* status) override;
|
||||
|
||||
AbstractTensorHandleInterface* Copy() override;
|
||||
|
||||
// TODO(gjn): This is not a very generic interface, but is needed for specific
|
||||
// use cases.
|
||||
TensorHandle* Handle() { return handle_; }
|
||||
|
||||
private:
|
||||
TensorHandle* handle_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
@ -18,37 +18,23 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
# Core TensorFlow depends on this, this will be included in main library
|
||||
cc_library(
|
||||
name = "filesystem_interface_impl",
|
||||
srcs = ["filesystem_interface.cc"],
|
||||
hdrs = ["filesystem_interface.h"],
|
||||
deps = [
|
||||
":modular_filesystem",
|
||||
"//tensorflow/c:tf_file_statistics",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_status_internal",
|
||||
"//tensorflow/core:ptr_util",
|
||||
"//tensorflow/core/platform:env",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/core/platform:strcat",
|
||||
"//tensorflow/core/platform:stringpiece",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# Core TensorFlow depends on this, will be included in main library
|
||||
cc_library(
|
||||
name = "modular_filesystem",
|
||||
srcs = ["modular_filesystem.cc"],
|
||||
srcs = [
|
||||
"modular_filesystem.cc",
|
||||
"modular_filesystem_registration.cc",
|
||||
"modular_filesystem_registration.h",
|
||||
],
|
||||
hdrs = ["modular_filesystem.h"],
|
||||
deps = [
|
||||
":filesystem_interface",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/c:tf_status_internal",
|
||||
"//tensorflow/core:ptr_util",
|
||||
"//tensorflow/core/platform:env",
|
||||
"//tensorflow/core/platform:strcat",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:status",
|
||||
],
|
||||
)
|
||||
|
||||
@ -63,16 +49,12 @@ tf_cc_test(
|
||||
"notap", # b/139060984, requires implementing modular support for Google filesystem
|
||||
],
|
||||
deps = [
|
||||
":filesystem_interface_impl",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_status_internal",
|
||||
":modular_filesystem",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core/lib/io:path",
|
||||
"//tensorflow/core/platform:env",
|
||||
"//tensorflow/core/platform:error",
|
||||
"//tensorflow/core/platform:stacktrace_handler",
|
||||
"//tensorflow/core/platform:str_util",
|
||||
"//tensorflow/core/platform:strcat",
|
||||
"//tensorflow/core/platform:test",
|
||||
],
|
||||
)
|
||||
|
@ -1,366 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
|
||||
#include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
|
||||
/// This translation unit is linked in core TensorFlow and provides the
|
||||
/// functionality needed for plugin registration to check ABI/API compatibility,
|
||||
/// to ensure required methods are present, to ensure plugins are not allowed to
|
||||
/// change functionality after being loaded and to register the filesystems
|
||||
/// provided by a plugin. Consult the header file for more information about
|
||||
/// how this is achieved.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Checks if the plugin and core ABI numbers match, filling in `status`.
|
||||
//
|
||||
// If the numbers don't match, plugin cannot be loaded.
|
||||
static bool CheckABIHelper(int pluginABI, int coreABI, StringPiece where,
|
||||
TF_Status* status) {
|
||||
if (pluginABI != coreABI) {
|
||||
TF_SetStatus(
|
||||
status, TF_FAILED_PRECONDITION,
|
||||
strings::StrCat("Plugin ABI (", pluginABI, ") for ", where,
|
||||
" operations doesn't match expected core ABI (",
|
||||
coreABI, "). Plugin cannot be loaded.")
|
||||
.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Checks if the plugin and core ABI numbers match, for all operations.
|
||||
//
|
||||
// If the numbers don't match, plugin cannot be loaded.
|
||||
//
|
||||
// Uses the simpler `CheckABIHelper(int, int, StringPiece, TF_Status*)`
|
||||
static bool CheckABI(
|
||||
int plugin_filesystem_ops_ABI,
|
||||
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
|
||||
int plugin_random_access_file_ops_ABI,
|
||||
const TF_WritableFileOps* plugin_writable_file_ops,
|
||||
int plugin_writable_file_ops_ABI,
|
||||
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
|
||||
int plugin_read_only_memory_region_ops_ABI, TF_Status* status) {
|
||||
if (!CheckABIHelper(plugin_filesystem_ops_ABI, TF_FILESYSTEM_OPS_ABI,
|
||||
"filesystem", status))
|
||||
return false;
|
||||
|
||||
if (plugin_random_access_file_ops != nullptr &&
|
||||
!CheckABIHelper(plugin_random_access_file_ops_ABI,
|
||||
TF_RANDOM_ACCESS_FILE_OPS_ABI, "random access file",
|
||||
status))
|
||||
return false;
|
||||
|
||||
if (plugin_writable_file_ops != nullptr &&
|
||||
!CheckABIHelper(plugin_writable_file_ops_ABI, TF_WRITABLE_FILE_OPS_ABI,
|
||||
"writable file", status))
|
||||
return false;
|
||||
|
||||
if (plugin_read_only_memory_region_ops != nullptr &&
|
||||
!CheckABIHelper(plugin_read_only_memory_region_ops_ABI,
|
||||
TF_READ_ONLY_MEMORY_REGION_OPS_ABI,
|
||||
"read only memory region", status))
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Checks if the plugin and core API numbers match, logging mismatches.
|
||||
static void CheckAPIHelper(int plugin_API, int core_API, StringPiece where) {
|
||||
if (plugin_API != core_API) {
|
||||
VLOG(0) << "Plugin API (" << plugin_API << ") for " << where
|
||||
<< " operations doesn't match expected core API (" << core_API
|
||||
<< "). Plugin will be loaded but functionality might be missing.";
|
||||
}
|
||||
}
|
||||
|
||||
// Checks if the plugin and core API numbers match, for all operations.
|
||||
//
|
||||
// Uses the simpler `CheckAPIHelper(int, int, StringPiece)`.
|
||||
static void CheckAPI(
|
||||
int plugin_filesystem_ops_API,
|
||||
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
|
||||
int plugin_random_access_file_ops_API,
|
||||
const TF_WritableFileOps* plugin_writable_file_ops,
|
||||
int plugin_writable_file_ops_API,
|
||||
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
|
||||
int plugin_read_only_memory_region_ops_API) {
|
||||
CheckAPIHelper(plugin_filesystem_ops_API, TF_FILESYSTEM_OPS_API,
|
||||
"filesystem");
|
||||
|
||||
if (plugin_random_access_file_ops != nullptr)
|
||||
CheckAPIHelper(plugin_random_access_file_ops_API,
|
||||
TF_RANDOM_ACCESS_FILE_OPS_API, "random access file");
|
||||
|
||||
if (plugin_writable_file_ops != nullptr)
|
||||
CheckAPIHelper(plugin_writable_file_ops_API, TF_WRITABLE_FILE_OPS_API,
|
||||
"writable file");
|
||||
|
||||
if (plugin_read_only_memory_region_ops != nullptr)
|
||||
CheckAPIHelper(plugin_read_only_memory_region_ops_API,
|
||||
TF_READ_ONLY_MEMORY_REGION_OPS_API,
|
||||
"read only memory region");
|
||||
}
|
||||
|
||||
// Validates the filesystem operations supplied by the plugin.
|
||||
static bool ValidateHelper(const TF_FilesystemOps* ops, TF_Status* status) {
|
||||
if (ops == nullptr) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"Trying to register filesystem without operations");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ops->init == nullptr) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"Trying to register filesystem without `init` operation");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ops->cleanup == nullptr) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"Trying to register filesystem without `cleanup` operation");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Validates the random access file operations supplied by the plugin.
|
||||
static bool ValidateHelper(const TF_RandomAccessFileOps* ops,
|
||||
TF_Status* status) {
|
||||
if (ops == nullptr) {
|
||||
// We allow filesystems where files can only be written to (from TF code)
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ops->cleanup == nullptr) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"Trying to register filesystem without `cleanup` operation on "
|
||||
"random access files");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Validates the writable file operations supplied by the plugin.
|
||||
static bool ValidateHelper(const TF_WritableFileOps* ops, TF_Status* status) {
|
||||
if (ops == nullptr) {
|
||||
// We allow read-only filesystems
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ops->cleanup == nullptr) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"Trying to register filesystem without `cleanup` operation on "
|
||||
"writable files");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Validates the read only memory region operations given by the plugin.
|
||||
static bool ValidateHelper(const TF_ReadOnlyMemoryRegionOps* ops,
|
||||
TF_Status* status) {
|
||||
if (ops == nullptr) {
|
||||
// read only memory region support is always optional
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ops->cleanup == nullptr) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"Trying to register filesystem without `cleanup` operation on "
|
||||
"read only memory regions");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ops->data == nullptr) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"Trying to register filesystem without `data` operation on "
|
||||
"read only memory regions");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ops->length == nullptr) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"Trying to register filesystem without `length` operation on "
|
||||
"read only memory regions");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Validates the operations supplied by the plugin.
|
||||
//
|
||||
// Uses the 4 simpler `ValidateHelper(const TF_..., TF_Status*)` to validate
|
||||
// each individual function table and then checks that the function table for a
|
||||
// specific file type exists if the plugin offers support for creating that
|
||||
// type of files.
|
||||
static bool Validate(
|
||||
const TF_FilesystemOps* plugin_filesystem_ops,
|
||||
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
|
||||
const TF_WritableFileOps* plugin_writable_file_ops,
|
||||
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
|
||||
TF_Status* status) {
|
||||
if (!ValidateHelper(plugin_filesystem_ops, status)) return false;
|
||||
if (!ValidateHelper(plugin_random_access_file_ops, status)) return false;
|
||||
if (!ValidateHelper(plugin_writable_file_ops, status)) return false;
|
||||
if (!ValidateHelper(plugin_read_only_memory_region_ops, status)) return false;
|
||||
|
||||
if (plugin_filesystem_ops->new_random_access_file != nullptr &&
|
||||
plugin_random_access_file_ops == nullptr) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"Filesystem allows creation of random access files but no "
|
||||
"operations on them have been supplied.");
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((plugin_filesystem_ops->new_writable_file != nullptr ||
|
||||
plugin_filesystem_ops->new_appendable_file != nullptr) &&
|
||||
plugin_writable_file_ops == nullptr) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"Filesystem allows creation of writable files but no "
|
||||
"operations on them have been supplied.");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (plugin_filesystem_ops->new_read_only_memory_region_from_file != nullptr &&
|
||||
plugin_read_only_memory_region_ops == nullptr) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"Filesystem allows creation of readonly memory regions but no "
|
||||
"operations on them have been supplied.");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Copies a function table from plugin memory space to core memory space.
|
||||
//
|
||||
// This has three benefits:
|
||||
// * allows having newer plugins than the current core TensorFlow: the
|
||||
// additional entries in the plugin's table are just discarded;
|
||||
// * allows having older plugins than the current core TensorFlow (though
|
||||
// we are still warning users): the entries that core TensorFlow expects
|
||||
// but plugins didn't provide will be set to `nullptr` values and core
|
||||
// TensorFlow will know to not call these on behalf of users;
|
||||
// * increased security as plugins will not be able to alter function table
|
||||
// after loading up. Thus, malicious plugins can't alter functionality to
|
||||
// probe for gadgets inside core TensorFlow. We can even protect the area
|
||||
// of memory where the copies reside to not allow any more writes to it
|
||||
// after all copies are created.
|
||||
template <typename T>
|
||||
static std::unique_ptr<const T> CopyToCore(const T* plugin_ops,
|
||||
size_t plugin_size) {
|
||||
if (plugin_ops == nullptr) return nullptr;
|
||||
|
||||
size_t copy_size = sizeof(T);
|
||||
if (plugin_size < copy_size) {
|
||||
copy_size = plugin_size;
|
||||
}
|
||||
|
||||
auto core_ops = tensorflow::MakeUnique<T>();
|
||||
memcpy(const_cast<T*>(core_ops.get()), plugin_ops, copy_size);
|
||||
return core_ops;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
void RegisterFilesystemPlugin(
|
||||
int plugin_filesystem_ops_ABI, int plugin_filesystem_ops_API,
|
||||
size_t plugin_filesystem_ops_size, int plugin_random_access_file_ops_ABI,
|
||||
int plugin_random_access_file_ops_API,
|
||||
size_t plugin_random_access_file_ops_size, int plugin_writable_file_ops_ABI,
|
||||
int plugin_writable_file_ops_API, size_t plugin_writable_file_ops_size,
|
||||
int plugin_read_only_memory_region_ops_ABI,
|
||||
int plugin_read_only_memory_region_ops_API,
|
||||
size_t plugin_read_only_memory_region_ops_size, const char* scheme,
|
||||
const TF_FilesystemOps* plugin_filesystem_ops,
|
||||
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
|
||||
const TF_WritableFileOps* plugin_writable_file_ops,
|
||||
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
|
||||
TF_Status* status) {
|
||||
if (scheme == nullptr) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"`scheme` argument must not be `nullptr`.");
|
||||
return;
|
||||
}
|
||||
|
||||
// ABI numbers must match exactly for plugin to be loaded
|
||||
if (!tensorflow::CheckABI(
|
||||
plugin_filesystem_ops_ABI, plugin_random_access_file_ops,
|
||||
plugin_random_access_file_ops_ABI, plugin_writable_file_ops,
|
||||
plugin_writable_file_ops_ABI, plugin_read_only_memory_region_ops,
|
||||
plugin_read_only_memory_region_ops_ABI, status)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// API numbers should match but mismatch doesn't block plugin load
|
||||
tensorflow::CheckAPI(plugin_filesystem_ops_API, plugin_random_access_file_ops,
|
||||
plugin_random_access_file_ops_API,
|
||||
plugin_writable_file_ops, plugin_writable_file_ops_API,
|
||||
plugin_read_only_memory_region_ops,
|
||||
plugin_read_only_memory_region_ops_API);
|
||||
|
||||
// Plugin can only be loaded if all supplied ops are valid
|
||||
if (!tensorflow::Validate(plugin_filesystem_ops,
|
||||
plugin_random_access_file_ops,
|
||||
plugin_writable_file_ops,
|
||||
plugin_read_only_memory_region_ops, status)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Copy all the function tables to core TensorFlow memory space
|
||||
auto core_filesystem_ops = tensorflow::CopyToCore<TF_FilesystemOps>(
|
||||
plugin_filesystem_ops, plugin_filesystem_ops_size);
|
||||
auto core_random_access_file_ops =
|
||||
tensorflow::CopyToCore<TF_RandomAccessFileOps>(
|
||||
plugin_random_access_file_ops, plugin_random_access_file_ops_size);
|
||||
auto core_writable_file_ops = tensorflow::CopyToCore<TF_WritableFileOps>(
|
||||
plugin_writable_file_ops, plugin_writable_file_ops_size);
|
||||
auto core_read_only_memory_region_ops =
|
||||
tensorflow::CopyToCore<TF_ReadOnlyMemoryRegionOps>(
|
||||
plugin_read_only_memory_region_ops,
|
||||
plugin_read_only_memory_region_ops_size);
|
||||
|
||||
// Initialize the opaque filesystem structure
|
||||
auto filesystem = tensorflow::MakeUnique<TF_Filesystem>();
|
||||
core_filesystem_ops->init(filesystem.get(), status);
|
||||
if (!status->status.ok()) {
|
||||
core_filesystem_ops->cleanup(filesystem.get());
|
||||
return;
|
||||
}
|
||||
|
||||
// Register new filesystem
|
||||
status->status = tensorflow::Env::Default()->RegisterFileSystem(
|
||||
scheme, tensorflow::MakeUnique<tensorflow::ModularFileSystem>(
|
||||
std::move(filesystem), std::move(core_filesystem_ops),
|
||||
std::move(core_random_access_file_ops),
|
||||
std::move(core_writable_file_ops),
|
||||
std::move(core_read_only_memory_region_ops)));
|
||||
}
|
@ -56,7 +56,7 @@ extern "C" {
|
||||
/// Lifetime: The wrapper data structures are owned by core TensorFlow. The data
|
||||
/// pointed to by the `void*` members is always owned by the plugin. The plugin
|
||||
/// will provide functions to call to allocate and deallocate this data (see
|
||||
/// next section) and core TensorFlow ensures to call these at the proper time.
|
||||
/// next sections) and core TensorFlow ensures to call these at the proper time.
|
||||
///
|
||||
/// Plugins will never receive a `TF_*` pointer that is `nullptr`. Core
|
||||
/// TensorFlow will never touch the `void*` wrapped by these structures, except
|
||||
@ -529,7 +529,7 @@ typedef struct TF_FilesystemOps {
|
||||
/// If `statuses` is not null, plugins must fill each element with detailed
|
||||
/// status for each file, as if calling `path_exists` on each one. Core
|
||||
/// TensorFlow initializes the `statuses` array and plugins must use
|
||||
/// `TF_SetStatus` to set each element instead of dirrectly assigning.
|
||||
/// `TF_SetStatus` to set each element instead of directly assigning.
|
||||
///
|
||||
/// DEFAULT IMPLEMENTATION: Checks existence of every file. Needs
|
||||
/// `path_exists`.
|
||||
@ -601,6 +601,10 @@ typedef struct TF_FilesystemOps {
|
||||
///
|
||||
/// Plugins must not return `nullptr`. Returning empty strings is allowed.
|
||||
///
|
||||
/// The allocation and freeing of memory must happen via the functions sent to
|
||||
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
|
||||
/// structure in Section 4).
|
||||
///
|
||||
/// This function will be called by core TensorFlow to clean up all path
|
||||
/// arguments for all other methods in the filesystem API.
|
||||
///
|
||||
@ -618,6 +622,10 @@ typedef struct TF_FilesystemOps {
|
||||
/// In case of error, plugins must set `status` to a value different than
|
||||
/// `TF_OK`, free memory allocated for `entries` and return -1.
|
||||
///
|
||||
/// The allocation and freeing of memory must happen via the functions sent to
|
||||
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
|
||||
/// structure in Section 4).
|
||||
///
|
||||
/// Plugins:
|
||||
/// * Must set `status` to `TF_OK` if all children were returned.
|
||||
/// * Must set `status` to `TF_NOT_FOUND` if `path` doesn't point to a
|
||||
@ -654,6 +662,10 @@ typedef struct TF_FilesystemOps {
|
||||
/// different than `TF_OK`, free any memory that might have been allocated for
|
||||
/// `entries` and return -1.
|
||||
///
|
||||
/// The allocation and freeing of memory must happen via the functions sent to
|
||||
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
|
||||
/// structure in Section 4).
|
||||
///
|
||||
/// Plugins:
|
||||
/// * Must set `status` to `TF_OK` if all matches were returned.
|
||||
/// * Might use any other error value for `status` to signal other errors.
|
||||
@ -736,95 +748,132 @@ constexpr size_t TF_FILESYSTEM_OPS_SIZE = sizeof(TF_FilesystemOps);
|
||||
/// SECTION 4. Plugin registration and initialization
|
||||
/// ----------------------------------------------------------------------------
|
||||
///
|
||||
/// In this section we define two functions:
|
||||
/// * `TF_InitPlugin`: must be present in the plugin shared object as it will
|
||||
/// be called by core TensorFlow when the filesystem plugin is loaded;
|
||||
/// * `RegisterFilesystemPlugin`: it is implemented by core TensorFlow but
|
||||
/// plugins must call it in their `TF_InitPlugin`, usually using the macro
|
||||
/// `TF_REGISTER_FILESYSTEM_PLUGIN`.
|
||||
/// In this section we define the API used by core TensorFlow to initialize a
|
||||
/// filesystem provided by a plugin. That is, we define the following:
|
||||
/// * `TF_InitPlugin` function: must be present in the plugin shared object as
|
||||
/// it will be called by core TensorFlow when the filesystem plugin is
|
||||
/// loaded;
|
||||
/// * `TF_FilesystemPluginOps` struct: used to transfer information between
|
||||
/// plugins and core TensorFlow about the operations provided and metadata;
|
||||
/// * `TF_FilesystemPluginInfo` struct: similar to the above structure, but
|
||||
/// collects information about all the file schemes that the plugin provides
|
||||
/// support for, as well as about the plugin's memory handling routines;
|
||||
/// * `TF_SetFilesystemVersionMetadata` function: must be called by plugins in
|
||||
/// their `TF_InitPlugin` to record the versioning information the plugins
|
||||
/// are compiled against.
|
||||
///
|
||||
/// The `TF_InitPlugin` function is used by plugins to set up the data
|
||||
/// structures that implement this interface, as presented in Section 2.
|
||||
///
|
||||
/// The `RegisterFilesystemPlugin` is used by core TensorFlow to check that
|
||||
/// plugins satisfy the requirements expected by core TensorFlow, as follows:
|
||||
/// 1. If ABI numbers don't match we don't load the plugin, else we continue.
|
||||
/// 2. If the API numbers are mismatched, we warn the user and continue
|
||||
/// loading the plugin.
|
||||
/// 3. If any required operation is missing, we stop loading the plugin.
|
||||
///
|
||||
/// If all these checks succeed, we copy the plugin operations to a different
|
||||
/// memory location so that core TensorFlow has the guarantee that they won't be
|
||||
/// changed by plugins at a later time. Finally, we initialize the opaque
|
||||
/// pointer of `TF_Filesystem` by calling the required `init` function of
|
||||
/// `TF_FilesystemOps` and if that succeeds we register the filesystem.
|
||||
/// structures that implement this interface, as presented in Section 2. In
|
||||
/// order to not have plugin shared objects call back symbols defined in core
|
||||
/// TensorFlow, `TF_InitPlugin` has a `TF_FilesystemPluginInfo` argument which
|
||||
/// the plugin must fill (using the `TF_SetFilesystemVersionMetadata` for the
|
||||
/// metadata and setting up all the supported operations and the URI schemes
|
||||
/// that are supported).
|
||||
|
||||
// Initializes a TensorFlow plugin.
|
||||
//
|
||||
// Must be implemented by the plugin DSO. It is called by TensorFlow runtime.
|
||||
//
|
||||
// Filesystem plugins can be loaded on demand by users via
|
||||
// `Env::LoadLibrary` or during TensorFlow's startup if they are on certain
|
||||
// paths (although this has a security risk if two plugins register for the
|
||||
// same filesystem and the malicious one loads before the legimitate one -
|
||||
// but we consider this to be something that users should care about and
|
||||
// manage themselves). In both of these cases, core TensorFlow looks for
|
||||
// the `TF_InitPlugin` symbol and calls that function.
|
||||
//
|
||||
// A plugin is loaded only if this `status` is `TF_OK` after the call.
|
||||
TF_CAPI_EXPORT extern void TF_InitPlugin(TF_Status* status);
|
||||
/// This structure incorporates the operations defined in Section 2 and the
|
||||
/// metadata defined in section 3, allowing plugins to define different ops
|
||||
/// for different URI schemes.
|
||||
///
|
||||
/// Every URI scheme is of the form "fs" for URIs of form "fs:///path/to/file".
|
||||
/// For local filesystems (i.e., when the URI is "/path/to/file"), the scheme
|
||||
/// must be "". The scheme must never be `nullptr`.
|
||||
///
|
||||
/// Every plugin fills this in `TF_InitPlugin`, using the alocator passed as
|
||||
/// argument to allocate memory. After `TF_InitPlugin` finishes, core
|
||||
/// TensorFlow uses the information present in this to initialize filesystems
|
||||
/// for the URI schemes that the plugin requests.
|
||||
///
|
||||
/// All pointers defined in this structure point to memory allocated by the DSO
|
||||
/// using an allocator provided by core TensorFlow when calling `TF_InitPlugin`.
|
||||
///
|
||||
/// IMPORTANT: To maintain binary compatibility, the layout of this structure
|
||||
/// must not change! In the unlikely case that a new type of file needs to be
|
||||
/// supported, add the new ops and metadata at the end of the structure.
|
||||
typedef struct TF_FilesystemPluginOps {
|
||||
char* scheme;
|
||||
int filesystem_ops_abi;
|
||||
int filesystem_ops_api;
|
||||
size_t filesystem_ops_size;
|
||||
TF_FilesystemOps* filesystem_ops;
|
||||
int random_access_file_ops_abi;
|
||||
int random_access_file_ops_api;
|
||||
size_t random_access_file_ops_size;
|
||||
TF_RandomAccessFileOps* random_access_file_ops;
|
||||
int writable_file_ops_abi;
|
||||
int writable_file_ops_api;
|
||||
size_t writable_file_ops_size;
|
||||
TF_WritableFileOps* writable_file_ops;
|
||||
int read_only_memory_region_ops_abi;
|
||||
int read_only_memory_region_ops_api;
|
||||
size_t read_only_memory_region_ops_size;
|
||||
TF_ReadOnlyMemoryRegionOps* read_only_memory_region_ops;
|
||||
} TF_FilesystemPluginOps;
|
||||
|
||||
/// Registers a filesystem plugin so that core TensorFlow can use it.
|
||||
/// This structure gathers together all the operations provided by the plugin.
|
||||
///
|
||||
/// Must be called by the plugin during `TF_InitPlugin`, usually by using the
|
||||
/// convenience `TF_REGISTER_FILESYSTEM_PLUGIN` macro.
|
||||
/// Plugins must provide exactly `num_schemes` elements in the `ops` array.
|
||||
///
|
||||
/// Arguments (grouped by category):
|
||||
/// * `..ABI`: ABI compatibility numbers (see Section 3.).
|
||||
/// * `..API`: API compatibility numbers (see Section 3.).
|
||||
/// * `..Size`: Sizes of the operation tables (see Section 3.).
|
||||
/// * `scheme`: The URI scheme that plugin is registering filesystems for.
|
||||
/// Must be of the form "fs" for URIs of form "fs:///path/to/file". For
|
||||
/// local filesystems (i.e., when the URI is "/path/to/file"), `scheme`
|
||||
/// must be "". Must never be `nullptr`.
|
||||
/// * `..Ops`: The function tables provided by the plugin. Owned by the
|
||||
/// plugin, but core TensorFlow makes a copy of these.
|
||||
/// * `status`: The output variable for representing success/failure.
|
||||
/// Since memory that is allocated by the DSO gets transferred to core
|
||||
/// TensorFlow, we need to provide a way for the allocation and deallocation to
|
||||
/// match. This is why this structure also defines `plugin_memory_allocate` and
|
||||
/// `plugin_memory_free` members.
|
||||
///
|
||||
/// Sets `status` to `TF_OK` if plugin was registered and filesystem operations
|
||||
/// can be invoked from anywhere during TensorFlow's runtime. Any other value of
|
||||
/// `status` means that plugin failed to load properly and as such the
|
||||
/// operations it provides cannot be used at all (i.e., core TensorFlow will
|
||||
/// never run them, returning early with `TF_UNIMPLEMENTED` or similar error
|
||||
/// values).
|
||||
TF_CAPI_EXPORT extern void RegisterFilesystemPlugin(
|
||||
int pluginFilesystemOpsABI, int pluginFilesystemOpsAPI,
|
||||
size_t pluginFilesystemOpsSize, int pluginRandomAccessFileOpsABI,
|
||||
int pluginRandomAccessFileOpsAPI, size_t pluginRandomAccessFileOpsSize,
|
||||
int pluginWritableFileOpsABI, int pluginWritableFileOpsAPI,
|
||||
size_t pluginWritableFileOpsSize, int pluginReadOnlyMemoryRegionOpsABI,
|
||||
int pluginReadOnlyMemoryRegionOpsAPI,
|
||||
size_t pluginReadOnlyMemoryRegionOpsSize, const char* scheme,
|
||||
const TF_FilesystemOps* pluginFilesystemOps,
|
||||
const TF_RandomAccessFileOps* pluginRandomAccessFileOps,
|
||||
const TF_WritableFileOps* pluginWritableFileOps,
|
||||
const TF_ReadOnlyMemoryRegionOps* pluginReadOnlyMemoryRegionOps,
|
||||
TF_Status* status);
|
||||
/// All memory allocated by the plugin that will be owned by core TensorFlow
|
||||
/// must be allocated using the allocator in this structure. Core TensorFlow
|
||||
/// will use the deallocator to free this memory once it no longer needs it.
|
||||
///
|
||||
/// IMPORTANT: To maintain binary compatibility, the layout of this structure
|
||||
/// must not change! In the unlikely case that new global operations must be
|
||||
/// provided, add them at the end of the structure.
|
||||
typedef struct TF_FilesystemPluginInfo {
|
||||
size_t num_schemes;
|
||||
TF_FilesystemPluginOps* ops;
|
||||
void* (*plugin_memory_allocate)(size_t size);
|
||||
void (*plugin_memory_free)(void* ptr);
|
||||
} TF_FilesystemPluginInfo;
|
||||
|
||||
/// This macro is just a convenience wrapper around `RegisterFilesystemPlugin`.
|
||||
/// Plugins should prefer using this macro instead of a direct call.
|
||||
#define TF_REGISTER_FILESYSTEM_PLUGIN( \
|
||||
scheme, pluginFilesystemOps, pluginRandomAccessFileOps, \
|
||||
pluginWritableFileOps, pluginReadOnlyMemoryRegionOps, status) \
|
||||
RegisterFilesystemPlugin( \
|
||||
TF_FILESYSTEM_OPS_ABI, TF_FILESYSTEM_OPS_API, TF_FILESYSTEM_OPS_SIZE, \
|
||||
TF_RANDOM_ACCESS_FILE_OPS_ABI, TF_RANDOM_ACCESS_FILE_OPS_API, \
|
||||
TF_RANDOM_ACCESS_FILE_OPS_SIZE, TF_WRITABLE_FILE_OPS_ABI, \
|
||||
TF_WRITABLE_FILE_OPS_API, TF_WRITABLE_FILE_OPS_SIZE, \
|
||||
TF_READ_ONLY_MEMORY_REGION_OPS_ABI, TF_READ_ONLY_MEMORY_REGION_OPS_API, \
|
||||
TF_READ_ONLY_MEMORY_REGION_OPS_SIZE, scheme, pluginFilesystemOps, \
|
||||
pluginRandomAccessFileOps, pluginWritableFileOps, \
|
||||
pluginReadOnlyMemoryRegionOps, status)
|
||||
/// Convenience function for setting the versioning metadata.
|
||||
///
|
||||
/// The argument is guaranteed to not be `nullptr`.
|
||||
///
|
||||
/// We want this to be defined in the plugin's memory space and we guarantee
|
||||
/// that core TensorFlow will never call this.
|
||||
static inline void TF_SetFilesystemVersionMetadata(
|
||||
TF_FilesystemPluginOps* ops) {
|
||||
ops->filesystem_ops_abi = TF_FILESYSTEM_OPS_ABI;
|
||||
ops->filesystem_ops_api = TF_FILESYSTEM_OPS_API;
|
||||
ops->filesystem_ops_size = TF_FILESYSTEM_OPS_SIZE;
|
||||
ops->random_access_file_ops_abi = TF_RANDOM_ACCESS_FILE_OPS_ABI;
|
||||
ops->random_access_file_ops_api = TF_RANDOM_ACCESS_FILE_OPS_API;
|
||||
ops->random_access_file_ops_size = TF_RANDOM_ACCESS_FILE_OPS_SIZE;
|
||||
ops->writable_file_ops_abi = TF_WRITABLE_FILE_OPS_ABI;
|
||||
ops->writable_file_ops_api = TF_WRITABLE_FILE_OPS_API;
|
||||
ops->writable_file_ops_size = TF_WRITABLE_FILE_OPS_SIZE;
|
||||
ops->read_only_memory_region_ops_abi = TF_READ_ONLY_MEMORY_REGION_OPS_ABI;
|
||||
ops->read_only_memory_region_ops_api = TF_READ_ONLY_MEMORY_REGION_OPS_API;
|
||||
ops->read_only_memory_region_ops_size = TF_READ_ONLY_MEMORY_REGION_OPS_SIZE;
|
||||
}
|
||||
|
||||
/// Initializes a TensorFlow plugin.
|
||||
///
|
||||
/// Must be implemented by the plugin DSO. It is called by TensorFlow runtime.
|
||||
///
|
||||
/// Filesystem plugins can be loaded on demand by users via
|
||||
/// `Env::LoadLibrary` or during TensorFlow's startup if they are on certain
|
||||
/// paths (although this has a security risk if two plugins register for the
|
||||
/// same filesystem and the malicious one loads before the legimitate one -
|
||||
/// but we consider this to be something that users should care about and
|
||||
/// manage themselves). In both of these cases, core TensorFlow looks for
|
||||
/// the `TF_InitPlugin` symbol and calls this function.
|
||||
///
|
||||
/// For every filesystem URI scheme that this plugin supports, the plugin must
|
||||
/// add one `TF_FilesystemPluginInfo` entry in `plugin_info->ops` and call
|
||||
/// `TF_SetFilesystemVersionMetadata` for that entry.
|
||||
///
|
||||
/// Plugins must also initialize `plugin_info->plugin_memory_allocate` and
|
||||
/// `plugin_info->plugin_memory_free` to ensure memory allocated by plugin is
|
||||
/// freed in a compatible way.
|
||||
TF_CAPI_EXPORT extern void TF_InitPlugin(TF_FilesystemPluginInfo* plugin_info);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
|
@ -18,11 +18,10 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/file_system_helper.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
|
||||
// TODO(mihaimaruseac): After all filesystems are converted, all calls to
|
||||
@ -165,16 +164,18 @@ Status ModularFileSystem::GetChildren(const std::string& dir,
|
||||
|
||||
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
|
||||
std::string translated_name = TranslateName(dir);
|
||||
char** children;
|
||||
// Note that `children` is allocated by the plugin and freed by core
|
||||
// TensorFlow, so we need to use `plugin_memory_free_` here.
|
||||
char** children = nullptr;
|
||||
const int num_children =
|
||||
ops_->get_children(filesystem_.get(), translated_name.c_str(), &children,
|
||||
plugin_status.get());
|
||||
if (num_children >= 0) {
|
||||
for (int i = 0; i < num_children; i++) {
|
||||
result->push_back(std::string(children[i]));
|
||||
free(children[i]);
|
||||
plugin_memory_free_(children[i]);
|
||||
}
|
||||
free(children);
|
||||
plugin_memory_free_(children);
|
||||
}
|
||||
|
||||
return StatusFromTF_Status(plugin_status.get());
|
||||
@ -186,15 +187,17 @@ Status ModularFileSystem::GetMatchingPaths(const std::string& pattern,
|
||||
return internal::GetMatchingPaths(this, Env::Default(), pattern, result);
|
||||
|
||||
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
|
||||
char** matches;
|
||||
// Note that `matches` is allocated by the plugin and freed by core
|
||||
// TensorFlow, so we need to use `plugin_memory_free_` here.
|
||||
char** matches = nullptr;
|
||||
const int num_matches = ops_->get_matching_paths(
|
||||
filesystem_.get(), pattern.c_str(), &matches, plugin_status.get());
|
||||
if (num_matches >= 0) {
|
||||
for (int i = 0; i < num_matches; i++) {
|
||||
result->push_back(std::string(matches[i]));
|
||||
free(matches[i]);
|
||||
plugin_memory_free_(matches[i]);
|
||||
}
|
||||
free(matches);
|
||||
plugin_memory_free_(matches);
|
||||
}
|
||||
|
||||
return StatusFromTF_Status(plugin_status.get());
|
||||
@ -358,7 +361,8 @@ std::string ModularFileSystem::TranslateName(const std::string& name) const {
|
||||
CHECK(p != nullptr) << "TranslateName(" << name << ") returned nullptr";
|
||||
|
||||
std::string ret(p);
|
||||
free(p);
|
||||
// Since `p` is allocated by plugin, free it using plugin's method.
|
||||
plugin_memory_free_(p);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@ -435,4 +439,8 @@ Status ModularWritableFile::Tell(int64* position) {
|
||||
return StatusFromTF_Status(plugin_status.get());
|
||||
}
|
||||
|
||||
Status RegisterFilesystemPlugin(const std::string& dso_path) {
|
||||
return filesystem_registration::RegisterFilesystemPluginImpl(dso_path);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -32,7 +32,7 @@ namespace tensorflow {
|
||||
// TODO(b/143949615): After all filesystems are converted, this file will be
|
||||
// moved to core/platform, and this class can become a singleton and replace the
|
||||
// need for `Env::Default()`. At that time, we might decide to remove the need
|
||||
// for `Env::Default()` altoghether, but that's a different project, not in
|
||||
// for `Env::Default()` altogether, but that's a different project, not in
|
||||
// scope for now. I'm just mentioning this here as that transition will mean
|
||||
// removal of the registration part from `Env` and adding it here instead: we
|
||||
// will need tables to hold for each scheme the function tables that implement
|
||||
@ -46,12 +46,16 @@ class ModularFileSystem final : public FileSystem {
|
||||
std::unique_ptr<const TF_RandomAccessFileOps> random_access_file_ops,
|
||||
std::unique_ptr<const TF_WritableFileOps> writable_file_ops,
|
||||
std::unique_ptr<const TF_ReadOnlyMemoryRegionOps>
|
||||
read_only_memory_region_ops)
|
||||
read_only_memory_region_ops,
|
||||
std::function<void*(size_t)> plugin_memory_allocate,
|
||||
std::function<void(void*)> plugin_memory_free)
|
||||
: filesystem_(std::move(filesystem)),
|
||||
ops_(std::move(filesystem_ops)),
|
||||
random_access_file_ops_(std::move(random_access_file_ops)),
|
||||
writable_file_ops_(std::move(writable_file_ops)),
|
||||
read_only_memory_region_ops_(std::move(read_only_memory_region_ops)) {}
|
||||
read_only_memory_region_ops_(std::move(read_only_memory_region_ops)),
|
||||
plugin_memory_allocate_(std::move(plugin_memory_allocate)),
|
||||
plugin_memory_free_(std::move(plugin_memory_free)) {}
|
||||
|
||||
~ModularFileSystem() override { ops_->cleanup(filesystem_.get()); }
|
||||
|
||||
@ -93,6 +97,8 @@ class ModularFileSystem final : public FileSystem {
|
||||
std::unique_ptr<const TF_WritableFileOps> writable_file_ops_;
|
||||
std::unique_ptr<const TF_ReadOnlyMemoryRegionOps>
|
||||
read_only_memory_region_ops_;
|
||||
std::function<void*(size_t)> plugin_memory_allocate_;
|
||||
std::function<void(void*)> plugin_memory_free_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ModularFileSystem);
|
||||
};
|
||||
|
||||
@ -156,6 +162,9 @@ class ModularReadOnlyMemoryRegion final : public ReadOnlyMemoryRegion {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ModularReadOnlyMemoryRegion);
|
||||
};
|
||||
|
||||
// Registers a filesystem plugin so that core TensorFlow can use it.
|
||||
Status RegisterFilesystemPlugin(const std::string& dso_path);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_H_
|
||||
|
@ -0,0 +1,346 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h"
|
||||
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Checks that all schemes provided by a plugin are valid.
|
||||
// TODO(mihaimaruseac): More validation could be done here, based on supported
|
||||
// charset, maximum length, etc. Punting it for later.
|
||||
static Status ValidateScheme(const char* scheme) {
|
||||
if (scheme == nullptr)
|
||||
return errors::InvalidArgument(
|
||||
"Attempted to register filesystem with `nullptr` URI scheme");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Checks if the plugin and core ABI numbers match.
|
||||
//
|
||||
// If the numbers don't match, plugin cannot be loaded.
|
||||
static Status CheckABI(int pluginABI, int coreABI, StringPiece where) {
|
||||
if (pluginABI != coreABI)
|
||||
return errors::FailedPrecondition(
|
||||
strings::StrCat("Plugin ABI (", pluginABI, ") for ", where,
|
||||
" operations doesn't match expected core ABI (",
|
||||
coreABI, "). Plugin cannot be loaded."));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Checks if the plugin and core ABI numbers match, for all operations.
|
||||
//
|
||||
// If the numbers don't match, plugin cannot be loaded.
|
||||
//
|
||||
// Uses the simpler `CheckABI(int, int, StringPiece)`.
|
||||
static Status ValidateABI(const TF_FilesystemPluginOps* ops) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckABI(ops->filesystem_ops_abi, TF_FILESYSTEM_OPS_ABI, "filesystem"));
|
||||
|
||||
if (ops->random_access_file_ops != nullptr)
|
||||
TF_RETURN_IF_ERROR(CheckABI(ops->random_access_file_ops_abi,
|
||||
TF_RANDOM_ACCESS_FILE_OPS_ABI,
|
||||
"random access file"));
|
||||
|
||||
if (ops->writable_file_ops != nullptr)
|
||||
TF_RETURN_IF_ERROR(CheckABI(ops->writable_file_ops_abi,
|
||||
TF_WRITABLE_FILE_OPS_ABI, "writable file"));
|
||||
|
||||
if (ops->read_only_memory_region_ops != nullptr)
|
||||
TF_RETURN_IF_ERROR(CheckABI(ops->read_only_memory_region_ops_abi,
|
||||
TF_READ_ONLY_MEMORY_REGION_OPS_ABI,
|
||||
"read only memory region"));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Checks if the plugin and core API numbers match, logging mismatches.
|
||||
static void CheckAPI(int plugin_API, int core_API, StringPiece where) {
|
||||
if (plugin_API != core_API) {
|
||||
VLOG(0) << "Plugin API (" << plugin_API << ") for " << where
|
||||
<< " operations doesn't match expected core API (" << core_API
|
||||
<< "). Plugin will be loaded but functionality might be missing.";
|
||||
}
|
||||
}
|
||||
|
||||
// Checks if the plugin and core API numbers match, for all operations.
|
||||
//
|
||||
// Uses the simpler `CheckAPIHelper(int, int, StringPiece)`.
|
||||
static void ValidateAPI(const TF_FilesystemPluginOps* ops) {
|
||||
CheckAPI(ops->filesystem_ops_api, TF_FILESYSTEM_OPS_API, "filesystem");
|
||||
|
||||
if (ops->random_access_file_ops != nullptr)
|
||||
CheckAPI(ops->random_access_file_ops_api, TF_RANDOM_ACCESS_FILE_OPS_API,
|
||||
"random access file");
|
||||
|
||||
if (ops->writable_file_ops != nullptr)
|
||||
CheckAPI(ops->writable_file_ops_api, TF_WRITABLE_FILE_OPS_API,
|
||||
"writable file");
|
||||
|
||||
if (ops->read_only_memory_region_ops != nullptr)
|
||||
CheckAPI(ops->read_only_memory_region_ops_api,
|
||||
TF_READ_ONLY_MEMORY_REGION_OPS_API, "read only memory region");
|
||||
}
|
||||
|
||||
// Validates the filesystem operations supplied by the plugin.
|
||||
static Status ValidateHelper(const TF_FilesystemOps* ops) {
|
||||
if (ops == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Trying to register filesystem without operations");
|
||||
|
||||
if (ops->init == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Trying to register filesystem without `init` operation");
|
||||
|
||||
if (ops->cleanup == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Trying to register filesystem without `cleanup` operation");
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Validates the random access file operations supplied by the plugin.
|
||||
static Status ValidateHelper(const TF_RandomAccessFileOps* ops) {
|
||||
if (ops == nullptr) {
|
||||
// We allow filesystems where files can only be written to (from TF code)
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (ops->cleanup == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Trying to register filesystem without `cleanup` operation on random "
|
||||
"access files");
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Validates the writable file operations supplied by the plugin.
|
||||
static Status ValidateHelper(const TF_WritableFileOps* ops) {
|
||||
if (ops == nullptr) {
|
||||
// We allow read-only filesystems
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (ops->cleanup == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Trying to register filesystem without `cleanup` operation on writable "
|
||||
"files");
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Validates the read only memory region operations given by the plugin.
|
||||
static Status ValidateHelper(const TF_ReadOnlyMemoryRegionOps* ops) {
|
||||
if (ops == nullptr) {
|
||||
// read only memory region support is always optional
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (ops->cleanup == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Trying to register filesystem without `cleanup` operation on read "
|
||||
"only memory regions");
|
||||
|
||||
if (ops->data == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Trying to register filesystem without `data` operation on read only "
|
||||
"memory regions");
|
||||
|
||||
if (ops->length == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Trying to register filesystem without `length` operation on read only "
|
||||
"memory regions");
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Validates the operations supplied by the plugin.
|
||||
//
|
||||
// Uses the 4 simpler `ValidateHelper(const TF_...*)` to validate each
|
||||
// individual function table and then checks that the function table for a
|
||||
// specific file type exists if the plugin offers support for creating that
|
||||
// type of files.
|
||||
static Status ValidateOperations(const TF_FilesystemPluginOps* ops) {
|
||||
TF_RETURN_IF_ERROR(ValidateHelper(ops->filesystem_ops));
|
||||
TF_RETURN_IF_ERROR(ValidateHelper(ops->random_access_file_ops));
|
||||
TF_RETURN_IF_ERROR(ValidateHelper(ops->writable_file_ops));
|
||||
TF_RETURN_IF_ERROR(ValidateHelper(ops->read_only_memory_region_ops));
|
||||
|
||||
if (ops->filesystem_ops->new_random_access_file != nullptr &&
|
||||
ops->random_access_file_ops == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Filesystem allows creation of random access files but no "
|
||||
"operations on them have been supplied.");
|
||||
|
||||
if ((ops->filesystem_ops->new_writable_file != nullptr ||
|
||||
ops->filesystem_ops->new_appendable_file != nullptr) &&
|
||||
ops->writable_file_ops == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Filesystem allows creation of writable files but no "
|
||||
"operations on them have been supplied.");
|
||||
|
||||
if (ops->filesystem_ops->new_read_only_memory_region_from_file != nullptr &&
|
||||
ops->read_only_memory_region_ops == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Filesystem allows creation of readonly memory regions but no "
|
||||
"operations on them have been supplied.");
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Copies a function table from plugin memory space to core memory space.
|
||||
//
|
||||
// This has three benefits:
|
||||
// * allows having newer plugins than the current core TensorFlow: the
|
||||
// additional entries in the plugin's table are just discarded;
|
||||
// * allows having older plugins than the current core TensorFlow (though
|
||||
// we are still warning users): the entries that core TensorFlow expects
|
||||
// but plugins didn't provide will be set to `nullptr` values and core
|
||||
// TensorFlow will know to not call these on behalf of users;
|
||||
// * increased security as plugins will not be able to alter function table
|
||||
// after loading up. Thus, malicious plugins can't alter functionality to
|
||||
// probe for gadgets inside core TensorFlow. We can even protect the area
|
||||
// of memory where the copies reside to not allow any more writes to it
|
||||
// after all copies are created.
|
||||
template <typename T>
|
||||
static std::unique_ptr<const T> CopyToCore(const T* plugin_ops,
|
||||
size_t plugin_size) {
|
||||
if (plugin_ops == nullptr) return nullptr;
|
||||
|
||||
size_t copy_size = std::min(plugin_size, sizeof(T));
|
||||
auto core_ops = tensorflow::MakeUnique<T>();
|
||||
memset(core_ops.get(), 0, sizeof(T));
|
||||
memcpy(core_ops.get(), plugin_ops, copy_size);
|
||||
return core_ops;
|
||||
}
|
||||
|
||||
// Registers one filesystem from the plugin.
|
||||
//
|
||||
// Must be called only with `index` a valid index in `info->ops`.
|
||||
static Status RegisterFileSystem(const TF_FilesystemPluginInfo* info,
|
||||
int index) {
|
||||
// Step 1: Copy all the function tables to core TensorFlow memory space
|
||||
auto core_filesystem_ops = CopyToCore<TF_FilesystemOps>(
|
||||
info->ops[index].filesystem_ops, info->ops[index].filesystem_ops_size);
|
||||
auto core_random_access_file_ops = CopyToCore<TF_RandomAccessFileOps>(
|
||||
info->ops[index].random_access_file_ops,
|
||||
info->ops[index].random_access_file_ops_size);
|
||||
auto core_writable_file_ops =
|
||||
CopyToCore<TF_WritableFileOps>(info->ops[index].writable_file_ops,
|
||||
info->ops[index].writable_file_ops_size);
|
||||
auto core_read_only_memory_region_ops =
|
||||
CopyToCore<TF_ReadOnlyMemoryRegionOps>(
|
||||
info->ops[index].read_only_memory_region_ops,
|
||||
info->ops[index].read_only_memory_region_ops_size);
|
||||
|
||||
// Step 2: Initialize the opaque filesystem structure
|
||||
auto filesystem = tensorflow::MakeUnique<TF_Filesystem>();
|
||||
TF_Status* c_status = TF_NewStatus();
|
||||
Status status = Status::OK();
|
||||
core_filesystem_ops->init(filesystem.get(), c_status);
|
||||
status = Status(c_status->status);
|
||||
TF_DeleteStatus(c_status);
|
||||
if (!status.ok()) return status;
|
||||
|
||||
// Step 3: Actual registration
|
||||
return Env::Default()->RegisterFileSystem(
|
||||
info->ops[index].scheme,
|
||||
tensorflow::MakeUnique<tensorflow::ModularFileSystem>(
|
||||
std::move(filesystem), std::move(core_filesystem_ops),
|
||||
std::move(core_random_access_file_ops),
|
||||
std::move(core_writable_file_ops),
|
||||
std::move(core_read_only_memory_region_ops),
|
||||
info->plugin_memory_allocate, info->plugin_memory_free));
|
||||
}
|
||||
|
||||
// Registers filesystem at `index`, if plugin is providing valid information.
|
||||
//
|
||||
// Extracted to a separate function so that pointers inside `info` are freed
|
||||
// by the caller regardless of whether validation/registration failed or not.
|
||||
//
|
||||
// Must be called only with `index` a valid index in `info->ops`.
|
||||
static Status ValidateAndRegisterFilesystems(
|
||||
const TF_FilesystemPluginInfo* info, int index) {
|
||||
TF_RETURN_IF_ERROR(ValidateScheme(info->ops[index].scheme));
|
||||
TF_RETURN_IF_ERROR(ValidateABI(&info->ops[index]));
|
||||
ValidateAPI(&info->ops[index]); // we just warn on API number mismatch
|
||||
TF_RETURN_IF_ERROR(ValidateOperations(&info->ops[index]));
|
||||
TF_RETURN_IF_ERROR(RegisterFileSystem(info, index));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Ensures that the plugin provides the required memory management operations.
|
||||
static Status ValidatePluginMemoryRoutines(
|
||||
const TF_FilesystemPluginInfo* info) {
|
||||
if (info->plugin_memory_allocate == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Cannot load filesystem plugin which does not provide "
|
||||
"`plugin_memory_allocate`");
|
||||
|
||||
if (info->plugin_memory_free == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Cannot load filesystem plugin which does not provide "
|
||||
"`plugin_memory_free`");
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace filesystem_registration {
|
||||
|
||||
Status RegisterFilesystemPluginImpl(const std::string& dso_path) {
|
||||
// Step 1: Load plugin
|
||||
Env* env = Env::Default();
|
||||
void* dso_handle;
|
||||
TF_RETURN_IF_ERROR(env->LoadLibrary(dso_path.c_str(), &dso_handle));
|
||||
|
||||
// Step 2: Load symbol for `TF_InitPlugin`
|
||||
void* dso_symbol;
|
||||
TF_RETURN_IF_ERROR(
|
||||
env->GetSymbolFromLibrary(dso_handle, "TF_InitPlugin", &dso_symbol));
|
||||
|
||||
// Step 3: Call `TF_InitPlugin`
|
||||
TF_FilesystemPluginInfo info;
|
||||
memset(&info, 0, sizeof(info));
|
||||
auto TF_InitPlugin =
|
||||
reinterpret_cast<int (*)(TF_FilesystemPluginInfo*)>(dso_symbol);
|
||||
TF_InitPlugin(&info);
|
||||
|
||||
// Step 4: Ensure plugin provides the memory management functions.
|
||||
TF_RETURN_IF_ERROR(ValidatePluginMemoryRoutines(&info));
|
||||
|
||||
// Step 5: Validate and register all filesystems
|
||||
// Try to register as many filesystems as possible.
|
||||
// Free memory once we no longer need it
|
||||
Status status;
|
||||
for (int i = 0; i < info.num_schemes; i++) {
|
||||
status.Update(ValidateAndRegisterFilesystems(&info, i));
|
||||
info.plugin_memory_free(info.ops[i].scheme);
|
||||
info.plugin_memory_free(info.ops[i].filesystem_ops);
|
||||
info.plugin_memory_free(info.ops[i].random_access_file_ops);
|
||||
info.plugin_memory_free(info.ops[i].writable_file_ops);
|
||||
info.plugin_memory_free(info.ops[i].read_only_memory_region_ops);
|
||||
}
|
||||
info.plugin_memory_free(info.ops);
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace filesystem_registration
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,28 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
|
||||
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace filesystem_registration {
|
||||
|
||||
Status RegisterFilesystemPluginImpl(const std::string& dso_path);
|
||||
|
||||
} // namespace filesystem_registration
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
|
File diff suppressed because it is too large
Load Diff
@ -24,8 +24,6 @@ limitations under the License.
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_helper.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
@ -33,6 +31,9 @@ limitations under the License.
|
||||
// Implementation of a filesystem for POSIX environments.
|
||||
// This filesystem will support `file://` and empty (local) URI schemes.
|
||||
|
||||
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
|
||||
static void plugin_memory_free(void* ptr) { free(ptr); }
|
||||
|
||||
// SECTION 1. Implementation for `TF_RandomAccessFile`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_random_access_file {
|
||||
@ -45,7 +46,9 @@ typedef struct PosixFile {
|
||||
static void Cleanup(TF_RandomAccessFile* file) {
|
||||
auto posix_file = static_cast<PosixFile*>(file->plugin_file);
|
||||
close(posix_file->fd);
|
||||
free(const_cast<char*>(posix_file->filename));
|
||||
// This would be safe to free using `free` directly as it is only opaque.
|
||||
// However, it is better to be consistent everywhere.
|
||||
plugin_memory_free(const_cast<char*>(posix_file->filename));
|
||||
delete posix_file;
|
||||
}
|
||||
|
||||
@ -100,7 +103,7 @@ typedef struct PosixFile {
|
||||
|
||||
static void Cleanup(TF_WritableFile* file) {
|
||||
auto posix_file = static_cast<PosixFile*>(file->plugin_file);
|
||||
free(const_cast<char*>(posix_file->filename));
|
||||
plugin_memory_free(const_cast<char*>(posix_file->filename));
|
||||
delete posix_file;
|
||||
}
|
||||
|
||||
@ -383,12 +386,13 @@ static int GetChildren(const TF_Filesystem* filesystem, const char* path,
|
||||
if (num_entries < 0) {
|
||||
TF_SetStatusFromIOError(status, errno, path);
|
||||
} else {
|
||||
*entries = static_cast<char**>(calloc(num_entries, sizeof((*entries)[0])));
|
||||
*entries = static_cast<char**>(
|
||||
plugin_memory_allocate(num_entries * sizeof((*entries)[0])));
|
||||
for (int i = 0; i < num_entries; i++) {
|
||||
(*entries)[i] = strdup(dir_entries[i]->d_name);
|
||||
free(dir_entries[i]);
|
||||
plugin_memory_free(dir_entries[i]);
|
||||
}
|
||||
free(dir_entries);
|
||||
plugin_memory_free(dir_entries);
|
||||
}
|
||||
|
||||
return num_entries;
|
||||
@ -396,48 +400,59 @@ static int GetChildren(const TF_Filesystem* filesystem, const char* path,
|
||||
|
||||
} // namespace tf_posix_filesystem
|
||||
|
||||
void TF_InitPlugin(TF_Status* status) {
|
||||
TF_RandomAccessFileOps random_access_file_ops = {
|
||||
tf_random_access_file::Cleanup,
|
||||
tf_random_access_file::Read,
|
||||
};
|
||||
TF_WritableFileOps writable_file_ops = {
|
||||
tf_writable_file::Cleanup, tf_writable_file::Append,
|
||||
tf_writable_file::Tell, tf_writable_file::Flush,
|
||||
tf_writable_file::Sync, tf_writable_file::Close,
|
||||
};
|
||||
TF_ReadOnlyMemoryRegionOps read_only_memory_region_ops = {
|
||||
tf_read_only_memory_region::Cleanup,
|
||||
tf_read_only_memory_region::Data,
|
||||
tf_read_only_memory_region::Length,
|
||||
};
|
||||
TF_FilesystemOps filesystem_ops = {
|
||||
tf_posix_filesystem::Init,
|
||||
tf_posix_filesystem::Cleanup,
|
||||
tf_posix_filesystem::NewRandomAccessFile,
|
||||
tf_posix_filesystem::NewWritableFile,
|
||||
tf_posix_filesystem::NewAppendableFile,
|
||||
tf_posix_filesystem::NewReadOnlyMemoryRegionFromFile,
|
||||
tf_posix_filesystem::CreateDir,
|
||||
/*recursively_create_dir=*/nullptr,
|
||||
tf_posix_filesystem::DeleteFile,
|
||||
tf_posix_filesystem::DeleteDir,
|
||||
/*delete_recursively=*/nullptr,
|
||||
tf_posix_filesystem::RenameFile,
|
||||
tf_posix_filesystem::CopyFile,
|
||||
tf_posix_filesystem::PathExists,
|
||||
/*paths_exist=*/nullptr,
|
||||
tf_posix_filesystem::Stat,
|
||||
/*is_directory=*/nullptr,
|
||||
/*get_file_size=*/nullptr,
|
||||
/*translate_name=*/nullptr,
|
||||
tf_posix_filesystem::GetChildren,
|
||||
/*get_matching_paths=*/nullptr,
|
||||
/*flush_caches=*/nullptr,
|
||||
};
|
||||
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
||||
const char* uri) {
|
||||
TF_SetFilesystemVersionMetadata(ops);
|
||||
ops->scheme = strdup(uri);
|
||||
|
||||
for (const char* scheme : {"", "file"})
|
||||
TF_REGISTER_FILESYSTEM_PLUGIN(scheme, &filesystem_ops,
|
||||
&random_access_file_ops, &writable_file_ops,
|
||||
&read_only_memory_region_ops, status);
|
||||
ops->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
|
||||
plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
|
||||
ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup;
|
||||
ops->random_access_file_ops->read = tf_random_access_file::Read;
|
||||
|
||||
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
|
||||
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
|
||||
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
|
||||
ops->writable_file_ops->append = tf_writable_file::Append;
|
||||
ops->writable_file_ops->tell = tf_writable_file::Tell;
|
||||
ops->writable_file_ops->flush = tf_writable_file::Flush;
|
||||
ops->writable_file_ops->sync = tf_writable_file::Sync;
|
||||
ops->writable_file_ops->close = tf_writable_file::Close;
|
||||
|
||||
ops->read_only_memory_region_ops = static_cast<TF_ReadOnlyMemoryRegionOps*>(
|
||||
plugin_memory_allocate(TF_READ_ONLY_MEMORY_REGION_OPS_SIZE));
|
||||
ops->read_only_memory_region_ops->cleanup =
|
||||
tf_read_only_memory_region::Cleanup;
|
||||
ops->read_only_memory_region_ops->data = tf_read_only_memory_region::Data;
|
||||
ops->read_only_memory_region_ops->length = tf_read_only_memory_region::Length;
|
||||
|
||||
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
|
||||
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
|
||||
ops->filesystem_ops->init = tf_posix_filesystem::Init;
|
||||
ops->filesystem_ops->cleanup = tf_posix_filesystem::Cleanup;
|
||||
ops->filesystem_ops->new_random_access_file =
|
||||
tf_posix_filesystem::NewRandomAccessFile;
|
||||
ops->filesystem_ops->new_writable_file = tf_posix_filesystem::NewWritableFile;
|
||||
ops->filesystem_ops->new_appendable_file =
|
||||
tf_posix_filesystem::NewAppendableFile;
|
||||
ops->filesystem_ops->new_read_only_memory_region_from_file =
|
||||
tf_posix_filesystem::NewReadOnlyMemoryRegionFromFile;
|
||||
ops->filesystem_ops->create_dir = tf_posix_filesystem::CreateDir;
|
||||
ops->filesystem_ops->delete_file = tf_posix_filesystem::DeleteFile;
|
||||
ops->filesystem_ops->delete_dir = tf_posix_filesystem::DeleteDir;
|
||||
ops->filesystem_ops->rename_file = tf_posix_filesystem::RenameFile;
|
||||
ops->filesystem_ops->copy_file = tf_posix_filesystem::CopyFile;
|
||||
ops->filesystem_ops->path_exists = tf_posix_filesystem::PathExists;
|
||||
ops->filesystem_ops->stat = tf_posix_filesystem::Stat;
|
||||
ops->filesystem_ops->get_children = tf_posix_filesystem::GetChildren;
|
||||
}
|
||||
|
||||
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
||||
info->plugin_memory_allocate = plugin_memory_allocate;
|
||||
info->plugin_memory_free = plugin_memory_free;
|
||||
info->num_schemes = 2;
|
||||
info->ops = static_cast<TF_FilesystemPluginOps*>(
|
||||
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
|
||||
ProvideFilesystemSupportFor(&info->ops[0], "");
|
||||
ProvideFilesystemSupportFor(&info->ops[1], "file");
|
||||
}
|
||||
|
@ -44,7 +44,7 @@ int TransferFileContents(const char* src, const char* dst, mode_t mode,
|
||||
}
|
||||
|
||||
// Both files have been opened, do the transfer.
|
||||
// Since errno would be overriden by `close` below, save it here.
|
||||
// Since errno would be overridden by `close` below, save it here.
|
||||
int error_code = 0;
|
||||
if (CopyFileContents(dst_fd, src_fd, size) < 0) error_code = errno;
|
||||
|
||||
|
36
tensorflow/c/experimental/filesystem/plugins/windows/BUILD
Normal file
36
tensorflow/c/experimental/filesystem/plugins/windows/BUILD
Normal file
@ -0,0 +1,36 @@
|
||||
# Experimental windows filesystem plugin.
|
||||
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
# Filesystem implementation for Windows environment
|
||||
tf_cc_shared_object(
|
||||
name = "windows_filesystem.dll",
|
||||
framework_so = [],
|
||||
linkstatic = False,
|
||||
tags = [
|
||||
"manual",
|
||||
"nobuilder",
|
||||
"notap",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":windows_filesystem_impl"],
|
||||
)
|
||||
|
||||
# The real implementation of the filesystem.
|
||||
cc_library(
|
||||
name = "windows_filesystem_impl",
|
||||
srcs = ["windows_filesystem.cc"],
|
||||
copts = get_win_copts(),
|
||||
tags = [
|
||||
"manual",
|
||||
"nobuilder",
|
||||
"notap",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
],
|
||||
)
|
@ -0,0 +1,73 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
// Implementation of a filesystem for POSIX environments.
|
||||
// This filesystem will support `file://` and empty (local) URI schemes.
|
||||
|
||||
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
|
||||
static void plugin_memory_free(void* ptr) { free(ptr); }
|
||||
|
||||
// SECTION 1. Implementation for `TF_RandomAccessFile`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_random_access_file {
|
||||
|
||||
// TODO(mihaimaruseac): Implement later
|
||||
|
||||
} // namespace tf_random_access_file
|
||||
|
||||
// SECTION 2. Implementation for `TF_WritableFile`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_writable_file {
|
||||
|
||||
// TODO(mihaimaruseac): Implement later
|
||||
|
||||
} // namespace tf_writable_file
|
||||
|
||||
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_read_only_memory_region {
|
||||
|
||||
// TODO(mihaimaruseac): Implement later
|
||||
|
||||
} // namespace tf_read_only_memory_region
|
||||
|
||||
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_windows_filesystem {
|
||||
|
||||
// TODO(mihaimaruseac): Implement later
|
||||
|
||||
} // namespace tf_windows_filesystem
|
||||
|
||||
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
||||
const char* uri) {
|
||||
TF_SetFilesystemVersionMetadata(ops);
|
||||
ops->scheme = strdup(uri);
|
||||
}
|
||||
|
||||
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
||||
info->plugin_memory_allocate = plugin_memory_allocate;
|
||||
info->plugin_memory_free = plugin_memory_free;
|
||||
info->num_schemes = 2;
|
||||
info->ops = static_cast<TF_FilesystemPluginOps*>(
|
||||
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
|
||||
ProvideFilesystemSupportFor(&info->ops[0], "");
|
||||
ProvideFilesystemSupportFor(&info->ops[1], "file");
|
||||
}
|
@ -18,19 +18,36 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/kernels.h"
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
struct MyCustomKernel {
|
||||
bool created;
|
||||
|
@ -133,7 +133,7 @@ TEST(OpsTest, TestShapeInference_VectorizeFunction) {
|
||||
|
||||
TEST(OpsTest, AttributeAccessors) {
|
||||
TF_OpDefinitionBuilder* builder =
|
||||
TF_NewOpDefinitionBuilder("AttributeAccesorsOp");
|
||||
TF_NewOpDefinitionBuilder("AttributeAccessorsOp");
|
||||
TF_OpDefinitionBuilderAddAttr(builder, "foo1: int >= 2");
|
||||
TF_OpDefinitionBuilderAddAttr(builder, "foo2: string=\"my string\"");
|
||||
TF_OpDefinitionBuilderSetIsCommutative(builder, true);
|
||||
@ -151,7 +151,7 @@ TEST(OpsTest, AttributeAccessors) {
|
||||
op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
|
||||
bool found = false;
|
||||
for (const auto& op : op_list.op()) {
|
||||
if (op.name() == "AttributeAccesorsOp") {
|
||||
if (op.name() == "AttributeAccessorsOp") {
|
||||
ASSERT_TRUE(op.is_commutative());
|
||||
ASSERT_TRUE(op.is_aggregate());
|
||||
ASSERT_TRUE(op.allows_uninitialized_input());
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
@ -103,49 +105,35 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
|
||||
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
|
||||
}
|
||||
|
||||
TF_Tensor* ret =
|
||||
new TF_Tensor{Tensor(static_cast<tensorflow::DataType>(dtype),
|
||||
tensorflow::TensorShape(dimvec), buf)};
|
||||
// TODO(gjn): Make the choice of interface a compile-time configuration.
|
||||
tensorflow::TensorInterface ret(
|
||||
Tensor(static_cast<tensorflow::DataType>(dtype),
|
||||
tensorflow::TensorShape(dimvec), buf));
|
||||
buf->Unref();
|
||||
size_t elem_size = TF_DataTypeSize(dtype);
|
||||
if (elem_size > 0 && len < (elem_size * ret->tensor.NumElements())) {
|
||||
delete ret;
|
||||
if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
|
||||
return nullptr;
|
||||
}
|
||||
return ret;
|
||||
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(ret)};
|
||||
}
|
||||
|
||||
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) {
|
||||
// It is safe to move the Tensor if and only if we own the unique reference to
|
||||
// it. In that case, we might as well not delete and reallocate, but a future
|
||||
// implementation might need to do so.
|
||||
TensorBuffer* buf = tensorflow::TensorCApi::Buffer(tensor->tensor);
|
||||
if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
|
||||
buf->OwnsMemory()) {
|
||||
return tensor;
|
||||
}
|
||||
return nullptr;
|
||||
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* t) {
|
||||
return t->tensor->CanMove() ? t : nullptr;
|
||||
}
|
||||
|
||||
void TF_DeleteTensor(TF_Tensor* t) { delete t; }
|
||||
|
||||
TF_DataType TF_TensorType(const TF_Tensor* t) {
|
||||
return static_cast<TF_DataType>(t->tensor.dtype());
|
||||
}
|
||||
TF_DataType TF_TensorType(const TF_Tensor* t) { return t->tensor->Type(); }
|
||||
|
||||
int TF_NumDims(const TF_Tensor* t) { return t->tensor.dims(); }
|
||||
int TF_NumDims(const TF_Tensor* t) { return t->tensor->NumDims(); }
|
||||
|
||||
int64_t TF_Dim(const TF_Tensor* t, int dim_index) {
|
||||
return static_cast<int64_t>(t->tensor.dim_size(dim_index));
|
||||
return t->tensor->Dim(dim_index);
|
||||
}
|
||||
|
||||
size_t TF_TensorByteSize(const TF_Tensor* t) {
|
||||
return tensorflow::TensorCApi::Buffer(t->tensor)->size();
|
||||
}
|
||||
size_t TF_TensorByteSize(const TF_Tensor* t) { return t->tensor->ByteSize(); }
|
||||
|
||||
void* TF_TensorData(const TF_Tensor* t) {
|
||||
return tensorflow::TensorCApi::Buffer(t->tensor)->data();
|
||||
}
|
||||
void* TF_TensorData(const TF_Tensor* t) { return t->tensor->Data(); }
|
||||
|
||||
int64_t TF_TensorElementCount(const TF_Tensor* t) {
|
||||
int64_t result = 1;
|
||||
@ -160,15 +148,63 @@ void TF_TensorBitcastFrom(const TF_Tensor* from, TF_DataType type,
|
||||
TF_Tensor* to, const int64_t* new_dims,
|
||||
int num_new_dims, TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
Status cc_status(
|
||||
static_cast<tensorflow::TensorInterface*>(to->tensor.get())
|
||||
->BitcastFrom(*static_cast<const tensorflow::TensorInterface*>(
|
||||
from->tensor.get()),
|
||||
type, new_dims, num_new_dims));
|
||||
Set_TF_Status_from_Status(status, cc_status);
|
||||
}
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
bool TensorInterface::CanMove() const {
|
||||
// It is safe to move the Tensor if and only if we own the unique reference to
|
||||
// it. In that case, we might as well not delete and reallocate, but a future
|
||||
// implementation might need to do so.
|
||||
TensorBuffer* buf = tensorflow::TensorCApi::Buffer(tensor_);
|
||||
if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
|
||||
buf->OwnsMemory()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
TF_DataType TensorInterface::Type() const {
|
||||
return static_cast<TF_DataType>(tensor_.dtype());
|
||||
}
|
||||
|
||||
int TensorInterface::NumDims() const { return tensor_.dims(); }
|
||||
|
||||
int64_t TensorInterface::Dim(int dim_index) const {
|
||||
return static_cast<int64_t>(tensor_.dim_size(dim_index));
|
||||
}
|
||||
|
||||
int64_t TensorInterface::NumElements() const {
|
||||
return static_cast<int64_t>(tensor_.NumElements());
|
||||
}
|
||||
|
||||
size_t TensorInterface::ByteSize() const {
|
||||
return tensorflow::TensorCApi::Buffer(tensor_)->size();
|
||||
}
|
||||
|
||||
void* TensorInterface::Data() const {
|
||||
return tensorflow::TensorCApi::Buffer(tensor_)->data();
|
||||
}
|
||||
|
||||
Status TensorInterface::BitcastFrom(const TensorInterface& from,
|
||||
TF_DataType type, const int64_t* new_dims,
|
||||
int num_new_dims) {
|
||||
tensorflow::TensorShape s;
|
||||
for (int i = 0; i < num_new_dims; ++i) {
|
||||
s.AddDim(new_dims[i]);
|
||||
}
|
||||
Status cc_status(to->tensor.BitcastFrom(
|
||||
from->tensor, static_cast<tensorflow::DataType>(type), s));
|
||||
Set_TF_Status_from_Status(status, cc_status);
|
||||
return tensor_.BitcastFrom(from.tensor_,
|
||||
static_cast<tensorflow::DataType>(type), s);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
void StringEncode(const char* src, size_t src_len, char* dst) {
|
||||
dst = tensorflow::core::EncodeVarint64(dst, src_len);
|
||||
@ -277,12 +313,11 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) {
|
||||
return t;
|
||||
}
|
||||
if (src.dtype() != tensorflow::DT_STRING) {
|
||||
auto* result = new TF_Tensor();
|
||||
if (!result->tensor.CopyFrom(src, src.shape())) {
|
||||
delete result;
|
||||
Tensor tensor;
|
||||
if (!tensor.CopyFrom(src, src.shape())) {
|
||||
return nullptr;
|
||||
}
|
||||
return result;
|
||||
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(tensor)};
|
||||
}
|
||||
// DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly
|
||||
// encoded sequence of strings.
|
||||
@ -332,31 +367,35 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) {
|
||||
}
|
||||
|
||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
|
||||
if (src->tensor.dtype() == DT_RESOURCE) {
|
||||
if (src->tensor.dims() != 0) {
|
||||
return static_cast<const tensorflow::TensorInterface*>(src->tensor.get())
|
||||
->ToTensor(dst);
|
||||
}
|
||||
|
||||
Status TensorInterface::ToTensor(Tensor* dst) const {
|
||||
if (tensor_.dtype() == DT_RESOURCE) {
|
||||
if (tensor_.dims() != 0) {
|
||||
return InvalidArgument(
|
||||
"Malformed TF_RESOURCE tensor: expected a scalar, got a tensor with "
|
||||
"shape ",
|
||||
src->tensor.shape().DebugString());
|
||||
tensor_.shape().DebugString());
|
||||
}
|
||||
*dst = Tensor(tensorflow::DT_RESOURCE, src->tensor.shape());
|
||||
*dst = Tensor(tensorflow::DT_RESOURCE, tensor_.shape());
|
||||
if (!dst->scalar<tensorflow::ResourceHandle>()().ParseFromString(
|
||||
string(static_cast<const char*>(TF_TensorData(src)),
|
||||
TF_TensorByteSize(src)))) {
|
||||
string(static_cast<const char*>(Data()), ByteSize()))) {
|
||||
return InvalidArgument(
|
||||
"Malformed TF_RESOUCE tensor: unable to parse resource handle");
|
||||
"Malformed TF_RESOURCE tensor: unable to parse resource handle");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
if (src->tensor.dtype() != DT_STRING) {
|
||||
*dst = src->tensor;
|
||||
if (tensor_.dtype() != DT_STRING) {
|
||||
*dst = tensor_;
|
||||
return Status::OK();
|
||||
}
|
||||
// TF_STRING tensors require copying since Tensor class expects a sequence of
|
||||
// string objects.
|
||||
const tensorflow::int64 num_elements = src->tensor.NumElements();
|
||||
const char* input = reinterpret_cast<const char*>(TF_TensorData(src));
|
||||
const size_t src_size = TF_TensorByteSize(src);
|
||||
const tensorflow::int64 num_elements = tensor_.NumElements();
|
||||
const char* input = reinterpret_cast<const char*>(Data());
|
||||
const size_t src_size = ByteSize();
|
||||
if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) <
|
||||
num_elements) {
|
||||
return InvalidArgument(
|
||||
@ -365,7 +404,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
|
||||
const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
|
||||
const char* limit = input + src_size;
|
||||
|
||||
*dst = Tensor(src->tensor.dtype(), src->tensor.shape());
|
||||
*dst = Tensor(tensor_.dtype(), tensor_.shape());
|
||||
auto dstarray = dst->flat<tstring>();
|
||||
for (tensorflow::int64 i = 0; i < num_elements; ++i) {
|
||||
tensorflow::uint64 offset =
|
||||
@ -384,8 +423,8 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool TensorInterface::IsAligned() const { return tensor_.IsAligned(); }
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
bool TF_TensorIsAligned(const TF_Tensor* tensor) {
|
||||
return tensor->tensor.IsAligned();
|
||||
}
|
||||
bool TF_TensorIsAligned(const TF_Tensor* t) { return t->tensor->IsAligned(); }
|
||||
|
@ -16,9 +16,12 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
|
||||
#define TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
// Internal structures used by the C API. These are likely to change and should
|
||||
@ -28,7 +31,7 @@ limitations under the License.
|
||||
// passed to or returned from C functions *by pointer*. Otherwise, changes to
|
||||
// its internal structure will break the C API's binary interface.
|
||||
typedef struct TF_Tensor {
|
||||
::tensorflow::Tensor tensor;
|
||||
std::unique_ptr<AbstractTensorInterface> tensor;
|
||||
} TF_Tensor;
|
||||
|
||||
class TF_ManagedBuffer : public tensorflow::TensorBuffer {
|
||||
@ -83,4 +86,5 @@ void* allocate_tensor(const char* operation, size_t len, Allocator* allocator);
|
||||
// a different Allocator as `arg`.
|
||||
void deallocate_buffer(void* data, size_t len, void* arg);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
|
||||
|
@ -41,6 +41,16 @@ filegroup(
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"training/coordinator.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradients",
|
||||
srcs = [
|
||||
|
@ -96,7 +96,7 @@ class SymbolicGradientBuilder {
|
||||
// Used to identify nodes at which to stop backprop.
|
||||
std::unordered_set<int> GetStopBackpropNodes(
|
||||
const std::vector<bool>& reachable_nodes,
|
||||
const std::unordered_set<int>& output_nodes);
|
||||
const std::unordered_set<int>& output_nodes) const;
|
||||
|
||||
const Scope& scope_;
|
||||
const ops::GradOpRegistry* registry_;
|
||||
@ -190,7 +190,7 @@ std::vector<bool> SymbolicGradientBuilder::GetReachableNodes() {
|
||||
|
||||
std::unordered_set<int> SymbolicGradientBuilder::GetStopBackpropNodes(
|
||||
const std::vector<bool>& reachable_nodes,
|
||||
const std::unordered_set<int>& output_nodes) {
|
||||
const std::unordered_set<int>& output_nodes) const {
|
||||
// Output nodes that get transitively consumed by other `outputs_` are stored
|
||||
// in `internal_outputs`.
|
||||
std::unordered_set<int> internal_outputs;
|
||||
@ -346,8 +346,8 @@ Status SymbolicGradientBuilder::SumGradients(const Output& src, Output* grad) {
|
||||
"Unable to find backprop list for node.id ", src.node()->name());
|
||||
}
|
||||
const auto& grads = iter->second;
|
||||
// Filter any backproped 'NoGradient' Outputs from 'grads' (if needed).
|
||||
// Return any valid backproped gradients that remain after filtering,
|
||||
// Filter any backpropped 'NoGradient' Outputs from 'grads' (if needed).
|
||||
// Return any valid backpropped gradients that remain after filtering,
|
||||
// or 'NoGradient' otherwise.
|
||||
std::vector<Output> grads_to_keep;
|
||||
for (const Output& o : grads) {
|
||||
@ -519,7 +519,7 @@ Status SymbolicGradientBuilder::AddGradients() {
|
||||
// Backprop along the in edges.
|
||||
// TODO(andydavis) Find cleaner way to map each grad output returned by
|
||||
// gradient function to the src node/output to which it should be
|
||||
// backproped. Maybe grad functions can return a vector of Output pairs to
|
||||
// backpropped. Maybe grad functions can return a vector of Output pairs to
|
||||
// make this association explicit.
|
||||
size_t dx_index = 0;
|
||||
for (const Edge* e : n->in_edges()) {
|
||||
|
@ -64,7 +64,7 @@ bool IsZero(const Scope& scope, const Output& grad) {
|
||||
// Multiply after broadcasting vec to match dimensions of mat.
|
||||
// Args:
|
||||
// vec: A 1-D tensor of dimension [D0]
|
||||
// mat: A 2-D tensor of dimesnion [D0, D1]
|
||||
// mat: A 2-D tensor of dimension [D0, D1]
|
||||
//
|
||||
// Returns:
|
||||
// A tensor of dimension [D0, D1], the result fo vec * mat.
|
||||
|
@ -124,13 +124,12 @@ cc_library(
|
||||
hdrs = ["bundle_v2.h"],
|
||||
deps = [
|
||||
":constants",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
] + if_not_mobile([
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:strcat",
|
||||
"//tensorflow/core/util/tensor_bundle",
|
||||
]),
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
|
@ -1,5 +1,6 @@
|
||||
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
|
||||
load("//tensorflow/core/platform:build_config.bzl", "if_llvm_aarch64_available")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:private"],
|
||||
@ -27,9 +28,15 @@ cc_library(
|
||||
"compile.h",
|
||||
"flags.h",
|
||||
],
|
||||
defines = if_llvm_aarch64_available(["TF_LLVM_AARCH64_AVAILABLE=1"]),
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
deps = [
|
||||
":aot_only_var_handle_op",
|
||||
":embedded_protocol_buffers",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"//tensorflow/compiler/tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
|
||||
@ -53,10 +60,13 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:target",
|
||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||
] + if_llvm_aarch64_available([
|
||||
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
|
||||
]),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
@ -86,6 +96,19 @@ tf_cc_binary(
|
||||
deps = [":tfcompile_main"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "llvm_targets",
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
deps = [
|
||||
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:target",
|
||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||
] + if_llvm_aarch64_available([
|
||||
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
|
||||
]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfcompile_main",
|
||||
srcs = ["tfcompile_main.cc"],
|
||||
@ -104,11 +127,6 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:aarch64_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:target",
|
||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||
],
|
||||
)
|
||||
|
||||
@ -214,8 +232,13 @@ cc_library(
|
||||
cc_library(
|
||||
name = "aot_only_var_handle_op",
|
||||
srcs = ["aot_only_var_handle_op.cc"],
|
||||
hdrs = ["aot_only_var_handle_op.h"],
|
||||
visibility = [
|
||||
"//tensorflow/compiler/tf2xla:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -13,9 +13,12 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/aot/aot_only_var_handle_op.h"
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -51,6 +54,31 @@ void XlaAotOnlyVarHandleOp::Compile(XlaOpKernelContext* context) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
REGISTER_XLA_OP(Name("VarHandleOp").CompilationOnly(), XlaAotOnlyVarHandleOp);
|
||||
REGISTER_OP(tfcompile::kXlaAotOnlyVarHandleOp)
|
||||
.Doc(R"doc(
|
||||
Internal VarHandleOp registration used for XLA AOT compilation.
|
||||
)doc")
|
||||
.Attr("container: string = ''")
|
||||
.Attr("shared_name: string = ''")
|
||||
.Attr("dtype: type")
|
||||
.Attr("shape: shape")
|
||||
.Output("resource: resource")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
c->set_output(0, c->Scalar());
|
||||
DataType t;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t));
|
||||
PartialTensorShape p;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("shape", &p));
|
||||
shape_inference::ShapeHandle s;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s));
|
||||
c->set_output_handle_shapes_and_types(
|
||||
0, std::vector<shape_inference::ShapeAndType>{{s, t}});
|
||||
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_XLA_OP(Name(tfcompile::kXlaAotOnlyVarHandleOp).CompilationOnly(),
|
||||
XlaAotOnlyVarHandleOp);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
27
tensorflow/compiler/aot/aot_only_var_handle_op.h
Normal file
27
tensorflow/compiler/aot/aot_only_var_handle_op.h
Normal file
@ -0,0 +1,27 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_
|
||||
#define TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
|
||||
static constexpr const char* const kXlaAotOnlyVarHandleOp =
|
||||
"_XlaAotOnlyVarHandleOp";
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_
|
@ -74,16 +74,16 @@ void DumpStatsToStdout(const Stats& stats) {
|
||||
const int kBufSize = 1000;
|
||||
char buf[kBufSize];
|
||||
snprintf(buf, kBufSize, "Mean with %2.0f%% trimmed:", trim_ratio * 100);
|
||||
const string label_trimmed(buf);
|
||||
std::string label_trimmed(buf);
|
||||
snprintf(buf, kBufSize, "Mean of %2.0f%% best:", best_ratio * 100);
|
||||
const string label_best(buf);
|
||||
std::vector<std::pair<string, double>> groups = {
|
||||
std::string label_best(buf);
|
||||
std::vector<std::pair<std::string, double>> groups = {
|
||||
{"Best:", sorted_us.front()},
|
||||
{"Worst:", sorted_us.back()},
|
||||
{"Median:", sorted_us[count_us / 2]},
|
||||
{"Mean:", sum_us / count_us},
|
||||
{label_trimmed, sum_us_trimmed / count_us_trimmed},
|
||||
{label_best, sum_us_best / count_us_best},
|
||||
{std::move(label_trimmed), sum_us_trimmed / count_us_trimmed},
|
||||
{std::move(label_best), sum_us_best / count_us_best},
|
||||
};
|
||||
int max_label_size = 0;
|
||||
double max_us = 0;
|
||||
@ -102,7 +102,7 @@ void DumpStatsToStdout(const Stats& stats) {
|
||||
}
|
||||
// Dump stats out.
|
||||
printf("Benchmark ran %zu iterations over %lld us\n", count_us,
|
||||
stats.total_us);
|
||||
static_cast<long long>(stats.total_us)); // NOLINT
|
||||
for (const auto& g : groups) {
|
||||
printf(" %-*s %*.3f us\n", max_label_size, g.first.c_str(), max_digits + 4,
|
||||
g.second);
|
||||
@ -114,7 +114,8 @@ void Benchmark(const Options& options, const BenchmarkFn& fn, Stats* stats) {
|
||||
const int64 max_us = (options.max_micros <= 0 && options.max_iters <= 0)
|
||||
? Options::kDefaultMicros
|
||||
: options.max_micros;
|
||||
printf("Running benchmark for %lld us\n", max_us);
|
||||
// NOLINTNEXTLINE
|
||||
printf("Running benchmark for %lld us\n", static_cast<long long>(max_us));
|
||||
const int64 start_us = NowMicros();
|
||||
int64 iters = 0;
|
||||
while (true) {
|
||||
|
@ -423,8 +423,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
|
||||
GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
|
||||
const string include_xla_data_proto =
|
||||
opts.gen_program_shape
|
||||
?
|
||||
R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
|
||||
? R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
|
||||
: "";
|
||||
|
||||
const string include_hlo_profile_printer_data_proto =
|
||||
|
@ -20,6 +20,9 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/base/call_once.h"
|
||||
#include "llvm-c/Target.h"
|
||||
#include "tensorflow/compiler/aot/codegen.h"
|
||||
#include "tensorflow/compiler/aot/flags.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
@ -90,7 +93,7 @@ Status CompileXla(xla::CompileOnlyClient* client,
|
||||
|
||||
} // namespace
|
||||
|
||||
Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
||||
Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
|
||||
const MainFlags& flags, CompileResult* compile_result) {
|
||||
// Converts the graph into an XLA computation, and compiles the
|
||||
// computation.
|
||||
@ -108,8 +111,8 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
||||
if (!flags.mlir_components.empty()) {
|
||||
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertGraphDefToXla(graph_def, config, client, &computation));
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config,
|
||||
client, &computation));
|
||||
}
|
||||
if (!flags.out_session_module.empty()) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
|
||||
@ -132,5 +135,96 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
||||
return CompileXla(client, computation, aot_opts, compile_result);
|
||||
}
|
||||
|
||||
static Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
|
||||
if (absl::EndsWith(fname, ".pbtxt")) {
|
||||
return ReadTextProto(Env::Default(), fname, proto);
|
||||
} else {
|
||||
return ReadBinaryProto(Env::Default(), fname, proto);
|
||||
}
|
||||
}
|
||||
|
||||
static absl::once_flag targets_init;
|
||||
|
||||
static void InitializeTargets() {
|
||||
// Initialize all LLVM targets so we can cross compile.
|
||||
#if TF_LLVM_AARCH64_AVAILABLE
|
||||
LLVMInitializeAArch64Target();
|
||||
LLVMInitializeAArch64TargetInfo();
|
||||
LLVMInitializeAArch64TargetMC();
|
||||
LLVMInitializeAArch64AsmPrinter();
|
||||
#endif
|
||||
LLVMInitializeARMTarget();
|
||||
LLVMInitializeARMTargetInfo();
|
||||
LLVMInitializeARMTargetMC();
|
||||
LLVMInitializeARMAsmPrinter();
|
||||
LLVMInitializePowerPCTarget();
|
||||
LLVMInitializePowerPCTargetInfo();
|
||||
LLVMInitializePowerPCTargetMC();
|
||||
LLVMInitializePowerPCAsmPrinter();
|
||||
LLVMInitializeX86Target();
|
||||
LLVMInitializeX86TargetInfo();
|
||||
LLVMInitializeX86TargetMC();
|
||||
LLVMInitializeX86AsmPrinter();
|
||||
}
|
||||
|
||||
Status Main(const MainFlags& flags) {
|
||||
absl::call_once(targets_init, &InitializeTargets);
|
||||
|
||||
// Process config.
|
||||
tf2xla::Config config;
|
||||
if (flags.config.empty()) {
|
||||
return errors::InvalidArgument("Must specify --config");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
|
||||
TF_RETURN_IF_ERROR(ValidateConfig(config));
|
||||
if (flags.dump_fetch_nodes) {
|
||||
std::set<string> nodes;
|
||||
for (const tf2xla::Fetch& fetch : config.fetch()) {
|
||||
nodes.insert(fetch.id().node_name());
|
||||
}
|
||||
std::cout << absl::StrJoin(nodes, ",");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Read and initialize the graph.
|
||||
if (flags.graph.empty()) {
|
||||
return errors::InvalidArgument("Must specify --graph");
|
||||
}
|
||||
GraphDef graph_def;
|
||||
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
|
||||
CompileResult compile_result;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CompileGraph(std::move(graph_def), config, flags, &compile_result));
|
||||
|
||||
// Write output files.
|
||||
Env* env = Env::Default();
|
||||
const std::vector<char>& obj = compile_result.aot->object_file_data();
|
||||
TF_RETURN_IF_ERROR(
|
||||
WriteStringToFile(env, flags.out_function_object,
|
||||
absl::string_view(obj.data(), obj.size())));
|
||||
CodegenOpts codegen_opts;
|
||||
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
|
||||
codegen_opts.gen_program_shape = flags.gen_program_shape;
|
||||
codegen_opts.target_triple = flags.target_triple;
|
||||
if (flags.cpp_class.empty()) {
|
||||
return errors::InvalidArgument("Must specify --cpp_class");
|
||||
}
|
||||
codegen_opts.gen_hlo_profile_printer_data =
|
||||
xla::GetDebugOptionsFromFlags().xla_hlo_profile();
|
||||
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
|
||||
&codegen_opts.namespaces));
|
||||
|
||||
MetadataResult metadata_result;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GenerateMetadata(codegen_opts, compile_result, &metadata_result));
|
||||
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object,
|
||||
metadata_result.object_file_data));
|
||||
string header;
|
||||
TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result,
|
||||
metadata_result, &header));
|
||||
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
@ -42,9 +42,12 @@ struct CompileResult {
|
||||
// that performs the graph operations.
|
||||
//
|
||||
// The XLA compilation options are specified in the flags.
|
||||
Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
||||
Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
|
||||
const MainFlags& flags, CompileResult* compile_result);
|
||||
|
||||
// The full compilation method, for reuse in a library setting.
|
||||
Status Main(const MainFlags& flags);
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -25,6 +25,7 @@ namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
|
||||
// Flags for the tfcompile binary. See *.cc file for descriptions.
|
||||
|
||||
struct MainFlags {
|
||||
string graph;
|
||||
string config;
|
||||
|
@ -25,6 +25,7 @@ test_suite(
|
||||
":test_graph_tfmatmulandadd_test",
|
||||
":test_graph_tfsplits_test",
|
||||
":test_graph_tftop_k_test",
|
||||
":test_graph_tfvariable_readonly_test",
|
||||
":test_graph_tfvariable_sequential_updates_test",
|
||||
":test_graph_tfvariable_test",
|
||||
":tfcompile_test",
|
||||
@ -73,6 +74,7 @@ genrule(
|
||||
"test_graph_tfsplits.pb",
|
||||
"test_graph_tftop_k.pb",
|
||||
"test_graph_tfvariable.pb",
|
||||
"test_graph_tfvariable_readonly.pb",
|
||||
"test_graph_tfvariable_sequential_updates.pb",
|
||||
],
|
||||
# Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any
|
||||
@ -238,6 +240,17 @@ tf_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfvariable_readonly",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfvariable_readonly.config.pbtxt",
|
||||
cpp_class = "VariableReadonlyComp",
|
||||
graph = "test_graph_tfvariable_readonly.pb",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfvariable_sequential_updates",
|
||||
testonly = 1,
|
||||
@ -269,6 +282,7 @@ tf_cc_test(
|
||||
":test_graph_tfsplits",
|
||||
":test_graph_tftop_k",
|
||||
":test_graph_tfvariable",
|
||||
":test_graph_tfvariable_readonly",
|
||||
":test_graph_tfvariable_sequential_updates",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
@ -323,6 +337,42 @@ tf_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfcond_mlir_bridge",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfcond.config.pbtxt",
|
||||
cpp_class = "CondComp",
|
||||
graph = "test_graph_tfcond.pb",
|
||||
mlir_components = "Bridge",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfassert_eq_mlir_bridge",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfassert_eq.config.pbtxt",
|
||||
cpp_class = "AssertComp",
|
||||
graph = "test_graph_tfassert_eq.pb",
|
||||
mlir_components = "Bridge",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfgather_mlir_bridge",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfgather.config.pbtxt",
|
||||
cpp_class = "GatherComp",
|
||||
graph = "test_graph_tfgather.pb",
|
||||
mlir_components = "Bridge",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfmatmul_mlir_bridge",
|
||||
testonly = 1,
|
||||
@ -361,6 +411,66 @@ tf_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfsplits_mlir_bridge",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfsplits.config.pbtxt",
|
||||
cpp_class = "SplitsComp",
|
||||
graph = "test_graph_tfsplits.pb",
|
||||
mlir_components = "Bridge",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tftop_k_mlir_bridge",
|
||||
testonly = 1,
|
||||
config = "test_graph_tftop_k.config.pbtxt",
|
||||
cpp_class = "TopKComp",
|
||||
graph = "test_graph_tftop_k.pb",
|
||||
mlir_components = "Bridge",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfvariable_readonly_mlir_bridge",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfvariable_readonly.config.pbtxt",
|
||||
cpp_class = "VariableReadonlyComp",
|
||||
graph = "test_graph_tfvariable_readonly.pb",
|
||||
mlir_components = "Bridge",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfvariable_mlir_bridge",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfvariable.config.pbtxt",
|
||||
cpp_class = "VariableComp",
|
||||
graph = "test_graph_tfvariable.pb",
|
||||
mlir_components = "Bridge",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfvariable_sequential_updates_mlir_bridge",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfvariable_sequential_updates.config.pbtxt",
|
||||
cpp_class = "VariableSequentialUpdatesComp",
|
||||
graph = "test_graph_tfvariable_sequential_updates.pb",
|
||||
mlir_components = "Bridge",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tfcompile_test_mlir_bridge",
|
||||
srcs = ["tfcompile_test.cc"],
|
||||
@ -372,9 +482,17 @@ tf_cc_test(
|
||||
":test_graph_tfadd_mlir_bridge",
|
||||
":test_graph_tfadd_with_ckpt_mlir_bridge",
|
||||
":test_graph_tfadd_with_ckpt_saver_mlir_bridge",
|
||||
":test_graph_tfassert_eq_mlir_bridge",
|
||||
":test_graph_tfcond_mlir_bridge",
|
||||
":test_graph_tfgather_mlir_bridge",
|
||||
":test_graph_tfmatmul_mlir_bridge",
|
||||
":test_graph_tfmatmulandadd_mlir_bridge",
|
||||
":test_graph_tfmatmulandadd_with_profiling_mlir_bridge",
|
||||
":test_graph_tfsplits_mlir_bridge",
|
||||
":test_graph_tftop_k_mlir_bridge",
|
||||
":test_graph_tfvariable_mlir_bridge",
|
||||
":test_graph_tfvariable_readonly_mlir_bridge",
|
||||
":test_graph_tfvariable_sequential_updates_mlir_bridge",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
|
@ -34,6 +34,7 @@ from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import variables
|
||||
@ -153,11 +154,21 @@ def tftop_k(_):
|
||||
array_ops.identity(output[1], name='indices')
|
||||
|
||||
|
||||
def tfvariable(_):
|
||||
def tfvariable_readonly(_):
|
||||
x = variables.Variable(1000.0, name='x')
|
||||
old_x = x.value()
|
||||
with ops.control_dependencies([old_x]):
|
||||
new_x = x.assign_add(42.0)
|
||||
new_value = math_ops.add(old_x, 42.0)
|
||||
array_ops.identity(new_value, name='result')
|
||||
|
||||
|
||||
# TODO(b/147908587): Change x and the two constants back to have a scalar shape
|
||||
# when the bug is fixed.
|
||||
def tfvariable(_):
|
||||
x = variables.Variable([1000.0], name='x', shape=[1])
|
||||
old_x = x.value()
|
||||
with ops.control_dependencies([old_x]):
|
||||
new_x = x.assign_add([42.0])
|
||||
array_ops.stack([old_x, new_x], name='result')
|
||||
|
||||
|
||||
@ -184,6 +195,7 @@ def write_graph(build_graph, out_dir):
|
||||
|
||||
|
||||
def main(_):
|
||||
control_flow_util.enable_control_flow_v2()
|
||||
write_graph(tfadd, FLAGS.out_dir)
|
||||
write_graph(tfadd_with_ckpt, FLAGS.out_dir)
|
||||
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
|
||||
@ -196,6 +208,7 @@ def main(_):
|
||||
write_graph(tfsplits, FLAGS.out_dir)
|
||||
write_graph(tftop_k, FLAGS.out_dir)
|
||||
write_graph(tfvariable, FLAGS.out_dir)
|
||||
write_graph(tfvariable_readonly, FLAGS.out_dir)
|
||||
write_graph(tfvariable_sequential_updates, FLAGS.out_dir)
|
||||
|
||||
|
||||
|
@ -0,0 +1,12 @@
|
||||
# Text form of tensorflow.tf2xla.Config proto.
|
||||
fetch {
|
||||
id { node_name: "result" }
|
||||
}
|
||||
|
||||
variable {
|
||||
node_name: "x"
|
||||
shape {
|
||||
}
|
||||
type: DT_FLOAT
|
||||
readonly: true
|
||||
}
|
@ -30,9 +30,17 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfcond_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfgather_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tftop_k_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates_mlir_bridge.h"
|
||||
#else
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
|
||||
@ -47,6 +55,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates.h"
|
||||
#endif
|
||||
|
||||
@ -167,8 +176,6 @@ TEST(TFCompileTest, AddWithCkptSaver) {
|
||||
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
|
||||
}
|
||||
|
||||
// TODO(bixia): the following tests failed with MLIR bridge.
|
||||
#if !defined(ENABLE_MLIR_BRIDGE_TEST)
|
||||
TEST(TFCompileTest, Cond) {
|
||||
CondComp cond;
|
||||
EXPECT_EQ(cond.arg0_data(), cond.arg_data(0));
|
||||
@ -233,7 +240,6 @@ TEST(TFCompileTest, Gather) {
|
||||
EXPECT_EQ(gather_const.result0_data(), gather.results()[0]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST(TFCompileTest, MatMul2) {
|
||||
Eigen::ThreadPool tp(2);
|
||||
@ -439,6 +445,7 @@ TEST(TFCompileTest, Function) {
|
||||
EXPECT_EQ(add_fn.result0_data()[0], 3);
|
||||
EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]);
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST(TFCompileTest, Splits) {
|
||||
Eigen::ThreadPool tp(1);
|
||||
@ -492,6 +499,20 @@ TEST(TFCompileTest, TopK) {
|
||||
EXPECT_EQ(expected_indices[1], fn.result1(1));
|
||||
}
|
||||
|
||||
TEST(TFCompileTest, VariableReadonly) {
|
||||
Eigen::ThreadPool tp(1);
|
||||
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
|
||||
|
||||
VariableReadonlyComp fn;
|
||||
float x = 23;
|
||||
fn.set_var_x_data(&x);
|
||||
|
||||
fn.set_thread_pool(&device);
|
||||
fn.Run();
|
||||
EXPECT_EQ(fn.result0(), 65);
|
||||
EXPECT_EQ(fn.var_x(), 23);
|
||||
}
|
||||
|
||||
TEST(TFCompileTest, Variable) {
|
||||
Eigen::ThreadPool tp(1);
|
||||
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
|
||||
@ -665,6 +686,11 @@ TEST(TFCompileTest, HloProfiling) {
|
||||
/*clock_rate_ghz=*/1.0);
|
||||
VLOG(1) << "Original HLO profile string:\n" << hlo_profile_as_string;
|
||||
|
||||
// Replace Arg_n with argn when the MLIR bridge is used.
|
||||
#if defined(ENABLE_MLIR_BRIDGE_TEST)
|
||||
RE2::GlobalReplace(&hlo_profile_as_string, "(Arg_)([0-9].)", "arg\\2");
|
||||
#endif
|
||||
|
||||
// Strip away identifier details from the profile string to avoid this test
|
||||
// being a change detector for xla internals. Identifiers such as '%dot.0.7'
|
||||
// just become '%dot'.
|
||||
@ -690,7 +716,6 @@ TEST(TFCompileTest, HloProfiling) {
|
||||
IsSupersetOf({header, total_cycles_profile_line, dot_profile_line,
|
||||
add_profile_line, tuple_profile_line}));
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace
|
||||
} // namespace tfcompile
|
||||
|
@ -21,7 +21,6 @@ limitations under the License.
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm-c/Target.h"
|
||||
#include "tensorflow/compiler/aot/codegen.h"
|
||||
#include "tensorflow/compiler/aot/compile.h"
|
||||
#include "tensorflow/compiler/aot/flags.h"
|
||||
@ -56,88 +55,6 @@ const char kUsageHeader[] =
|
||||
"--cpp_class=\"mynamespace::MyComputation\"\n"
|
||||
"\n";
|
||||
|
||||
Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
|
||||
if (absl::EndsWith(fname, ".pbtxt")) {
|
||||
return ReadTextProto(Env::Default(), fname, proto);
|
||||
} else {
|
||||
return ReadBinaryProto(Env::Default(), fname, proto);
|
||||
}
|
||||
}
|
||||
|
||||
Status Main(const MainFlags& flags) {
|
||||
// Initialize all LLVM targets so we can cross compile.
|
||||
LLVMInitializeAArch64Target();
|
||||
LLVMInitializeAArch64TargetInfo();
|
||||
LLVMInitializeAArch64TargetMC();
|
||||
LLVMInitializeAArch64AsmPrinter();
|
||||
LLVMInitializeARMTarget();
|
||||
LLVMInitializeARMTargetInfo();
|
||||
LLVMInitializeARMTargetMC();
|
||||
LLVMInitializeARMAsmPrinter();
|
||||
LLVMInitializePowerPCTarget();
|
||||
LLVMInitializePowerPCTargetInfo();
|
||||
LLVMInitializePowerPCTargetMC();
|
||||
LLVMInitializePowerPCAsmPrinter();
|
||||
LLVMInitializeX86Target();
|
||||
LLVMInitializeX86TargetInfo();
|
||||
LLVMInitializeX86TargetMC();
|
||||
LLVMInitializeX86AsmPrinter();
|
||||
|
||||
// Process config.
|
||||
tf2xla::Config config;
|
||||
if (flags.config.empty()) {
|
||||
return errors::InvalidArgument("Must specify --config");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
|
||||
TF_RETURN_IF_ERROR(ValidateConfig(config));
|
||||
if (flags.dump_fetch_nodes) {
|
||||
std::set<string> nodes;
|
||||
for (const tf2xla::Fetch& fetch : config.fetch()) {
|
||||
nodes.insert(fetch.id().node_name());
|
||||
}
|
||||
std::cout << absl::StrJoin(nodes, ",");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Read and initialize the graph.
|
||||
if (flags.graph.empty()) {
|
||||
return errors::InvalidArgument("Must specify --graph");
|
||||
}
|
||||
GraphDef graph_def;
|
||||
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
|
||||
CompileResult compile_result;
|
||||
TF_RETURN_IF_ERROR(CompileGraph(graph_def, config, flags, &compile_result));
|
||||
|
||||
// Write output files.
|
||||
Env* env = Env::Default();
|
||||
const std::vector<char>& obj = compile_result.aot->object_file_data();
|
||||
TF_RETURN_IF_ERROR(
|
||||
WriteStringToFile(env, flags.out_function_object,
|
||||
absl::string_view(obj.data(), obj.size())));
|
||||
CodegenOpts codegen_opts;
|
||||
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
|
||||
codegen_opts.gen_program_shape = flags.gen_program_shape;
|
||||
codegen_opts.target_triple = flags.target_triple;
|
||||
if (flags.cpp_class.empty()) {
|
||||
return errors::InvalidArgument("Must specify --cpp_class");
|
||||
}
|
||||
codegen_opts.gen_hlo_profile_printer_data =
|
||||
xla::GetDebugOptionsFromFlags().xla_hlo_profile();
|
||||
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
|
||||
&codegen_opts.namespaces));
|
||||
|
||||
MetadataResult metadata_result;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GenerateMetadata(codegen_opts, compile_result, &metadata_result));
|
||||
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object,
|
||||
metadata_result.object_file_data));
|
||||
string header;
|
||||
TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result,
|
||||
metadata_result, &header));
|
||||
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // end namespace tfcompile
|
||||
} // end namespace tensorflow
|
||||
|
||||
|
@ -2,6 +2,7 @@ load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "tf_cc_
|
||||
load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps")
|
||||
load("//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
|
||||
load("//tensorflow/core/platform:build_config_root.bzl", "tf_cuda_tests_tags")
|
||||
|
||||
package(
|
||||
default_visibility = [":internal"],
|
||||
@ -56,6 +57,7 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":jit_compilation_passes",
|
||||
":xla_kernel_creator", # buildcleaner: keep
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
@ -69,6 +71,7 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = if_cuda_or_rocm([
|
||||
":jit_compilation_passes",
|
||||
":xla_kernel_creator", # buildcleaner: keep
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||
@ -77,19 +80,6 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_mlir_gpu_jit",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = if_cuda_or_rocm([
|
||||
":jit_compilation_passes",
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||
"//tensorflow/compiler/xla/service:mlir_gpu_plugin",
|
||||
]),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_cpu_device",
|
||||
srcs = ["xla_cpu_device.cc"],
|
||||
@ -115,6 +105,7 @@ cc_library(
|
||||
srcs = ["xla_gpu_device.cc"],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
":flags",
|
||||
":jit_compilation_passes",
|
||||
":xla_device",
|
||||
":xla_kernel_creator", # buildcleaner: keep
|
||||
@ -123,6 +114,7 @@ cc_library(
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:gpu_init",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -167,7 +159,9 @@ XLA_DEVICE_DEPS = [
|
||||
":common",
|
||||
":xla_launch_util",
|
||||
":xla_tensor",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"//tensorflow/compiler/jit/ops:xla_ops",
|
||||
@ -260,13 +254,26 @@ cc_library(
|
||||
}),
|
||||
)
|
||||
|
||||
# Internal targets below this point.
|
||||
|
||||
cc_library(
|
||||
name = "flags",
|
||||
srcs = ["flags.cc"],
|
||||
hdrs = ["flags.h"],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:parse_flags_from_env",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
# Header-only version of "flags" library, for linking from the shared object
|
||||
# without ODR violations.
|
||||
cc_library(
|
||||
name = "flags_headers_only",
|
||||
hdrs = ["flags.h"],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:parse_flags_from_env",
|
||||
"//tensorflow/core:framework_internal",
|
||||
@ -286,6 +293,8 @@ cc_library(
|
||||
visibility = [":friends"],
|
||||
)
|
||||
|
||||
# Internal targets below this point.
|
||||
|
||||
cc_library(
|
||||
name = "xla_launch_util",
|
||||
srcs = ["xla_launch_util.cc"],
|
||||
@ -407,6 +416,7 @@ cc_library(
|
||||
"xla_kernel_creator.h",
|
||||
],
|
||||
deps = [
|
||||
":flags",
|
||||
":jit_compilation_passes",
|
||||
":xla_kernel_creator_util",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
@ -635,6 +645,7 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
@ -766,7 +777,7 @@ tf_cc_test(
|
||||
],
|
||||
# TODO(b/141643254) Re-enable msan after fixing use-of-uninitialized-value
|
||||
# error.
|
||||
tags = ["nomsan"],
|
||||
tags = ["nomsan"] + tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
":common",
|
||||
":compilation_passes",
|
||||
|
@ -1584,7 +1584,6 @@ DeadnessAnalysis::~DeadnessAnalysis() {}
|
||||
absl::flat_hash_map<TensorId, string, TensorId::Hasher>
|
||||
DeadnessAnalysisImpl::PredicateMapAsString() const {
|
||||
absl::flat_hash_map<TensorId, string, TensorId::Hasher> result;
|
||||
std::vector<TensorId> tensor_ids;
|
||||
for (const auto& kv_pair : predicate_map_) {
|
||||
CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second);
|
||||
}
|
||||
|
@ -374,39 +374,6 @@ xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
|
||||
return new_def;
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NOINLINE Status
|
||||
ValidateOutsideCompilationCallNode(Node* call_node) {
|
||||
// DT_INT64 as input/output for outside compilation is not supported yet:
|
||||
// b/120809951.
|
||||
for (const Edge* e : call_node->in_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
continue;
|
||||
}
|
||||
DataType dtype = e->src()->output_type(e->src_output());
|
||||
if (dtype == DT_INT64) {
|
||||
return errors::Unimplemented(
|
||||
"int64 input for outside compilation is not supported yet: "
|
||||
"b/120809951. Please cast output of node ",
|
||||
e->src()->DebugString(),
|
||||
" to int32 before feeding it into outside compilation.");
|
||||
}
|
||||
}
|
||||
for (const Edge* e : call_node->out_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
continue;
|
||||
}
|
||||
DataType dtype = e->dst()->input_type(e->dst_input());
|
||||
if (dtype == DT_INT64) {
|
||||
return errors::Unimplemented(
|
||||
"int64 output for outside compilation is not supported yet: "
|
||||
"b/120809951. Please cast input of node ",
|
||||
e->dst()->DebugString(),
|
||||
" to int32 before returning it from outside compilation.");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Replace outside compilation function call node with XlaHostCompute node.
|
||||
TF_ATTRIBUTE_NOINLINE xla::StatusOr<Node*> ReplaceOutsideCompilationCallNode(
|
||||
Graph* g, Node* call_node, const std::map<string, int>& host_compute_core,
|
||||
@ -2384,7 +2351,6 @@ Status ExtractOutsideCompilationForFunction(
|
||||
}
|
||||
std::map<string, Node*> host_compute_nodes;
|
||||
for (Node* n : outside_compilation_nodes) {
|
||||
TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n));
|
||||
auto host_compute_node_or = ReplaceOutsideCompilationCallNode(
|
||||
graph_out.get(), n, host_compute_core, *cluster_deps);
|
||||
TF_RETURN_IF_ERROR(host_compute_node_or.status());
|
||||
|
@ -13,13 +13,16 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
|
||||
#include <mutex> // NOLINT
|
||||
|
||||
#include "absl/base/call_once.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/strings/strip.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -32,7 +35,7 @@ XlaOpsCommonFlags* ops_flags;
|
||||
IntroduceFloatingPointJitterPassFlags* jitter_flags;
|
||||
|
||||
std::vector<Flag>* flag_list;
|
||||
std::once_flag flags_init;
|
||||
absl::once_flag flags_init;
|
||||
|
||||
bool SetterForXlaAutoJitFlag(const string& value) {
|
||||
int32 opt_level;
|
||||
@ -155,6 +158,7 @@ void AllocateAndParseFlags() {
|
||||
|
||||
device_flags = new XlaDeviceFlags;
|
||||
device_flags->tf_xla_compile_on_demand = false;
|
||||
device_flags->tf_xla_enable_xla_devices = true;
|
||||
|
||||
ops_flags = new XlaOpsCommonFlags;
|
||||
ops_flags->tf_xla_always_defer_compilation = false;
|
||||
@ -187,6 +191,12 @@ void AllocateAndParseFlags() {
|
||||
"Switch a device into 'on-demand' mode, where instead of "
|
||||
"autoclustering ops are compiled one by one just-in-time."),
|
||||
|
||||
Flag("tf_xla_enable_xla_devices",
|
||||
&device_flags->tf_xla_enable_xla_devices,
|
||||
"Generate XLA_* devices, where placing a computation on such a "
|
||||
"device"
|
||||
"forces compilation by XLA. Deprecated."),
|
||||
|
||||
Flag("tf_xla_always_defer_compilation",
|
||||
&ops_flags->tf_xla_always_defer_compilation, ""),
|
||||
|
||||
@ -206,38 +216,45 @@ void AllocateAndParseFlags() {
|
||||
} // namespace
|
||||
|
||||
bool SetXlaAutoJitFlagFromFlagString(const string& value) {
|
||||
std::call_once(flags_init, &AllocateAndParseFlags);
|
||||
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||
return SetterForXlaAutoJitFlag(value);
|
||||
}
|
||||
|
||||
BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags() {
|
||||
std::call_once(flags_init, &AllocateAndParseFlags);
|
||||
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||
return build_ops_flags;
|
||||
}
|
||||
|
||||
MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() {
|
||||
std::call_once(flags_init, &AllocateAndParseFlags);
|
||||
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||
return mark_for_compilation_flags;
|
||||
}
|
||||
|
||||
XlaDeviceFlags* GetXlaDeviceFlags() {
|
||||
std::call_once(flags_init, &AllocateAndParseFlags);
|
||||
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||
return device_flags;
|
||||
}
|
||||
|
||||
const XlaOpsCommonFlags& GetXlaOpsCommonFlags() {
|
||||
std::call_once(flags_init, &AllocateAndParseFlags);
|
||||
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||
return *ops_flags;
|
||||
}
|
||||
|
||||
const IntroduceFloatingPointJitterPassFlags&
|
||||
GetIntroduceFloatingPointJitterPassFlags() {
|
||||
std::call_once(flags_init, &AllocateAndParseFlags);
|
||||
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||
return *jitter_flags;
|
||||
}
|
||||
|
||||
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
|
||||
std::call_once(flags_init, &AllocateAndParseFlags);
|
||||
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
||||
}
|
||||
|
||||
static bool xla_is_enabled = false;
|
||||
|
||||
void SetXlaIsEnabled() { xla_is_enabled = true; }
|
||||
|
||||
bool IsXlaEnabled() { return xla_is_enabled; }
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -87,6 +87,9 @@ struct XlaDeviceFlags {
|
||||
// Enabling this mode by a legacy flag is a temporary mechanism. When this
|
||||
// feature is battle-tested, we will switch this to be a session option.
|
||||
bool tf_xla_compile_on_demand;
|
||||
|
||||
// Enables "XLA" devices if this flag is set.
|
||||
bool tf_xla_enable_xla_devices;
|
||||
};
|
||||
|
||||
// Flags common to the _Xla* ops and their kernels.
|
||||
@ -151,6 +154,15 @@ GetIntroduceFloatingPointJitterPassFlags();
|
||||
// Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet.
|
||||
void AppendMarkForCompilationPassFlags(
|
||||
std::vector<tensorflow::Flag>* flag_list);
|
||||
|
||||
// Makes all future calls to `IsXlaEnabled()` return `true`.
|
||||
//
|
||||
// Should only be called when XLA is linked in.
|
||||
void SetXlaIsEnabled();
|
||||
|
||||
// Returns whether XLA is enabled.
|
||||
bool IsXlaEnabled();
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "absl/base/call_once.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
@ -1616,8 +1617,8 @@ StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileClusterImpl(
|
||||
|
||||
if (!should_compile && global_jit_level_ != OptimizerOptions::OFF &&
|
||||
device_type.type_string() == DEVICE_CPU) {
|
||||
static std::once_flag once;
|
||||
std::call_once(once, [] {
|
||||
static absl::once_flag once;
|
||||
absl::call_once(once, [] {
|
||||
LOG(WARNING)
|
||||
<< "(One-time warning): Not using XLA:CPU for cluster because envvar "
|
||||
"TF_XLA_FLAGS=--tf_xla_cpu_global_jit was not set. If you want "
|
||||
@ -1776,9 +1777,9 @@ absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable() {
|
||||
"Lgamma", "Digamma",
|
||||
// Binary
|
||||
"Add", "AddV2", "Sub", "Mul", "Div", "Atan2", "Complex", "DivNoNan",
|
||||
"MulNoNan", "FloorDiv", "Xlogy", "Xdivy", "FloorMod", "BitwiseAnd",
|
||||
"BitwiseOr", "BitwiseXor", "LeftShift", "RightShift", "LogicalAnd",
|
||||
"LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
|
||||
"MulNoNan", "FloorDiv", "Xlogy", "Xlog1py", "Xdivy", "FloorMod",
|
||||
"BitwiseAnd", "BitwiseOr", "BitwiseXor", "LeftShift", "RightShift",
|
||||
"LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
|
||||
"ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "TruncateDiv",
|
||||
"TruncateMod", "Equal", "NotEqual", "Greater", "GreaterEqual",
|
||||
"Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", "SoftsignGrad",
|
||||
@ -1872,6 +1873,8 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"Einsum",
|
||||
"EmptyTensorList",
|
||||
"ExtractImagePatches",
|
||||
"Igamma",
|
||||
"Igammac",
|
||||
"FFT",
|
||||
"FFT2D",
|
||||
"FFT3D",
|
||||
|
@ -163,12 +163,11 @@ Status XlaCompilationCache::BuildExecutable(
|
||||
build_options.set_device_allocator(options.device_allocator);
|
||||
build_options.set_alias_passthrough_params(options.alias_passthrough_params);
|
||||
|
||||
auto compile_result =
|
||||
client_->Compile(*result.computation, argument_layouts, build_options);
|
||||
if (!compile_result.ok()) {
|
||||
return compile_result.status();
|
||||
}
|
||||
*executable = std::move(compile_result.ValueOrDie());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto executables,
|
||||
client_->Compile(*result.computation, argument_layouts, build_options));
|
||||
TF_RET_CHECK(executables.size() == 1);
|
||||
*executable = std::move(executables[0]);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -36,8 +36,13 @@ class XlaCpuDeviceFactory : public DeviceFactory {
|
||||
};
|
||||
|
||||
Status XlaCpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
||||
devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0"));
|
||||
XlaDeviceFlags* flags = GetXlaDeviceFlags();
|
||||
if (!flags->tf_xla_enable_xla_devices) {
|
||||
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0"));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -45,6 +50,10 @@ Status XlaCpuDeviceFactory::CreateDevices(
|
||||
const SessionOptions& session_options, const string& name_prefix,
|
||||
std::vector<std::unique_ptr<Device>>* devices) {
|
||||
XlaDeviceFlags* flags = GetXlaDeviceFlags();
|
||||
if (!flags->tf_xla_enable_xla_devices) {
|
||||
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
|
||||
return Status::OK();
|
||||
}
|
||||
bool compile_on_demand = flags->tf_xla_compile_on_demand;
|
||||
|
||||
XlaOpRegistry::DeviceRegistration registration;
|
||||
|
@ -20,7 +20,9 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/base/call_once.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
|
||||
#include "tensorflow/compiler/jit/xla_device_context.h"
|
||||
@ -386,14 +388,33 @@ Status XlaDevice::TryGetDeviceContext(DeviceContext** out_context) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Warn about XLA_CPU/XLA_GPU exactly once.
|
||||
static void ShowXlaDeviceDeprecationWarning(
|
||||
absl::string_view compilation_device_name) {
|
||||
static absl::once_flag once;
|
||||
if (absl::StrContains(compilation_device_name, "CPU") ||
|
||||
absl::StrContains(compilation_device_name, "GPU")) {
|
||||
absl::call_once(once, [] {
|
||||
LOG(WARNING)
|
||||
<< "XLA_GPU and XLA_CPU devices are deprecated and will be "
|
||||
"removed in subsequent releases. Instead, use either "
|
||||
"@tf.function(experimental_compile=True) for must-compile "
|
||||
"semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 "
|
||||
"for auto-clustering best-effort compilation.";
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
|
||||
VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":"
|
||||
<< op_kernel->type_string();
|
||||
ShowXlaDeviceDeprecationWarning(jit_device_name_.type_string());
|
||||
op_kernel->Compute(context);
|
||||
}
|
||||
|
||||
void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
|
||||
AsyncOpKernel::DoneCallback done) {
|
||||
ShowXlaDeviceDeprecationWarning(jit_device_name_.type_string());
|
||||
VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
|
||||
<< op_kernel->type_string();
|
||||
op_kernel->ComputeAsync(context, done);
|
||||
|
@ -140,7 +140,6 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
||||
// The device tensor should always be fresh.
|
||||
TF_RET_CHECK(!xla_tensor->has_shaped_buffer());
|
||||
|
||||
xla_tensor->set_host_tensor(*cpu_tensor);
|
||||
TF_RETURN_IF_ERROR(
|
||||
xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
|
||||
stream_->parent()->device_ordinal()));
|
||||
|
@ -14,17 +14,20 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
|
||||
// operators using XLA via the XLA "CUDA" (GPU) backend.
|
||||
// operators using XLA via the XLA "CUDA" or "ROCM" (GPU) backend.
|
||||
|
||||
#include <set>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#include "tensorflow/compiler/jit/xla_device_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -61,7 +64,14 @@ class XlaGpuDeviceFactory : public DeviceFactory {
|
||||
};
|
||||
|
||||
Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
||||
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA");
|
||||
XlaDeviceFlags* flags = GetXlaDeviceFlags();
|
||||
if (!flags->tf_xla_enable_xla_devices) {
|
||||
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
auto platform =
|
||||
se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
|
||||
if (!platform.ok()) {
|
||||
// Treat failures as non-fatal; there might not be a GPU in the machine.
|
||||
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
|
||||
@ -84,6 +94,12 @@ Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
||||
Status XlaGpuDeviceFactory::CreateDevices(
|
||||
const SessionOptions& session_options, const string& name_prefix,
|
||||
std::vector<std::unique_ptr<Device>>* devices) {
|
||||
XlaDeviceFlags* flags = GetXlaDeviceFlags();
|
||||
if (!flags->tf_xla_enable_xla_devices) {
|
||||
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
XlaOpRegistry::DeviceRegistration registration;
|
||||
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
|
||||
registration.autoclustering_policy =
|
||||
@ -103,7 +119,8 @@ Status XlaGpuDeviceFactory::CreateDevices(
|
||||
RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT);
|
||||
(void)registrations;
|
||||
|
||||
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA");
|
||||
auto platform =
|
||||
se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
|
||||
if (!platform.ok()) {
|
||||
// Treat failures as non-fatal; there might not be a GPU in the machine.
|
||||
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/jit/xla_kernel_creator.h"
|
||||
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/xla_kernel_creator_util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
|
||||
@ -39,6 +40,10 @@ bool RegisterLaunchOpCreator() {
|
||||
}
|
||||
|
||||
static bool register_me = RegisterLaunchOpCreator();
|
||||
static bool register_xla = [] {
|
||||
SetXlaIsEnabled();
|
||||
return true;
|
||||
}();
|
||||
|
||||
} // end namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -222,8 +222,9 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
OpKernelConstruction construction(
|
||||
DeviceType(dev->device_type()), dev,
|
||||
dev->GetAllocator(AllocatorAttributes()), &node_def,
|
||||
&fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types,
|
||||
fbody->ret_types, output_memory_types, flr->graph_def_version(), &s);
|
||||
&fbody->fdef.signature(), flr, dev->resource_manager(), fbody->arg_types,
|
||||
input_memory_types, fbody->ret_types, output_memory_types,
|
||||
flr->graph_def_version(), &s);
|
||||
|
||||
*kernel = absl::make_unique<XlaLocalLaunchBase>(
|
||||
&construction, constant_arg_indices, resource_arg_indices, function);
|
||||
|
@ -66,6 +66,8 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_optimize",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_passes",
|
||||
"//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_to_quant",
|
||||
"//tensorflow/compiler/mlir/lite/quantization/xla:hlo_xla_quantization_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
|
||||
@ -77,10 +79,10 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_linalg",
|
||||
"//tensorflow/compiler/mlir/xla:xla_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_to_linalg",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
|
||||
"//tensorflow/compiler/mlir/xla:xla_lower",
|
||||
"//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts",
|
||||
|
@ -26,9 +26,11 @@ package_group(
|
||||
filegroup(
|
||||
name = "tensorflow_lite_ops_td_files",
|
||||
srcs = [
|
||||
"ir/tfl_op_interfaces.td",
|
||||
"ir/tfl_ops.td",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:include/mlir/Transforms/LoopLikeInterface.td",
|
||||
],
|
||||
)
|
||||
|
||||
@ -55,6 +57,25 @@ gentbl(
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "tensorflow_lite_op_interfaces_inc_gen",
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-op-interface-decls",
|
||||
"ir/tfl_ops_interface.h.inc",
|
||||
),
|
||||
(
|
||||
"-gen-op-interface-defs",
|
||||
"ir/tfl_ops_interface.cc.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "ir/tfl_op_interfaces.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "tensorflow_lite_prepare_tf_inc_gen",
|
||||
tbl_outs = [
|
||||
@ -177,11 +198,12 @@ cc_library(
|
||||
"ir/tfl_ops.cc",
|
||||
"ir/tfl_ops.cc.inc",
|
||||
"ir/tfl_ops.h.inc",
|
||||
"ir/tfl_ops_interface.cc.inc",
|
||||
"ir/tfl_ops_interface.h.inc",
|
||||
"utils/attribute_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"ir/tfl_ops.h",
|
||||
"ir/tfl_traits.h",
|
||||
"transforms/passes.h",
|
||||
"utils/attribute_utils.h",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",
|
||||
@ -190,8 +212,6 @@ cc_library(
|
||||
deps = [
|
||||
":tensorflow_lite_ops_inc_gen",
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
@ -200,6 +220,10 @@ cc_library(
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
# TODO(jpienaar): Move this out after splitting out LoopLikeOpInterface.
|
||||
"@llvm-project//mlir:Transforms",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -258,6 +282,7 @@ tf_cc_test(
|
||||
cc_library(
|
||||
name = "tensorflow_lite_legalize_tf",
|
||||
srcs = [
|
||||
"transforms/dilated_conv.cc",
|
||||
"transforms/extract_ophint.cc",
|
||||
"transforms/generated_legalize_tf.inc",
|
||||
"transforms/generated_lower_static_tensor_list.inc",
|
||||
@ -273,6 +298,7 @@ cc_library(
|
||||
"transforms/unroll_batch_matmul.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"transforms/dilated_conv.h",
|
||||
"transforms/passes.h",
|
||||
"transforms/unroll_batch_matmul.h",
|
||||
],
|
||||
@ -284,13 +310,16 @@ cc_library(
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/kernels:tensor_list",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
@ -316,6 +345,7 @@ cc_library(
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
@ -347,6 +377,7 @@ cc_library(
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/lite/quantization/lite:tfl_to_std",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
@ -371,6 +402,8 @@ genrule(
|
||||
name = "op_quant_spec_getters_inc",
|
||||
srcs = [
|
||||
"ir/tfl_ops.td",
|
||||
"ir/tfl_op_interfaces.td",
|
||||
"@llvm-project//mlir:include/mlir/Transforms/LoopLikeInterface.td",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
],
|
||||
outs = [
|
||||
@ -673,12 +706,16 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
exports_files(
|
||||
["transforms/passes.h"],
|
||||
cc_library(
|
||||
name = "empty_passes",
|
||||
hdrs = ["transforms/passes.h"],
|
||||
visibility = [
|
||||
"//configs/devtools/hawkeye/tflite:__subpackages__",
|
||||
"//learning/brain/models/app_benchmarks:__subpackages__",
|
||||
"//tensorflow/compiler/mlir/lite:friends",
|
||||
"//tensorflow/lite/experimental/mlir:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
"@llvm-project//llvm:support",
|
||||
],
|
||||
)
|
||||
|
@ -31,10 +31,11 @@ struct PassConfig {
|
||||
: emit_builtin_tflite_ops(true),
|
||||
lower_tensor_list_ops(false),
|
||||
trim_functions_whitelist({}),
|
||||
quant_specs(specs),
|
||||
quant_specs(std::move(specs)),
|
||||
skip_control_dialect(false),
|
||||
form_clusters(false),
|
||||
inline_functions(false) {}
|
||||
inline_functions(true),
|
||||
unfold_batch_matmul(true) {}
|
||||
|
||||
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
|
||||
// added, which produces TF Lite ops.
|
||||
@ -57,6 +58,9 @@ struct PassConfig {
|
||||
// Inline function calls within the main function in the MLIR module, prior
|
||||
// to legalization to TFLite.
|
||||
bool inline_functions;
|
||||
// if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set
|
||||
// of tfl.fully_connected ops.
|
||||
bool unfold_batch_matmul;
|
||||
};
|
||||
|
||||
} // namespace TFL
|
||||
|
@ -389,7 +389,6 @@ StatusOr<mlir::ElementsAttr> ConvertIntBuffer(
|
||||
mlir::RankedTensorType shaped_type, mlir::Type elem_type,
|
||||
const std::vector<uint8_t>& buffer) {
|
||||
unsigned bit_width;
|
||||
mlir::RankedTensorType buffer_type;
|
||||
if (auto itype = elem_type.dyn_cast<mlir::IntegerType>()) {
|
||||
bit_width = itype.getWidth();
|
||||
} else if (auto qtype = elem_type.dyn_cast<QuantizedType>()) {
|
||||
@ -920,15 +919,13 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
// represents TFLite, this entry point must be called "main"
|
||||
// TODO(b/131175224,b/132239787) Support multiple entry points
|
||||
std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
|
||||
if (subgraph.name.empty()) {
|
||||
if (index == 0) {
|
||||
return "main";
|
||||
} else {
|
||||
return llvm::formatv("fn_{0}", index).str();
|
||||
}
|
||||
} else {
|
||||
return subgraph.name;
|
||||
if (index == 0) {
|
||||
return "main";
|
||||
}
|
||||
if (subgraph.name.empty()) {
|
||||
return llvm::formatv("fn_{0}", index).str();
|
||||
}
|
||||
return subgraph.name;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -90,6 +90,7 @@ using mlir::MLIRContext;
|
||||
using mlir::ModuleOp;
|
||||
using mlir::NoneType;
|
||||
using mlir::Operation;
|
||||
using mlir::Region;
|
||||
using mlir::StringAttr;
|
||||
using mlir::TensorType;
|
||||
using mlir::TranslateFromMLIRRegistration;
|
||||
@ -309,7 +310,7 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::unique_ptr<::tensorflow::NodeDef> getTensorFlowNodeDef(
|
||||
static std::unique_ptr<::tensorflow::NodeDef> GetTensorFlowNodeDef(
|
||||
::mlir::Operation* inst) {
|
||||
// We pass empty string for the original node_def name since Flex runtime
|
||||
// does not care about this being set correctly on node_def. There is no
|
||||
@ -425,6 +426,11 @@ class Translator {
|
||||
mlir::TF::WhileOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
|
||||
// Build while operator where cond & body are regions.
|
||||
BufferOffset<tflite::Operator> BuildWhileOperator(
|
||||
mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
|
||||
// Builds custom operators.
|
||||
// Templated on a) data type of custom_option to be stored into flatbuffer,
|
||||
// and b) TFL custom op type.
|
||||
@ -472,7 +478,10 @@ class Translator {
|
||||
Operation* inst, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
|
||||
Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(FuncOp fn);
|
||||
// Build a subgraph with a given name out of the region either corresponding
|
||||
// to a function's body or while op.
|
||||
Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(
|
||||
const std::string& name, Region* region);
|
||||
|
||||
// Builds Metadata with the given `name` and buffer `content`.
|
||||
BufferOffset<tflite::Metadata> BuildMetadata(StringRef name,
|
||||
@ -494,6 +503,12 @@ class Translator {
|
||||
// Returns a unique name for `val`.
|
||||
std::string UniqueName(mlir::Value val);
|
||||
|
||||
// Returns the names of the subgraphs corresponding the regions of the op. The
|
||||
// names are supposed to be unique as the op name is unique and the suffix is
|
||||
// not a valid name.
|
||||
std::string GetWhileBodyName(mlir::TFL::WhileOp while_op);
|
||||
std::string GetWhileCondName(mlir::TFL::WhileOp while_op);
|
||||
|
||||
ModuleOp module_;
|
||||
|
||||
tensorflow::OpOrArgNameMapper& name_mapper_;
|
||||
@ -523,7 +538,7 @@ class Translator {
|
||||
};
|
||||
|
||||
std::string Translator::UniqueName(mlir::Value val) {
|
||||
return name_mapper_.GetUniqueName(val);
|
||||
return std::string(name_mapper_.GetUniqueName(val));
|
||||
}
|
||||
|
||||
Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
|
||||
@ -595,6 +610,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
||||
};
|
||||
|
||||
std::vector<int32_t> shape;
|
||||
std::vector<int32_t> shape_signature;
|
||||
if (type.hasStaticShape()) {
|
||||
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
|
||||
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
|
||||
@ -612,7 +628,17 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
||||
|
||||
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
|
||||
}
|
||||
} else if (type.hasRank()) {
|
||||
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
|
||||
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
|
||||
|
||||
shape.reserve(shape_ref.size());
|
||||
for (auto& dim : shape_ref) {
|
||||
shape.push_back(dim == -1 ? 1 : dim);
|
||||
}
|
||||
shape_signature = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
|
||||
}
|
||||
|
||||
Type element_type = type.getElementType();
|
||||
tflite::TensorType tflite_element_type =
|
||||
GetTFLiteType(type.getElementType()).ValueOrDie();
|
||||
@ -649,10 +675,19 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
||||
break;
|
||||
}
|
||||
}
|
||||
return tflite::CreateTensor(
|
||||
builder_, builder_.CreateVector(shape), tflite_element_type,
|
||||
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
|
||||
/*is_variable=*/is_variable);
|
||||
|
||||
if (shape_signature.empty()) {
|
||||
return tflite::CreateTensor(
|
||||
builder_, builder_.CreateVector(shape), tflite_element_type,
|
||||
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
|
||||
/*is_variable=*/is_variable);
|
||||
} else {
|
||||
return tflite::CreateTensor(
|
||||
builder_, builder_.CreateVector(shape), tflite_element_type,
|
||||
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
|
||||
/*is_variable=*/is_variable, /*sparsity=*/0,
|
||||
/*shape_signature=*/builder_.CreateVector(shape_signature));
|
||||
}
|
||||
}
|
||||
|
||||
BufferOffset<tflite::Operator> Translator::BuildIfOperator(
|
||||
@ -687,6 +722,30 @@ BufferOffset<tflite::Operator> Translator::BuildWhileOperator(
|
||||
builtin_options);
|
||||
}
|
||||
|
||||
std::string Translator::GetWhileBodyName(mlir::TFL::WhileOp while_op) {
|
||||
return (name_mapper_.GetUniqueName(while_op.getOperation()) + "$body").str();
|
||||
}
|
||||
|
||||
std::string Translator::GetWhileCondName(mlir::TFL::WhileOp while_op) {
|
||||
return (name_mapper_.GetUniqueName(while_op.getOperation()) + "$cond").str();
|
||||
}
|
||||
|
||||
BufferOffset<tflite::Operator> Translator::BuildWhileOperator(
|
||||
mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results) {
|
||||
auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE);
|
||||
int body_subgraph_index = subgraph_index_map_.at(GetWhileBodyName(op));
|
||||
int cond_subgraph_index = subgraph_index_map_.at(GetWhileCondName(op));
|
||||
auto builtin_options = tflite::CreateWhileOptions(
|
||||
builder_, cond_subgraph_index, body_subgraph_index)
|
||||
.Union();
|
||||
auto inputs = builder_.CreateVector(operands);
|
||||
auto outputs = builder_.CreateVector(results);
|
||||
return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
|
||||
tflite::BuiltinOptions_WhileOptions,
|
||||
builtin_options);
|
||||
}
|
||||
|
||||
template <typename CustomOptionType, typename TFLOp>
|
||||
BufferOffset<tflite::Operator> Translator::BuildCustomOperator(
|
||||
const CustomOptionType& custom_option, const std::string& opcode_name,
|
||||
@ -908,6 +967,16 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
return BuildMaxUnpooling2DOperator(inst, max_unpooling_op, operands,
|
||||
results);
|
||||
}
|
||||
if (auto whileOp = dyn_cast<mlir::TFL::WhileOp>(inst)) {
|
||||
if (inst->getNumOperands() != inst->getNumResults()) {
|
||||
inst->emitOpError(
|
||||
"number of operands and results don't match, only canonical "
|
||||
"TFL While supported");
|
||||
return llvm::None;
|
||||
}
|
||||
return BuildWhileOperator(whileOp, operands, results);
|
||||
}
|
||||
|
||||
inst->emitOpError("is not a supported TFLite op");
|
||||
return llvm::None;
|
||||
}
|
||||
@ -944,7 +1013,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
// we emit op as flex.
|
||||
// if custom is enabled
|
||||
// we emit the op as custom.
|
||||
auto node_def = getTensorFlowNodeDef(inst);
|
||||
auto node_def = GetTensorFlowNodeDef(inst);
|
||||
if (!node_def) {
|
||||
return llvm::None;
|
||||
}
|
||||
@ -1043,18 +1112,16 @@ void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) {
|
||||
|
||||
bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) {
|
||||
std::vector<int> operand_indices;
|
||||
// TODO(b/138254427): When the bug is addressed, we'll be able to inspect
|
||||
// for the presence of a specific OpTrait using mlir::Operation, without
|
||||
// having to cast it to specific ops like below.
|
||||
// Until then, when a new RNN/LSTM op is added to TFLite and has stateful
|
||||
// tensors as operands, they will need to be added here as well.
|
||||
if (!mlir::TFL::IsStatefulOp(op, &operand_indices)) return false;
|
||||
return absl::c_find(operand_indices, operand_index) != operand_indices.end();
|
||||
}
|
||||
|
||||
Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
|
||||
const std::string& name, Region* region) {
|
||||
bool has_input_attr = false;
|
||||
InitializeNamesFromAttribute(fn, &has_input_attr);
|
||||
if (auto fn = dyn_cast<FuncOp>(region->getParentOp())) {
|
||||
InitializeNamesFromAttribute(fn, &has_input_attr);
|
||||
}
|
||||
std::vector<BufferOffset<tflite::Tensor>> tensors;
|
||||
llvm::DenseMap<Value, int> tensor_index_map;
|
||||
|
||||
@ -1086,7 +1153,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
};
|
||||
|
||||
std::vector<BufferOffset<tflite::Operator>> operators;
|
||||
auto& bb = fn.getBlocks().front();
|
||||
auto& bb = region->front();
|
||||
|
||||
// Main function's arguments are first passed to `input` op so they don't
|
||||
// have associated tensor and buffer. Build FlatBuffer tensor and buffer for
|
||||
@ -1094,7 +1161,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
for (unsigned i = 0, e = bb.getNumArguments(); i < e; ++i) {
|
||||
mlir::BlockArgument arg = bb.getArgument(i);
|
||||
std::string name;
|
||||
if (has_input_attr) name = name_mapper_.GetUniqueName(arg);
|
||||
if (has_input_attr) name = std::string(name_mapper_.GetUniqueName(arg));
|
||||
if (name.empty()) name = absl::StrCat("arg", i);
|
||||
if (!build_tensor_and_buffer(arg, name)) return llvm::None;
|
||||
}
|
||||
@ -1146,7 +1213,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
return tflite::CreateSubGraph(
|
||||
builder_, builder_.CreateVector(tensors), builder_.CreateVector(inputs),
|
||||
builder_.CreateVector(outputs), builder_.CreateVector(operators),
|
||||
/*name=*/builder_.CreateString(fn.getName().str()));
|
||||
/*name=*/builder_.CreateString(name));
|
||||
}
|
||||
|
||||
BufferOffset<tflite::Metadata> Translator::BuildMetadata(StringRef name,
|
||||
@ -1189,35 +1256,45 @@ Optional<std::string> Translator::Translate(
|
||||
}
|
||||
|
||||
Optional<std::string> Translator::TranslateInternal() {
|
||||
// Create a list of functions in the module with main function being the
|
||||
// first function in the list. This is required as the first subgraph in the
|
||||
// model is entry point for the model.
|
||||
std::vector<FuncOp> functions;
|
||||
functions.reserve(std::distance(module_.begin(), module_.end()));
|
||||
// A list of named regions in the module with main function being the first in
|
||||
// the list. The main function is required as the first subgraph in the model
|
||||
// is entry point for the model.
|
||||
std::vector<std::pair<std::string, Region*>> named_regions;
|
||||
named_regions.reserve(std::distance(module_.begin(), module_.end()));
|
||||
|
||||
int subgraph_idx = 0;
|
||||
FuncOp main_fn = module_.lookupSymbol<FuncOp>("main");
|
||||
subgraph_index_map_[main_fn.getName().str()] = subgraph_idx++;
|
||||
functions.push_back(main_fn);
|
||||
for (auto fn : module_.getOps<FuncOp>()) {
|
||||
if (fn == main_fn) continue;
|
||||
named_regions.emplace_back("main", &main_fn.getBody());
|
||||
// Walk over the module collection ops with functions and while ops.
|
||||
module_.walk([&](Operation* op) {
|
||||
if (auto fn = dyn_cast<FuncOp>(op)) {
|
||||
if (fn != main_fn) {
|
||||
subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
|
||||
named_regions.emplace_back(fn.getName().str(), &fn.getBody());
|
||||
}
|
||||
} else if (auto wo = dyn_cast<mlir::TFL::WhileOp>(op)) {
|
||||
std::string name = GetWhileCondName(wo);
|
||||
subgraph_index_map_[name] = subgraph_idx++;
|
||||
named_regions.emplace_back(GetWhileCondName(wo), &wo.cond());
|
||||
name = GetWhileBodyName(wo);
|
||||
subgraph_index_map_[name] = subgraph_idx++;
|
||||
named_regions.emplace_back(name, &wo.body());
|
||||
}
|
||||
});
|
||||
|
||||
subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
|
||||
functions.push_back(fn);
|
||||
}
|
||||
|
||||
// Build subgraph for each of the functions.
|
||||
// Build subgraph for each of the named regions.
|
||||
std::vector<BufferOffset<tflite::SubGraph>> subgraphs;
|
||||
subgraphs.reserve(functions.size());
|
||||
subgraphs.reserve(named_regions.size());
|
||||
int first_failed_func = -1;
|
||||
for (int i = 0; i < functions.size(); ++i) {
|
||||
auto subgraph_or = BuildSubGraph(functions[i]);
|
||||
for (auto it : llvm::enumerate(named_regions)) {
|
||||
auto subgraph_or = BuildSubGraph(it.value().first, it.value().second);
|
||||
if (!subgraph_or) {
|
||||
if (first_failed_func == -1)
|
||||
// Record the index of the first function that cannot be converted.
|
||||
// Record the index of the first region that cannot be converted.
|
||||
// Keep looping through all subgraphs in the module to make sure that
|
||||
// we collect the list of missing ops from the entire module.
|
||||
first_failed_func = i;
|
||||
first_failed_func = it.index();
|
||||
} else {
|
||||
subgraphs.push_back(*subgraph_or);
|
||||
}
|
||||
@ -1238,9 +1315,10 @@ Optional<std::string> Translator::TranslateInternal() {
|
||||
"-emit-custom-ops flag): " +
|
||||
failed_custom_ops_list;
|
||||
|
||||
return functions[first_failed_func].emitError("failed while converting: '")
|
||||
<< functions[first_failed_func].getName() << "\'\n"
|
||||
<< err,
|
||||
auto& failed_region = named_regions[first_failed_func];
|
||||
return failed_region.second->getParentOp()->emitError()
|
||||
<< "failed while converting: '" << failed_region.first
|
||||
<< "': " << err,
|
||||
llvm::None;
|
||||
}
|
||||
|
||||
|
58
tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td
Normal file
58
tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td
Normal file
@ -0,0 +1,58 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This is the operation interface definition file for TensorFlow Lite.
|
||||
|
||||
#ifndef TFL_OP_INTERFACES
|
||||
#define TFL_OP_INTERFACES
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL op interface for stateful operands.
|
||||
|
||||
def TFL_StatefulOp : OpInterface<"StatefulOpInterface"> {
|
||||
let description = [{
|
||||
Interface for ops that are stateful and need to identify stateful operands.
|
||||
|
||||
Stateful operands correspond to TF's variables semantics. An op that has 1
|
||||
or more stateful operands is a stateful op.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
[{Returns the indices of stateful operands.}],
|
||||
"std::vector<int>", "GetStatefulOperands", (ins)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL op interface for output channel index.
|
||||
|
||||
def TFL_ChannelDimIndexInterface : OpInterface<"ChannelDimIndexInterface"> {
|
||||
let description = [{
|
||||
Interface for defining the index of out channel index.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
[{Returns the dimension index of the output channels.}],
|
||||
"int", "GetChannelDimIndex", (ins)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
#endif // TFL_OP_INTERFACES
|
@ -797,8 +797,7 @@ struct RemoveAdjacentReshape : public RewritePattern {
|
||||
// With
|
||||
// %2 = "tfl.reshape"(%0, %shape1)
|
||||
rewriter.replaceOpWithNewOp<ReshapeOp>(
|
||||
{prevOp.getResult()}, op, thisOp.getType(), prevOp.getOperand(0),
|
||||
thisOp.getOperand(1));
|
||||
op, thisOp.getType(), prevOp.getOperand(0), thisOp.getOperand(1));
|
||||
}
|
||||
};
|
||||
|
||||
@ -1302,6 +1301,19 @@ OpFoldResult AbsOp::fold(ArrayRef<Attribute> operands) {
|
||||
return ConstFoldUnaryOp(result_type, operands[0], compute);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NegOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult NegOp::fold(ArrayRef<Attribute> operands) {
|
||||
Type result_type = getType();
|
||||
// Only constant fold for tensor of f32 is implemented.
|
||||
if (!IsF32ShapedType(result_type)) return nullptr;
|
||||
|
||||
auto compute = [](APFloat value) -> APFloat { return llvm::neg(value); };
|
||||
return ConstFoldUnaryOp(result_type, operands[0], compute);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SinOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1724,10 +1736,97 @@ static LogicalResult Verify(TransposeOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct WhileResultOperandsMatch : public OpRewritePattern<WhileOp> {
|
||||
using OpRewritePattern<WhileOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(WhileOp while_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto size = while_op.body().front().getArguments().size();
|
||||
Operation *op = while_op.getOperation();
|
||||
auto old_size = op->getNumResults();
|
||||
// No change needed as the number of operands match the number of results.
|
||||
if (size == old_size) return matchFailure();
|
||||
|
||||
// Collect the new types by combining results of old op with additional
|
||||
// operand results.
|
||||
llvm::SmallVector<Type, 4> types;
|
||||
types.reserve(size);
|
||||
for (auto type : while_op.getResultTypes()) types.push_back(type);
|
||||
for (auto arg : while_op.body().front().getArguments().drop_front(old_size))
|
||||
types.push_back(arg.getType());
|
||||
// Collect operands.
|
||||
llvm::SmallVector<Value, 8> operands;
|
||||
operands.reserve(while_op.getNumOperands());
|
||||
for (auto operand : while_op.getOperands()) operands.push_back(operand);
|
||||
|
||||
// Replace with new While with matching operands and results.
|
||||
Operation *new_op = rewriter.insert(
|
||||
Operation::create(op->getLoc(), op->getName(), types, operands,
|
||||
op->getAttrs(), {}, /*numRegions=*/2,
|
||||
/*resizableOperandList=*/true));
|
||||
for (int i = 0; i < 2; ++i) new_op->getRegion(i).takeBody(op->getRegion(i));
|
||||
rewriter.replaceOp(op,
|
||||
new_op->getResults().take_front(op->getNumResults()));
|
||||
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<WhileResultOperandsMatch>(context);
|
||||
}
|
||||
|
||||
Region &WhileOp::getLoopBody() { return body(); }
|
||||
|
||||
bool WhileOp::isDefinedOutsideOfLoop(Value value) {
|
||||
// TODO(jpienaar): This is to overly conservative and disables anything other
|
||||
// than constant hoisting initially.
|
||||
return false;
|
||||
}
|
||||
|
||||
LogicalResult WhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) {
|
||||
// TODO(jpienaar): This can be removed post the resizable trait is added.
|
||||
Operation &while_op = *this->getOperation();
|
||||
if (!while_op.hasResizableOperandsList()) return failure();
|
||||
if (ops.empty()) return success();
|
||||
|
||||
// Operands to the while op.
|
||||
llvm::SmallVector<Value, 4> operands(getOperands());
|
||||
// Results that have to be returned by the body.
|
||||
llvm::SmallVector<Value, 4> results(
|
||||
body().front().getTerminator()->getOperands());
|
||||
for (auto op : ops) {
|
||||
// Move the hoisted value to just before the while.
|
||||
op->moveBefore(&while_op);
|
||||
|
||||
// Each result of the hoisted op becomes an input to while, cond and body.
|
||||
for (auto result : op->getResults()) {
|
||||
operands.push_back(result);
|
||||
auto type = result.getType();
|
||||
auto arg = body().front().addArgument(type);
|
||||
// Loop invariant value passes through the body function unchanged.
|
||||
result.replaceAllUsesWith(arg);
|
||||
results.push_back(arg);
|
||||
// Operand types match for body and cond. The value is hoisted out of the
|
||||
// body and so not necessarily used in cond. This could be expanded to
|
||||
// consider common usage across cond and body.
|
||||
cond().front().addArgument(type);
|
||||
}
|
||||
}
|
||||
|
||||
body().front().getTerminator()->setOperands(results);
|
||||
while_op.setOperands(operands);
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.cc.inc"
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
|
||||
|
||||
|
@ -27,7 +27,7 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
|
||||
#include "mlir/Transforms/LoopLikeInterface.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
@ -44,6 +44,7 @@ class TensorFlowLiteDialect : public Dialect {
|
||||
Location loc) override;
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.h.inc"
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc"
|
||||
|
||||
|
@ -19,6 +19,8 @@ limitations under the License.
|
||||
#define TFL_OPS
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Transforms/LoopLikeInterface.td"
|
||||
include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td"
|
||||
include "tensorflow/compiler/mlir/lite/quantization/quantization.td"
|
||||
|
||||
def TFL_Dialect : Dialect {
|
||||
@ -248,16 +250,6 @@ def TFL_ComparisonBinaryBuilder : OpBuilder<
|
||||
buildComparisonBinOp(builder, result, lhs, rhs);
|
||||
}]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL native op trait for stateful operands and channel indices.
|
||||
|
||||
class StatefulOperands<list<int> operands>
|
||||
: ParamNativeOpTrait<"TFL::StatefulOperands", StrJoinInt<operands>.result>;
|
||||
|
||||
|
||||
class ChannelDimIndex<int index>
|
||||
: ParamNativeOpTrait<"TFL::ChannelDimIndex", !cast<string>(index)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL op base class.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -285,7 +277,7 @@ class TFL_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
|
||||
class TFL_ConvOp<string mnemonic, string opSummary, int index> :
|
||||
TFL_Op<mnemonic, [NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
|
||||
ChannelDimIndex<index>, AffineOpCoefficient<index, 1>]> {
|
||||
TFL_ChannelDimIndexInterface, AffineOpCoefficient<index, 1>]> {
|
||||
let summary = opSummary # " operator";
|
||||
|
||||
let description = [{
|
||||
@ -335,7 +327,7 @@ an output element, this operation computes \\(y = |x|\\).
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TFL_AddOp : TFL_Op<"add", [Broadcastable, NoSideEffect, Commutative]> {
|
||||
def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape, NoSideEffect, Commutative]> {
|
||||
let summary = "Addition operator";
|
||||
|
||||
let description = [{
|
||||
@ -486,8 +478,7 @@ def TFL_ArgMaxOp : TFL_Op<"arg_max", [NoSideEffect]> {
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
// TODO: Add support for uint8.
|
||||
ins TensorOf<[F32, I32, I8]>:$input,
|
||||
ins TensorOf<[F32, I32, I8, TFL_Uint8, QI8, QUI8]>:$input,
|
||||
TFL_I32OrI64Tensor:$dim
|
||||
);
|
||||
|
||||
@ -515,8 +506,7 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> {
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
// TODO(pkanwar): Add support for uint8.
|
||||
ins TensorOf<[F32, I32, I8]>:$input,
|
||||
ins TensorOf<[F32, I32, I8, TFL_Uint8, QI8, QUI8]>:$input,
|
||||
TFL_I32OrI64Tensor:$dim
|
||||
);
|
||||
|
||||
@ -617,7 +607,12 @@ def TFL_ExternalConstOp : Op<TFL_Dialect, "external_const", [NoSideEffect]> {
|
||||
let results = (outs AnyTensor:$output);
|
||||
}
|
||||
|
||||
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0>;
|
||||
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0> {
|
||||
let extraClassDeclaration = [{
|
||||
// StatefulOpInterface:
|
||||
int GetChannelDimIndex() { return 0; }
|
||||
}];
|
||||
}
|
||||
|
||||
def TFL_CosOp: TFL_Op<"cos", [
|
||||
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
|
||||
@ -637,6 +632,11 @@ def TFL_CosOp: TFL_Op<"cos", [
|
||||
def TFL_DepthwiseConv2DOp :
|
||||
TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> {
|
||||
let arguments = !con(TFL_Conv2DOp.arguments, (ins I32Attr:$depth_multiplier));
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// StatefulOpInterface:
|
||||
int GetChannelDimIndex() { return 3; }
|
||||
}];
|
||||
}
|
||||
|
||||
def TFL_FCWO_Default : StrEnumAttrCase<"DEFAULT">;
|
||||
@ -650,7 +650,8 @@ def TFL_FullyConnectedOptionsWeightFormatAttr :
|
||||
|
||||
// TODO(jpienaar): Update post discussion on semantics of FC OP.
|
||||
def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
|
||||
NoSideEffect, AccumulatorUniformScale<2, 0, 1>, ChannelDimIndex<0>,
|
||||
NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
|
||||
TFL_ChannelDimIndexInterface,
|
||||
AffineOpCoefficient<-1, 1>]> {
|
||||
let summary = "Fully connected op";
|
||||
|
||||
@ -672,6 +673,11 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
|
||||
let hasOptions = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// ChannelDimIndexInterface:
|
||||
int GetChannelDimIndex() { return 0; }
|
||||
}];
|
||||
}
|
||||
|
||||
def TFL_GatherOp : TFL_Op<"gather", [
|
||||
@ -679,7 +685,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_OperandHasAtleastRank<0, 1>,
|
||||
PredOpTrait<"params and output must have same element type",
|
||||
TCresVTEtIsSameAsOp<0, 0>>
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>
|
||||
]> {
|
||||
let summary = "Gather operator";
|
||||
|
||||
@ -688,7 +694,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8]>:$params,
|
||||
TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8, QI16]>:$params,
|
||||
TensorOf<[I32, I64]>:$indices,
|
||||
I32Attr:$axis
|
||||
);
|
||||
@ -701,7 +707,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
|
||||
];
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8]>:$output
|
||||
TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8, QI16]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
@ -724,9 +730,9 @@ def TFL_GatherNdOp : TFL_Op<"gather_nd", [NoSideEffect]> {
|
||||
);
|
||||
}
|
||||
|
||||
// Same type check of lhs and rhs is handled by the Broadcastable trait.
|
||||
// Same type check of lhs and rhs is handled by the ResultsBroadcastableShape trait.
|
||||
def TFL_LessEqualOp : TFL_Op<"less_equal", [
|
||||
Broadcastable, NoSideEffect, NoQuantizableResult]> {
|
||||
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
|
||||
let summary = "Less_equal operator";
|
||||
|
||||
let description = [{
|
||||
@ -782,7 +788,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag
|
||||
}
|
||||
|
||||
def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [
|
||||
Broadcastable, NoSideEffect, NoQuantizableResult]> {
|
||||
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
|
||||
let summary = "Greater_equal operator";
|
||||
|
||||
let description = [{
|
||||
@ -943,7 +949,7 @@ larger than 0.
|
||||
}
|
||||
|
||||
def TFL_NotEqualOp : TFL_Op<"not_equal", [
|
||||
Broadcastable, Commutative, NoSideEffect, NoQuantizableResult]> {
|
||||
ResultsBroadcastableShape, Commutative, NoSideEffect, NoQuantizableResult]> {
|
||||
let summary = "Not_equal operator";
|
||||
|
||||
let description = [{
|
||||
@ -970,7 +976,7 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [
|
||||
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
|
||||
}
|
||||
|
||||
def TFL_DivOp : TFL_Op<"div", [Broadcastable, NoSideEffect]> {
|
||||
def TFL_DivOp : TFL_Op<"div", [ResultsBroadcastableShape, NoSideEffect]> {
|
||||
let summary = "Division operator";
|
||||
|
||||
let description = [{
|
||||
@ -1029,7 +1035,7 @@ def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup",
|
||||
let results = (outs TensorOf<[F32, I8, TFL_Uint8]>:$output);
|
||||
}
|
||||
|
||||
def TFL_EqualOp: TFL_Op<"equal", [Commutative, Broadcastable,
|
||||
def TFL_EqualOp: TFL_Op<"equal", [Commutative, ResultsBroadcastableShape,
|
||||
NoQuantizableResult,
|
||||
PredOpTrait<"Operands have same value type", TCopVTEtIsSameAs<0, 1>>]> {
|
||||
let summary = "Equal operator";
|
||||
@ -1063,7 +1069,8 @@ def TFL_ExpOp: TFL_Op<"exp", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let hasOptions = 0b1;
|
||||
}
|
||||
|
||||
def TFL_ExpandDimsOp: TFL_Op<"expand_dims", [NoSideEffect]> {
|
||||
def TFL_ExpandDimsOp: TFL_Op<"expand_dims", [
|
||||
NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
let summary = "Inserts a dimension of 1 into a tensor's shape.";
|
||||
|
||||
let description = [{
|
||||
@ -1173,7 +1180,7 @@ def TFL_FloorOp: TFL_Op<"floor", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
}
|
||||
|
||||
def TFL_FloorDivOp : TFL_Op<"floor_div", [
|
||||
Broadcastable, NoSideEffect, BinaryOpSameElementTypeConstraint]> {
|
||||
ResultsBroadcastableShape, NoSideEffect, BinaryOpSameElementTypeConstraint]> {
|
||||
let summary = "Floor div operator";
|
||||
|
||||
let description = [{
|
||||
@ -1192,7 +1199,7 @@ def TFL_FloorDivOp : TFL_Op<"floor_div", [
|
||||
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
|
||||
}
|
||||
|
||||
def TFL_FloorModOp : TFL_Op<"floor_mod", [Broadcastable, NoSideEffect]> {
|
||||
def TFL_FloorModOp : TFL_Op<"floor_mod", [ResultsBroadcastableShape, NoSideEffect]> {
|
||||
let summary = "Division reminder";
|
||||
|
||||
let description = [{
|
||||
@ -1209,7 +1216,7 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [Broadcastable, NoSideEffect]> {
|
||||
}
|
||||
|
||||
def TFL_GreaterOp : TFL_Op<"greater", [
|
||||
Broadcastable, NoSideEffect, NoQuantizableResult]> {
|
||||
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
|
||||
let summary = "Greater operator";
|
||||
|
||||
let description = [{
|
||||
@ -1291,7 +1298,7 @@ def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [NoSideEffect, SameOperandsAndResultTy
|
||||
}
|
||||
|
||||
def TFL_LessOp : TFL_Op<"less", [
|
||||
Broadcastable, NoSideEffect, NoQuantizableResult]> {
|
||||
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
|
||||
let summary = "Less operator";
|
||||
|
||||
let description = [{
|
||||
@ -1516,7 +1523,7 @@ def TFL_MaxUnpooling2DOp :
|
||||
}
|
||||
|
||||
def TFL_MaximumOp : TFL_Op<"maximum", [
|
||||
Broadcastable, NoSideEffect, Commutative, SameOperandsAndResultsScale,
|
||||
ResultsBroadcastableShape, NoSideEffect, Commutative, SameOperandsAndResultsScale,
|
||||
TFL_OperandHasRankLessThan<0, 4>, TFL_OperandHasRankLessThan<1, 4>]> {
|
||||
let summary = "Max operator";
|
||||
let description = [{
|
||||
@ -1655,7 +1662,8 @@ def TFL_SumOp: TFL_Op<"sum", [NoSideEffect]> {
|
||||
let customOption = "ReducerOptions";
|
||||
}
|
||||
|
||||
def TFL_ReduceMinOp: TFL_Op<"reduce_min", [NoSideEffect]> {
|
||||
def TFL_ReduceMinOp: TFL_Op<"reduce_min", [
|
||||
NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
let summary = "Min-reduction operator";
|
||||
|
||||
let description = [{
|
||||
@ -1674,7 +1682,8 @@ def TFL_ReduceMinOp: TFL_Op<"reduce_min", [NoSideEffect]> {
|
||||
let customOption = "ReducerOptions";
|
||||
}
|
||||
|
||||
def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [NoSideEffect]> {
|
||||
def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [
|
||||
NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
let summary = "Max-reduction operator";
|
||||
|
||||
let description = [{
|
||||
@ -1713,7 +1722,7 @@ def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [NoSideEffect]> {
|
||||
}
|
||||
|
||||
def TFL_MinimumOp : TFL_Op<"minimum", [
|
||||
Broadcastable, NoSideEffect, Commutative, SameOperandsAndResultsScale,
|
||||
ResultsBroadcastableShape, NoSideEffect, Commutative, SameOperandsAndResultsScale,
|
||||
TFL_OperandHasRankLessThan<0, 4>, TFL_OperandHasRankLessThan<1, 4>]> {
|
||||
let summary = "Min operator";
|
||||
let description = [{
|
||||
@ -1734,7 +1743,7 @@ def TFL_MinimumOp : TFL_Op<"minimum", [
|
||||
let hasOptions = 0;
|
||||
}
|
||||
|
||||
def TFL_MulOp : TFL_Op<"mul", [Broadcastable, NoSideEffect, Commutative]> {
|
||||
def TFL_MulOp : TFL_Op<"mul", [ResultsBroadcastableShape, NoSideEffect, Commutative]> {
|
||||
let summary = "Multiplication operator";
|
||||
|
||||
let description = [{
|
||||
@ -1771,6 +1780,8 @@ def TFL_NegOp: TFL_Op<"neg", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let results = (outs AnyTensor:$y);
|
||||
|
||||
let hasOptions = 0b1;
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
@ -1804,14 +1815,14 @@ def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8]>>:$values,
|
||||
Variadic<TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8, QI16]>>:$values,
|
||||
|
||||
I32Attr:$values_count,
|
||||
I32Attr:$axis
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8]>:$output
|
||||
TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8, QI16]>:$output
|
||||
);
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
@ -1909,7 +1920,7 @@ def TFL_PadV2Op : TFL_Op<"padv2", [
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_PowOp : TFL_Op<"pow", [Broadcastable, NoSideEffect, NoQuantizableResult]> {
|
||||
def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
|
||||
let summary = "Power operator";
|
||||
|
||||
let description = [{
|
||||
@ -2278,7 +2289,7 @@ def TFL_SquareOp: TFL_Op<"square", [
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TFL_SubOp : TFL_Op<"sub", [Broadcastable, NoSideEffect]> {
|
||||
def TFL_SubOp : TFL_Op<"sub", [ResultsBroadcastableShape, NoSideEffect]> {
|
||||
let summary = "Subtraction operator";
|
||||
|
||||
let description = [{
|
||||
@ -2306,7 +2317,7 @@ def TFL_SubOp : TFL_Op<"sub", [Broadcastable, NoSideEffect]> {
|
||||
// TODO(jpienaar): Expand the kernel implementation to support all types besides
|
||||
// I32 and F32.
|
||||
def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [
|
||||
Broadcastable, NoSideEffect, NoQuantizableResult]> {
|
||||
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
|
||||
let summary = "Squared difference operator";
|
||||
|
||||
let description = [{
|
||||
@ -2345,9 +2356,9 @@ def TFL_TanhOp: TFL_Op<"tanh", [
|
||||
let results = (outs TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$y);
|
||||
}
|
||||
|
||||
def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
|
||||
def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, SameOperandsAndResultsScale,
|
||||
PredOpTrait<"resultant element type needs to match first operand type",
|
||||
TCresVTEtIsSameAsOp<0,0>>]> {
|
||||
TFL_TCresVTEtIsSameAsOp<0,0>>]> {
|
||||
let summary = "Tile operator.";
|
||||
let description = [{
|
||||
Constructs a tensor by tiling a given tensor.
|
||||
@ -2360,10 +2371,11 @@ def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I1, I32, I64, TFL_Uint8]>:$input,
|
||||
TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8]>:$input,
|
||||
TFL_I32OrI64Tensor:$multiples);
|
||||
|
||||
let results = (outs TensorOf<[F32, I1, I32, I64, TFL_Uint8]>:$output);
|
||||
let results = (outs
|
||||
TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8]>:$output);
|
||||
|
||||
let hasOptions = 0;
|
||||
}
|
||||
@ -2373,7 +2385,7 @@ def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
|
||||
// TODO(jpienaar): Check that k is less or equal the internal dimension
|
||||
def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>,
|
||||
PredOpTrait<"result and input element type match",
|
||||
TCresVTEtIsSameAsOp<0,0>>]> {
|
||||
TCresVTEtIsSameAsOp<0,0>>, SameOperandsAndResultsScale]> {
|
||||
let summary = "TopK operator";
|
||||
|
||||
let description = [{
|
||||
@ -2383,11 +2395,11 @@ def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>,
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I8, I32, I64, TFL_Uint8]>:$input,
|
||||
TensorOf<[F32, I8, I32, I64, TFL_Uint8, QI8, QUI8]>:$input,
|
||||
I32Tensor:$k);
|
||||
|
||||
let results = (outs
|
||||
AnyTensor:$values,
|
||||
TensorOf<[F32, I8, I32, I64, TFL_Uint8, QI8, QUI8]>:$values,
|
||||
I32Tensor:$indices);
|
||||
|
||||
let builders = [OpBuilder<"Builder *builder, OperationState &result, "
|
||||
@ -2426,7 +2438,7 @@ def TFL_TransposeOp : TFL_Op<"transpose",
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect]> {
|
||||
def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
let summary = "Unpacks a tensor along a dimension into multiple tensors";
|
||||
|
||||
let description = [{
|
||||
@ -2642,7 +2654,9 @@ def TFL_ResizeBilinearOp: TFL_Op<"resize_bilinear", [
|
||||
// TODO(ycling): Support quantized types.
|
||||
TensorOf<[F32, I32, QI8, QUI8]>:$input,
|
||||
TensorOf<[I32]>:$size,
|
||||
BoolAttr:$align_corners);
|
||||
BoolAttr:$align_corners,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$half_pixel_centers
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F32, QI8, QUI8]>:$output
|
||||
@ -2751,12 +2765,11 @@ def TFL_CastOp : TFL_Op<"cast", [
|
||||
Casts input from input type to output type.
|
||||
}];
|
||||
|
||||
// TODO(b/135538711): Add complex types here.
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I1, I32, I64, TFL_Quint8, TFL_Uint8]>:$input
|
||||
TensorOf<[F32, I1, I32, I64, TFL_Quint8, TFL_Uint8, Complex<F<32>>]>:$input
|
||||
);
|
||||
|
||||
let results = (outs TensorOf<[F32, I1, I32, I64]>:$output);
|
||||
let results = (outs TensorOf<[F32, I1, I32, I64, Complex<F<32>>]>:$output);
|
||||
|
||||
// TFLite's cast op does not utilize CastOptions, instead derives types
|
||||
// from the TfLiteTensors.
|
||||
@ -2856,7 +2869,9 @@ def TFL_FakeQuantOp : TFL_Op<"fake_quant", [NoSideEffect]> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$input,
|
||||
// The expected [min, max] range of values.
|
||||
MinMaxAttr:$minmax,
|
||||
F32Attr:$min,
|
||||
F32Attr:$max,
|
||||
|
||||
// The bitwidth of the quantization; between 2 and 16, inclusive.
|
||||
I32Attr:$num_bits,
|
||||
// Quantization range starts from 0 or 1; starts from 1 if true.
|
||||
@ -2865,6 +2880,8 @@ def TFL_FakeQuantOp : TFL_Op<"fake_quant", [NoSideEffect]> {
|
||||
let results = (outs AnyTensor:$output);
|
||||
|
||||
let hasCanonicalizer = 0b1;
|
||||
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_QConstOp : Op<TFL_Dialect, "pseudo_qconst", [
|
||||
@ -2911,6 +2928,20 @@ def TFL_QuantizeOp: TFL_Op<"quantize", [
|
||||
let results = (outs AnyTensor:$output);
|
||||
}
|
||||
|
||||
def TFL_DensifyOp: TFL_Op<"densify", [NoSideEffect,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Densify operator";
|
||||
|
||||
let description = [{
|
||||
Converts sparse tensor to dense format.
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTensor:$input);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LSTM Ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -3000,7 +3031,7 @@ def TFL_LSTMOp :
|
||||
LstmOptionalPeepholeWeightConstraint,
|
||||
LstmProjectionWeightBiasConstraint,
|
||||
LstmResultConstraint,
|
||||
StatefulOperands<[18, 19]>]> {
|
||||
TFL_StatefulOp]> {
|
||||
let summary = "The full lstm operator";
|
||||
|
||||
let description = [{
|
||||
@ -3084,6 +3115,11 @@ Ba et al. “Layer Normalization”
|
||||
let hasOptions = 1;
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// StatefulOpInterface:
|
||||
std::vector<int> GetStatefulOperands() { return {18, 19}; }
|
||||
}];
|
||||
}
|
||||
|
||||
// UnidirectionalSequenceLstm op.
|
||||
@ -3095,7 +3131,7 @@ def TFL_UnidirectionalSequenceLSTMOp :
|
||||
LstmOptionalPeepholeWeightConstraint,
|
||||
LstmProjectionWeightBiasConstraint,
|
||||
LstmResultConstraint,
|
||||
StatefulOperands<[18, 19]>]> {
|
||||
TFL_StatefulOp]> {
|
||||
let summary = "Unidirectional sequence lstm operator";
|
||||
|
||||
let description = [{
|
||||
@ -3164,6 +3200,11 @@ def TFL_UnidirectionalSequenceLSTMOp :
|
||||
let hasOptions = 1;
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// StatefulOpInterface:
|
||||
std::vector<int> GetStatefulOperands() { return {18, 19}; }
|
||||
}];
|
||||
}
|
||||
|
||||
def RnnResultConstraint : PredOpTrait<
|
||||
@ -3173,7 +3214,7 @@ def RnnResultConstraint : PredOpTrait<
|
||||
// UnidirectionalSequenceRNN op.
|
||||
def TFL_UnidirectionalSequenceRNNOp :
|
||||
TFL_Op<"unidirectional_sequence_rnn",
|
||||
[RnnResultConstraint, StatefulOperands<[4]>]> {
|
||||
[RnnResultConstraint, TFL_StatefulOp]> {
|
||||
|
||||
let summary = "Unidirectional sequence rnn operator";
|
||||
|
||||
@ -3217,6 +3258,11 @@ def TFL_UnidirectionalSequenceRNNOp :
|
||||
let customOption = "SequenceRNNOptions";
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// StatefulOpInterface:
|
||||
std::vector<int> GetStatefulOperands() { return {4}; }
|
||||
}];
|
||||
}
|
||||
|
||||
def TFL_WhereOp : TFL_Op<"where", [NoSideEffect]> {
|
||||
@ -3268,7 +3314,7 @@ def SVDFResultConstraint: PredOpTrait<
|
||||
// SVDF op.
|
||||
def TFL_SVDFOp :
|
||||
TFL_Op<"svdf",
|
||||
[SVDFResultConstraint, StatefulOperands<[4]>]> {
|
||||
[SVDFResultConstraint, TFL_StatefulOp]> {
|
||||
|
||||
let summary = "Single value decomposition filter operator";
|
||||
|
||||
@ -3304,6 +3350,72 @@ def TFL_SVDFOp :
|
||||
let hasOptions = 1;
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// StatefulOpInterface:
|
||||
std::vector<int> GetStatefulOperands() { return {4}; }
|
||||
}];
|
||||
}
|
||||
|
||||
def TFL_SegmentSumOp: TFL_Op<"segment_sum", [NoSideEffect]> {
|
||||
let summary = "SegmentSum operator";
|
||||
|
||||
let description = [{
|
||||
Computes the sum along segments of a tensor.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I32]>:$data,
|
||||
I32Tensor:$segment_ids
|
||||
);
|
||||
let results = (outs TensorOf<[F32, I32]>:$output);
|
||||
}
|
||||
|
||||
def TFL_YieldOp : Op<TFL_Dialect, "yield", [Terminator]> {
|
||||
let summary = "Yield operation";
|
||||
let description = [{
|
||||
The "yield" operation represents a return operation within the conditional
|
||||
and body of structured control flow (e.g., while). The operation takes
|
||||
variable number of operands and produces no results. The operand number and
|
||||
types must match the signature of the region that contains the operation.
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<AnyType>:$operands);
|
||||
}
|
||||
|
||||
def TFL_WhileOp : Op<TFL_Dialect, "while", [
|
||||
DeclareOpInterfaceMethods<LoopLikeOpInterface>,
|
||||
SingleBlockImplicitTerminator<"YieldOp">,
|
||||
// Make isolated from above to force values through operands to simplify
|
||||
// exporting to subgraphs.
|
||||
IsolatedFromAbove]> {
|
||||
let summary = [{While loop}];
|
||||
|
||||
let description = [{
|
||||
output = input; while (cond(output)) { output = body(output) }
|
||||
|
||||
While loop where all values are passes through arguments with no implicit
|
||||
capture.
|
||||
|
||||
input: A list of input tensors whose types are T.
|
||||
output: A list of output tensors whose types are T.
|
||||
cond: A region takes 'input' and returns a boolean scalar tensor.
|
||||
body: A region that takes a list of tensors and returns another
|
||||
list of tensors. Both lists have the same types.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<AnyTensor>:$input,
|
||||
|
||||
// Used to map StatelessWhile and While op defined in TensorFlow to a common
|
||||
// op.
|
||||
DefaultValuedAttr<BoolAttr, "false">:$is_stateless
|
||||
);
|
||||
let results = (outs Variadic<AnyTensor>:$output);
|
||||
|
||||
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
#endif // TFL_OPS
|
||||
|
@ -1,67 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file defines the op traits used in the MLIR TensorFlow Lite dialect.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace OpTrait {
|
||||
namespace TFL {
|
||||
|
||||
// The trait to specify that the specified operands of the TFL op are stateful.
|
||||
// This is used as a trait like this:
|
||||
//
|
||||
// class LSTMOp
|
||||
// : public Op<LSTMOp, OpTrait::TFL::StatefulOperands<18, 19>::Impl> {
|
||||
//
|
||||
template <int... Operands>
|
||||
class StatefulOperands {
|
||||
public:
|
||||
template <typename ConcreteType>
|
||||
class Impl
|
||||
: public TraitBase<ConcreteType, StatefulOperands<Operands...>::Impl> {
|
||||
public:
|
||||
static std::vector<int> GetStatefulOperands() {
|
||||
return std::vector<int>({Operands...});
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
// The trait to specify the channel dimension index of the input (first operand)
|
||||
// of an affine TFL op (Conv2D, DepthwiseConv2D, FullyConnected).
|
||||
//
|
||||
// class Conv2DOp
|
||||
// : public Op<Conv2DOp, OpTrait::TFL::ChannelDimIndex<0>::Impl> {
|
||||
//
|
||||
template <int Index>
|
||||
class ChannelDimIndex {
|
||||
public:
|
||||
template <typename ConcreteType>
|
||||
class Impl : public TraitBase<ConcreteType, ChannelDimIndex<Index>::Impl> {
|
||||
public:
|
||||
static int GetChannelDimIndex() { return Index; }
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace OpTrait
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
|
@ -41,13 +41,20 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/flex/delegate.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/optional_debug_tools.h"
|
||||
|
||||
using llvm::cl::desc;
|
||||
using llvm::cl::init;
|
||||
using llvm::cl::opt;
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static opt<std::string> inputFileName(llvm::cl::Positional,
|
||||
llvm::cl::desc("<input file>"),
|
||||
llvm::cl::init("-"));
|
||||
static opt<std::string> input_filename(llvm::cl::Positional,
|
||||
desc("<input file>"), init("-"));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static opt<bool> dump_state("dump-interpreter-state",
|
||||
desc("dump interpreter state post execution"),
|
||||
init(false));
|
||||
|
||||
// TODO(jpienaar): Move these functions to some debug utils.
|
||||
static std::string TfLiteTensorDimString(const TfLiteTensor& tensor) {
|
||||
@ -82,9 +89,9 @@ int main(int argc, char** argv) {
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR TFLite runner\n");
|
||||
|
||||
auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(inputFileName.c_str());
|
||||
auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(input_filename.c_str());
|
||||
if (std::error_code error = file_or_err.getError()) {
|
||||
LOG(ERROR) << argv[0] << ": could not open input file '" << inputFileName
|
||||
LOG(ERROR) << argv[0] << ": could not open input file '" << input_filename
|
||||
<< "': " << error.message() << "\n";
|
||||
return 1;
|
||||
}
|
||||
@ -133,5 +140,7 @@ int main(int argc, char** argv) {
|
||||
TfLiteTensorString(out).c_str());
|
||||
}
|
||||
|
||||
if (dump_state) tflite::PrintInterpreterState(interpreter.get());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
@ -122,7 +122,7 @@ static void EmitOptionBuilders(const RecordKeeper &record_keeper,
|
||||
os << formatv(
|
||||
" auto {0} = Convert{1}ForOptionWriter(op.{0}(), fbb);\n",
|
||||
val.getName(), record->getClasses()[0]->getName());
|
||||
options.push_back(val.getName());
|
||||
options.push_back(std::string(val.getName()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -71,18 +71,17 @@ cc_library(
|
||||
"quantization_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"quantization_traits.h",
|
||||
"quantization_utils.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
# TODO(fengliuai): remove this dependence.
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -206,10 +206,17 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
|
||||
std::unique_ptr<OpPassBase<FuncOp>>
|
||||
CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) {
|
||||
auto get_name_func = [](Operation *op) {
|
||||
if (auto name = op->getAttrOfType<StringAttr>("name"))
|
||||
return name.getValue();
|
||||
else
|
||||
return llvm::StringRef("");
|
||||
Location loc = op->getLoc();
|
||||
if (auto name = loc.dyn_cast<NameLoc>()) {
|
||||
return name.getName().strref();
|
||||
} else if (auto fused_name = loc.dyn_cast<FusedLoc>()) {
|
||||
for (auto sub_loc : fused_name.getLocations()) {
|
||||
if (auto named_sub_loc = sub_loc.dyn_cast<NameLoc>()) {
|
||||
return named_sub_loc.getName().strref();
|
||||
}
|
||||
}
|
||||
}
|
||||
return llvm::StringRef("");
|
||||
};
|
||||
|
||||
return CreateImportQuantStatsPass(get_name_func, stats_str);
|
||||
|
@ -12,6 +12,7 @@ package_group(
|
||||
includes = ["//third_party/mlir:subpackages"],
|
||||
packages = [
|
||||
"//learning/brain/experimental/mlir/...",
|
||||
"//tensorflow/compiler/mlir/lite/...",
|
||||
"//tensorflow/lite/...",
|
||||
],
|
||||
)
|
||||
@ -23,7 +24,6 @@ cc_library(
|
||||
],
|
||||
hdrs = [
|
||||
"quantize_model.h",
|
||||
"//tensorflow/compiler/mlir/lite:transforms/passes.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite:common",
|
||||
@ -42,6 +42,24 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfl_to_std",
|
||||
srcs = [
|
||||
"tfl_to_std.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"tfl_to_std.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
],
|
||||
)
|
||||
|
||||
# Binary to apply quantization on the annotated files.
|
||||
tf_cc_binary(
|
||||
name = "tfl_quantizer",
|
||||
|
@ -73,19 +73,19 @@ TfLiteStatus QuantizeModel(
|
||||
|
||||
// Apply quantization passes
|
||||
PassManager pm(module->getContext());
|
||||
TFL::QuantizationSpecs pass_config;
|
||||
pass_config.inference_type = tensorflow::DT_QINT8;
|
||||
pass_config.post_training_quantization = true;
|
||||
TFL::QuantizationSpecs quant_specs;
|
||||
quant_specs.inference_type = tensorflow::DT_QINT8;
|
||||
quant_specs.post_training_quantization = true;
|
||||
|
||||
bool emit_adaptor = false;
|
||||
auto input_tf_type = tflite::TflTypeToTfType(input_type);
|
||||
if (input_tf_type == tensorflow::DT_FLOAT) {
|
||||
emit_adaptor = true;
|
||||
} else if (input_tf_type == tensorflow::DT_UINT8) {
|
||||
pass_config.inference_type = tensorflow::DT_QUINT8;
|
||||
quant_specs.inference_type = tensorflow::DT_QUINT8;
|
||||
}
|
||||
|
||||
pm.addPass(TFL::CreatePrepareQuantizePass(pass_config));
|
||||
pm.addPass(TFL::CreatePrepareQuantizePass(quant_specs));
|
||||
pm.addPass(TFL::CreateQuantizePass());
|
||||
pm.addPass(TFL::CreatePostQuantizePass(emit_adaptor));
|
||||
|
||||
|
@ -0,0 +1,62 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h"
|
||||
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
void ConvertTFLQuantOpsToMlirQuantOps(FuncOp func) {
|
||||
OpBuilder b(func);
|
||||
func.walk([&](Operation* op) {
|
||||
b.setInsertionPoint(op);
|
||||
if (auto dq = llvm::dyn_cast<DequantizeOp>(op)) {
|
||||
auto dcast = b.create<quant::DequantizeCastOp>(
|
||||
dq.getLoc(), dq.output().getType(), dq.input());
|
||||
dq.output().replaceAllUsesWith(dcast);
|
||||
dq.erase();
|
||||
} else if (auto q = llvm::dyn_cast<QuantizeOp>(op)) {
|
||||
auto qcast = b.create<quant::QuantizeCastOp>(
|
||||
q.getLoc(), q.output().getType(), q.input());
|
||||
q.output().replaceAllUsesWith(qcast);
|
||||
q.erase();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void ConvertMlirQuantOpsToTFLQuantOps(FuncOp func) {
|
||||
OpBuilder b(func);
|
||||
func.walk([&](Operation* op) {
|
||||
b.setInsertionPoint(op);
|
||||
if (auto dq = llvm::dyn_cast<quant::DequantizeCastOp>(op)) {
|
||||
auto dcast = b.create<DequantizeOp>(dq.getLoc(), dq.getResult().getType(),
|
||||
dq.arg());
|
||||
dq.getResult().replaceAllUsesWith(dcast);
|
||||
dq.erase();
|
||||
} else if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(op)) {
|
||||
auto out_type = q.getResult().getType();
|
||||
auto qcast = b.create<QuantizeOp>(q.getLoc(), out_type, q.arg(),
|
||||
TypeAttr::get(out_type));
|
||||
q.getResult().replaceAllUsesWith(qcast);
|
||||
q.erase();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
34
tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h
Normal file
34
tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h
Normal file
@ -0,0 +1,34 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_
|
||||
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
// Converts all the tfl.quantize/tfl.dequantize ops to the ops in the mlir.quant
|
||||
// dialect ones in the function.
|
||||
void ConvertTFLQuantOpsToMlirQuantOps(FuncOp func);
|
||||
|
||||
// Converts all the mlir.quant dialect ops to the tfl.quantize/tfl.dequantize
|
||||
// ops in the function.
|
||||
void ConvertMlirQuantOpsToTFLQuantOps(FuncOp func);
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_
|
@ -22,21 +22,6 @@ limitations under the License.
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/QuantOps/QuantPredicates.td"
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Min-max range pair definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// A pair of floating point values which defines the min and max of a value
|
||||
// range for quantization. The attribute is allowed to be empty or
|
||||
// have 2 elements.
|
||||
def MinMaxAttr : Attr<Or<[CPred<"$_self.cast<ArrayAttr>().size() == 0">,
|
||||
CPred<"$_self.cast<ArrayAttr>().size() == 2">]>,
|
||||
"min-max range pair"> {
|
||||
let storageType = [{ ArrayAttr }];
|
||||
let returnType = [{ ArrayRef<Attribute> }];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// QuantizedType definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -23,6 +23,8 @@ limitations under the License.
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
@ -34,14 +36,14 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
#define DEBUG_TYPE "quantization-driver"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
namespace quant {
|
||||
namespace {
|
||||
static bool EmptyParams(QuantParams p) { return p == quant::QuantizedType(); }
|
||||
|
||||
@ -282,6 +284,37 @@ class QuantizationDriver {
|
||||
cached.first->second = InitializeState(op, index, res, /*as_result=*/true);
|
||||
}
|
||||
|
||||
void DumpStates(Operation *current_op) {
|
||||
if (current_op) {
|
||||
llvm::errs() << "\n\n\n" << current_op->getName() << "\n";
|
||||
}
|
||||
fn_.walk([&](Operation *op) {
|
||||
if (llvm::isa<quant::QuantizeCastOp>(op) ||
|
||||
llvm::isa<quant::DequantizeCastOp>(op) || llvm::isa<ConstantOp>(op))
|
||||
return;
|
||||
if (current_op == op) llvm::errs() << "===>>>";
|
||||
llvm::errs() << op->getName() << " : (";
|
||||
for (auto i = 0; i < op->getNumOperands(); ++i) {
|
||||
if (auto params = GetOperandQuantState(op, i).params)
|
||||
params.print(llvm::errs());
|
||||
else
|
||||
op->getOperand(i).getType().cast<ShapedType>().getElementType().print(
|
||||
llvm::errs());
|
||||
llvm::errs() << ",";
|
||||
}
|
||||
llvm::errs() << ") -> (";
|
||||
for (auto i = 0; i < op->getNumResults(); ++i) {
|
||||
if (auto params = GetResultQuantState(op, i).params)
|
||||
params.print(llvm::errs());
|
||||
else
|
||||
op->getResult(i).getType().cast<ShapedType>().getElementType().print(
|
||||
llvm::errs());
|
||||
llvm::errs() << ",";
|
||||
}
|
||||
llvm::errs() << ")\n";
|
||||
});
|
||||
}
|
||||
|
||||
FuncOp fn_;
|
||||
OpBuilder builder_;
|
||||
bool is_signed_;
|
||||
@ -351,7 +384,7 @@ int QuantizationDriver::InitializeState(Operation *op, int index, Value val,
|
||||
}
|
||||
|
||||
bool QuantizationDriver::SetConstantResultParams(Operation *op) {
|
||||
ElementsAttr attr;
|
||||
DenseFPElementsAttr attr;
|
||||
Value res = op->getResult(0);
|
||||
if (!matchPattern(res, m_Constant(&attr))) {
|
||||
return false;
|
||||
@ -458,11 +491,9 @@ void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
|
||||
// This value isn't an expressed type (float), skip.
|
||||
if (!new_type) return;
|
||||
|
||||
TypeAttr type_attr = TypeAttr::get(new_type);
|
||||
auto quantize =
|
||||
builder_.create<TFL::QuantizeOp>(loc, new_type, value, type_attr);
|
||||
auto dequantize = builder_.create<TFL::DequantizeOp>(loc, expressed_type,
|
||||
quantize.output());
|
||||
auto quantize = builder_.create<quant::QuantizeCastOp>(loc, new_type, value);
|
||||
auto dequantize = builder_.create<quant::DequantizeCastOp>(
|
||||
loc, expressed_type, quantize.getResult());
|
||||
// `original_result` has a use to `quantize`, so this will replace that use
|
||||
// by the result of `dequantize`. Remember to reset that use afterwards
|
||||
value.replaceAllUsesWith(dequantize);
|
||||
@ -476,7 +507,7 @@ void QuantizationDriver::RequantizeOpResult(Operation *op, int index,
|
||||
Value value = op->getResult(index);
|
||||
if (state->pos == RequantizeState::ON_OUTPUT) {
|
||||
Operation *user = value.getUses().begin().getUser();
|
||||
if (llvm::isa<TFL::QuantizeOp>(user)) {
|
||||
if (llvm::isa<quant::QuantizeCastOp>(user)) {
|
||||
// The requantize op is inserted between `quantize` and `dequantize` ops.
|
||||
value = user->getResult(0);
|
||||
builder_.setInsertionPointAfter(user);
|
||||
@ -491,8 +522,8 @@ void QuantizationDriver::RequantizeArg(BlockArgument arg,
|
||||
builder_.setInsertionPointToStart(arg.getOwner());
|
||||
if (value.hasOneUse()) {
|
||||
auto user = value.use_begin().getUser();
|
||||
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
|
||||
value = q.output();
|
||||
if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
|
||||
value = q.getResult();
|
||||
builder_.setInsertionPoint(arg.getOwner(), ++Block::iterator(user));
|
||||
}
|
||||
}
|
||||
@ -519,9 +550,8 @@ void QuantizationDriver::RequantizeValue(Value value, RequantizeState *state,
|
||||
// This value isn't an expressed type (float), skip.
|
||||
if (!new_type) return;
|
||||
|
||||
TypeAttr type_attr = TypeAttr::get(new_type);
|
||||
auto requantize_op =
|
||||
builder_.create<TFL::QuantizeOp>(loc, new_type, value, type_attr);
|
||||
builder_.create<quant::QuantizeCastOp>(loc, new_type, value);
|
||||
value.replaceAllUsesWith(requantize_op);
|
||||
requantize_op.getOperation()->replaceUsesOfWith(requantize_op, value);
|
||||
}
|
||||
@ -651,8 +681,8 @@ void QuantizationDriver::SetupAllStates() {
|
||||
// If the argument is quantized, it should only has one user.
|
||||
if (arg.hasOneUse()) {
|
||||
auto user = value.use_begin().getUser();
|
||||
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
|
||||
value = q.output();
|
||||
if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
|
||||
value = q.getResult();
|
||||
}
|
||||
}
|
||||
InitializeArgState(arg, value, &value_to_state);
|
||||
@ -660,7 +690,9 @@ void QuantizationDriver::SetupAllStates() {
|
||||
|
||||
fn_.walk([&](Operation *op) {
|
||||
if (op->isKnownTerminator() ||
|
||||
op->hasTrait<OpTrait::quant::NoQuantizableResult>())
|
||||
op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
|
||||
llvm::isa<quant::DequantizeCastOp>(op) ||
|
||||
llvm::isa<quant::QuantizeCastOp>(op))
|
||||
return;
|
||||
work_list_.push_back(op);
|
||||
|
||||
@ -669,8 +701,8 @@ void QuantizationDriver::SetupAllStates() {
|
||||
if (auto *inst = operand.getDefiningOp()) {
|
||||
// If the operand comes from a tfl.dequantize op, we use the quantized
|
||||
// input of this tfl.dequantize op to set the state.
|
||||
if (auto dq = llvm::dyn_cast<TFL::DequantizeOp>(inst)) {
|
||||
operand = dq.input();
|
||||
if (auto dq = llvm::dyn_cast<quant::DequantizeCastOp>(inst)) {
|
||||
operand = dq.arg();
|
||||
}
|
||||
}
|
||||
InitializeOperandState(op, i, operand, &value_to_state);
|
||||
@ -683,8 +715,8 @@ void QuantizationDriver::SetupAllStates() {
|
||||
// create the state and mark it immutable.
|
||||
if (result.hasOneUse()) {
|
||||
auto user = result.use_begin().getUser();
|
||||
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
|
||||
result = q.output();
|
||||
if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
|
||||
result = q.getResult();
|
||||
}
|
||||
}
|
||||
InitializeResultState(op, res, result, &value_to_state);
|
||||
@ -714,6 +746,8 @@ bool QuantizationDriver::PropagateParams() {
|
||||
Operation *op = work_list_.back();
|
||||
work_list_.pop_back();
|
||||
|
||||
LLVM_DEBUG(DumpStates(op));
|
||||
|
||||
// This op has been quantized, so we should not consider it again.
|
||||
if (llvm::is_contained(quantized_, op)) continue;
|
||||
quantized_.insert(op);
|
||||
@ -738,12 +772,23 @@ bool QuantizationDriver::PropagateParams() {
|
||||
}
|
||||
|
||||
// Use the final state to set all the operands' parameters.
|
||||
for (int i = 0, e = op->getNumOperands(); i != e; ++i)
|
||||
changed |= SetOperandParams(op, i, params);
|
||||
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
|
||||
if (auto type = op->getOperand(i).getType().dyn_cast<ShapedType>()) {
|
||||
// Without this check, it will accidently propagate the quantization
|
||||
// information by the shared non-float tensors.
|
||||
if (type.getElementType().isa<FloatType>())
|
||||
changed |= SetOperandParams(op, i, params);
|
||||
}
|
||||
}
|
||||
|
||||
// Use the final state to set all the results' parameters.
|
||||
for (int res = 0, e = op->getNumResults(); res != e; ++res)
|
||||
changed |= SetResultParams(op, res, params);
|
||||
if (auto type = op->getResult(res).getType().dyn_cast<ShapedType>()) {
|
||||
// Without this check, it will accidently propagate the quantization
|
||||
// information by the shared non-float-tensors.
|
||||
if (type.getElementType().isa<FloatType>())
|
||||
changed |= SetResultParams(op, res, params);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(fengliuai): make the bit width configurable.
|
||||
@ -822,5 +867,5 @@ void ApplyQuantizationParamsPropagation(
|
||||
.Run();
|
||||
}
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace quant
|
||||
} // namespace mlir
|
||||
|
@ -70,7 +70,8 @@ class FixedResultUniformScale {
|
||||
QuantizedType GetResultQuantizedType(int index) {
|
||||
auto op = this->getOperation();
|
||||
auto result_type =
|
||||
op->getResult(index).getType().template cast<TensorType>();
|
||||
op->getResult(index).getType().template cast<ShapedType>();
|
||||
if (!result_type.getElementType().template isa<FloatType>()) return {};
|
||||
Builder builder(op->getContext());
|
||||
IntegerType storage_type = builder.getIntegerType(BitWidth);
|
||||
const double scale = static_cast<double>(ScaleMantissa) *
|
||||
|
@ -30,10 +30,9 @@ limitations under the License.
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
namespace quant {
|
||||
|
||||
const float kNearZeroTolerance = 1.0e-6;
|
||||
|
||||
@ -66,6 +65,37 @@ static Type GetQuantizedType(Builder builder, Type input_type,
|
||||
return converter.convert(quantizedEleType);
|
||||
}
|
||||
|
||||
// TODO(fengliuai): promote this utility method to mlir QuantOps.
|
||||
TypeAttr RescaleQuantizedType(Type input, Attribute factor) {
|
||||
auto factor_values = factor.dyn_cast_or_null<DenseFPElementsAttr>();
|
||||
if (!factor_values) return {};
|
||||
auto ele_type = quant::QuantizedType::getQuantizedElementType(input);
|
||||
if (!ele_type) return {};
|
||||
if (auto qtype = ele_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
|
||||
ArrayRef<double> scales = qtype.getScales();
|
||||
// Broadcasting hasn't been implemented yet.
|
||||
if (scales.size() != factor_values.getNumElements()) return {};
|
||||
SmallVector<double, 4> new_scales;
|
||||
new_scales.reserve(scales.size());
|
||||
auto scales_iter = scales.begin();
|
||||
for (auto f : factor_values) {
|
||||
new_scales.push_back(*(scales_iter++) *
|
||||
std::fabs(FloatAttr::getValueAsDouble(f)));
|
||||
}
|
||||
// We are assuming symmetric quantization.
|
||||
auto new_ele_type = quant::UniformQuantizedPerAxisType::get(
|
||||
qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(),
|
||||
new_scales, qtype.getZeroPoints(), qtype.getQuantizedDimension(),
|
||||
qtype.getStorageTypeMin(), qtype.getStorageTypeMax());
|
||||
if (auto new_type = new_ele_type.castFromExpressedType(
|
||||
quant::QuantizedType::castToExpressedType(input))) {
|
||||
return TypeAttr::get(new_type);
|
||||
}
|
||||
}
|
||||
// Currently, we only support per-axis quantized type.
|
||||
return {};
|
||||
}
|
||||
|
||||
TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min,
|
||||
Attribute max, int quant_dim,
|
||||
IntegerAttr num_bits, BoolAttr narrow_range,
|
||||
@ -369,7 +399,7 @@ static bool PreferResultScale(Operation* op) {
|
||||
for (auto operand : op->getOperands()) {
|
||||
if (auto operand_type = operand.getType().dyn_cast<ShapedType>()) {
|
||||
if (operand_type.getElementType().isa<FloatType>()) {
|
||||
if (float_operands++ > 1) return true;
|
||||
if (++float_operands > 1) return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -429,7 +459,7 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
|
||||
}
|
||||
|
||||
// Step 2: backward pass: For the ops skiped in the forward pass, propagate
|
||||
// its results scale backwards.
|
||||
// its results scale backwards as far as possible.
|
||||
func.walk([&](quant::StatisticsOp stats_op) {
|
||||
if (redundant_stats_ops.find(stats_op) == redundant_stats_ops.end()) {
|
||||
all_stats_ops.push_back(stats_op);
|
||||
@ -441,8 +471,7 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
|
||||
all_stats_ops.pop_back();
|
||||
|
||||
if (auto def = stats_op.arg().getDefiningOp()) {
|
||||
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() &&
|
||||
PreferResultScale(def)) {
|
||||
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>()) {
|
||||
for (auto input : def->getOperands()) {
|
||||
if (auto next_stats = llvm::dyn_cast_or_null<quant::StatisticsOp>(
|
||||
input.getDefiningOp())) {
|
||||
@ -465,5 +494,5 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
|
||||
// Returns false if the steps finish without errors.
|
||||
return false;
|
||||
}
|
||||
} // namespace TFL
|
||||
} // namespace quant
|
||||
} // namespace mlir
|
||||
|
@ -38,7 +38,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
namespace quant {
|
||||
|
||||
using QuantParams = quant::QuantizedType;
|
||||
using SignedInteger = std::pair<unsigned, unsigned>; // bitwidth and sign
|
||||
@ -113,8 +113,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
|
||||
|
||||
rewriter.setInsertionPointAfter(op);
|
||||
Type result_type = quant_type.castFromExpressedType(op.getType());
|
||||
auto q = rewriter.create<Q>(op.getLoc(), result_type, op.arg(),
|
||||
TypeAttr::get(result_type));
|
||||
auto q = rewriter.create<Q>(op.getLoc(), result_type, op.arg());
|
||||
auto dq = rewriter.create<DQ>(op.getLoc(), op.getType(), q);
|
||||
op.getResult().replaceAllUsesWith(dq);
|
||||
q.getOperation()->replaceUsesOfWith(dq, op.arg());
|
||||
@ -168,9 +167,12 @@ struct QuantizationPattern : public RewritePattern {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
// If it is terminator or not quantizable, we shouldn't rewrite.
|
||||
// If it is terminator or not quantizable or any ops form the mlir quant
|
||||
// ops dialect, we shouldn't rewrite.
|
||||
if (quantized_op->isKnownTerminator() ||
|
||||
quantized_op->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
|
||||
quantized_op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
|
||||
llvm::isa<quant::QuantizeCastOp>(quantized_op) ||
|
||||
llvm::isa<quant::DequantizeCastOp>(quantized_op)) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
@ -316,7 +318,7 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
|
||||
|
||||
PatternMatchResult matchAndRewrite(Q op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
Type output_type = op.output().getType();
|
||||
Type output_type = op.getResult().getType();
|
||||
auto qtype = QType::getQuantizedElementType(output_type);
|
||||
if (!qtype || qtype.isSigned()) return this->matchFailure();
|
||||
|
||||
@ -352,14 +354,19 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
|
||||
return this->matchFailure();
|
||||
}
|
||||
|
||||
if (!new_qtype) return this->matchFailure();
|
||||
Type new_output_type = new_qtype.castFromExpressedType(
|
||||
QType::castToExpressedType(output_type));
|
||||
rewriter.replaceOpWithNewOp<Q>(op, new_output_type, op.input(),
|
||||
TypeAttr::get(new_output_type));
|
||||
rewriter.replaceOpWithNewOp<Q>(op, new_output_type, op.arg());
|
||||
return this->matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// Given a quantized type `input`, magnifying its scales by the factor stored in
|
||||
// `factor`. If `input` isn't a quantized type or the `factor` doesn't match the
|
||||
// dimension size of `input` or isn't floating-point, nullptr will be returned.
|
||||
TypeAttr RescaleQuantizedType(Type input, Attribute factor);
|
||||
|
||||
// Converts the min/max/num_bits/narrow_range information to a
|
||||
// QuantizedType, and then returns the attribute containing the QuantizedType.
|
||||
// The `min` and `max` arguments can be FloatAttr or DenseFPElementsAttr and
|
||||
@ -438,7 +445,7 @@ void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
|
||||
bool RemoveRedundantStatsOps(mlir::FuncOp func,
|
||||
OpQuantSpecGetter op_quant_spec_getter);
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace quant
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_
|
||||
|
36
tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD
Normal file
36
tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD
Normal file
@ -0,0 +1,36 @@
|
||||
package(
|
||||
default_visibility = [
|
||||
":friends",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "friends",
|
||||
includes = ["//third_party/mlir:subpackages"],
|
||||
packages = [
|
||||
"//tensorflow/compiler/mlir/...",
|
||||
"//tensorflow/compiler/mlir/lite/...",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_to_quant",
|
||||
srcs = [
|
||||
"tf_to_quant.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"passes.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
@ -0,0 +1,32 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_TENSORFLOW_PASSES_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_TENSORFLOW_PASSES_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
// Legalize the tf ops to the quant ops, so the quantization passes can work.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFToQuantPass();
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_TENSORFLOW_PASSES_H_
|
@ -0,0 +1,19 @@
|
||||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
|
||||
package(licenses = ["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [":test_utilities"],
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
||||
|
||||
# Bundle together all of the test utilities that are used by tests.
|
||||
filegroup(
|
||||
name = "test_utilities",
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir:tf-opt",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
],
|
||||
)
|
@ -0,0 +1,148 @@
|
||||
// RUN: tf-opt -tf-to-quant %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: fakeQuantPerChannelForActivation
|
||||
func @fakeQuantPerChannelForActivation(%arg0: tensor<8x3xf32>) -> (tensor<8x3xf32>) {
|
||||
%arg1 = constant dense<[0.0, -1.0, 1.0]> : tensor<3xf32>
|
||||
%arg2 = constant dense<[255.0, 254.0, 256.0]> : tensor<3xf32>
|
||||
%0 = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8x3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<8x3xf32>
|
||||
return %0 : tensor<8x3xf32>
|
||||
|
||||
// CHECK: %[[fq:.*]] = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %cst, %cst_0)
|
||||
// CHECK: %[[q:.*]] = "quant.qcast"(%[[fq]]) : (tensor<8x3xf32>) -> tensor<8x3x!quant.uniform<i8:f32:1, {1.000000e+00:-128,1.000000e+00:-127,1.000000e+00:-128}>>
|
||||
// CHECK: %[[dq:.*]] = "quant.dcast"(%[[q]])
|
||||
// CHECK: return %[[dq]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fakeQuantForActivation
|
||||
func @fakeQuantForActivation(tensor<8xf32>) -> (tensor<8xf32>) {
|
||||
^bb0(%arg0: tensor<8xf32>):
|
||||
%arg1 = constant dense<0.0> : tensor<f32>
|
||||
%arg2 = constant dense<255.0> : tensor<f32>
|
||||
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
|
||||
return %0 : tensor<8xf32>
|
||||
|
||||
// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0)
|
||||
// CHECK: %1 = "quant.qcast"(%0) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
|
||||
// CHECK: %2 = "quant.dcast"(%1)
|
||||
// CHECK: return %2
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fakeQuantForActivationNoDuplication
|
||||
func @fakeQuantForActivationNoDuplication(tensor<8xf32>) -> (tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>) {
|
||||
^bb0(%arg0: tensor<8xf32>):
|
||||
%arg1 = constant dense<0.0> : tensor<f32>
|
||||
%arg2 = constant dense<255.0> : tensor<f32>
|
||||
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
|
||||
%1 = "quant.qcast"(%0) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
|
||||
return %1 : tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
|
||||
|
||||
// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) {narrow_range = false, num_bits = 3 : i64}
|
||||
// CHECK: %1 = "quant.qcast"(%0) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
|
||||
// CHECK: return %1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fakeQuantFolded
|
||||
func @fakeQuantFolded() -> (tensor<8xf32>) {
|
||||
%in = constant dense<0.0> : tensor<8xf32>
|
||||
%min = constant dense<0.0> : tensor<f32>
|
||||
%max = constant dense<255.0> : tensor<f32>
|
||||
%mini = "tf.Identity"(%min) : (tensor<f32>) -> tensor<f32>
|
||||
%maxi = "tf.Identity"(%max) : (tensor<f32>) -> tensor<f32>
|
||||
%rst = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
|
||||
return %rst : tensor<8xf32>
|
||||
|
||||
// CHECK: %[[CONSTANT:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<8xf32>}
|
||||
// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT]]) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
|
||||
// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
|
||||
// CHECK: return %[[DEQUANTIZE]] : tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fakeQuantNotFolded
|
||||
func @fakeQuantNotFolded(tensor<8xf32>, tensor<f32>, tensor<f32>) -> (tensor<8xf32>) {
|
||||
^bb0(%arg0: tensor<8xf32>, %arg3: tensor<f32>, %arg4: tensor<f32>):
|
||||
%1 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg3, %arg4) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
|
||||
return %1 : tensor<8xf32>
|
||||
|
||||
// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2)
|
||||
// CHECK: return %0 : tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fakeQuantWithConv2D
|
||||
func @fakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
|
||||
^bb0(%arg: tensor<256x32x32x3xf32>) :
|
||||
%in = constant dense<0.0> : tensor<3x3x3x16xf32>
|
||||
%min = constant dense<0.0> : tensor<f32>
|
||||
%max = constant dense<255.0> : tensor<f32>
|
||||
%mini = "tf.Identity"(%min) : (tensor<f32>) -> tensor<f32>
|
||||
%maxi = "tf.Identity"(%max) : (tensor<f32>) -> tensor<f32>
|
||||
%fq = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<f32>, tensor<f32>) -> tensor<3x3x3x16xf32>
|
||||
%rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %rst : tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
|
||||
// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32, 1.000000e+00:-128>>
|
||||
// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
|
||||
// CHECK: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[DEQUANTIZE]])
|
||||
// CHECK: return %[[CONV]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: perChannelFakeQuantWithConv2D
|
||||
func @perChannelFakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
|
||||
^bb0(%arg: tensor<256x32x32x3xf32>) :
|
||||
%in = constant dense<0.0> : tensor<3x3x3x16xf32>
|
||||
%min = constant dense<0.0> : tensor<16xf32>
|
||||
%max = constant dense<255.0> : tensor<16xf32>
|
||||
%mini = "tf.Identity"(%min) : (tensor<16xf32>) -> tensor<16xf32>
|
||||
%maxi = "tf.Identity"(%max) : (tensor<16xf32>) -> tensor<16xf32>
|
||||
%fq = "tf.FakeQuantWithMinMaxVarsPerChannel"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<3x3x3x16xf32>
|
||||
%rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %rst : tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
|
||||
// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32:3,
|
||||
// CHECK-SAME: {1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,
|
||||
// CHECK-SAME: 1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128}>>
|
||||
// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
|
||||
// CHECK: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[DEQUANTIZE]])
|
||||
// CHECK: return %[[CONV]] : tensor<256x30x30x16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fakeQuantWithDepthwiseConv2D
|
||||
func @fakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
|
||||
^bb0(%arg: tensor<256x32x32x3xf32>) :
|
||||
%in = constant dense<0.0> : tensor<3x3x3x16xf32>
|
||||
%min = constant dense<0.0> : tensor<f32>
|
||||
%max = constant dense<255.0> : tensor<f32>
|
||||
%mini = "tf.Identity"(%min) : (tensor<f32>) -> tensor<f32>
|
||||
%maxi = "tf.Identity"(%max) : (tensor<f32>) -> tensor<f32>
|
||||
%fq = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<f32>, tensor<f32>) -> tensor<3x3x3x16xf32>
|
||||
%rst = "tf.DepthwiseConv2dNative"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %rst : tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
|
||||
// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32, 1.000000e+00:-128>>
|
||||
// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
|
||||
// CHECK: %[[CONV:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[DEQUANTIZE]])
|
||||
// CHECK: return %[[CONV]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: perChannelFakeQuantWithDepthwiseConv2D
|
||||
func @perChannelFakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
|
||||
^bb0(%arg: tensor<256x32x32x3xf32>) :
|
||||
%in = constant dense<0.0> : tensor<3x3x3x16xf32>
|
||||
%min = constant dense<0.0> : tensor<16xf32>
|
||||
%max = constant dense<255.0> : tensor<16xf32>
|
||||
%mini = "tf.Identity"(%min) : (tensor<16xf32>) -> tensor<16xf32>
|
||||
%maxi = "tf.Identity"(%max) : (tensor<16xf32>) -> tensor<16xf32>
|
||||
%fq = "tf.FakeQuantWithMinMaxVarsPerChannel"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<3x3x3x16xf32>
|
||||
%rst = "tf.DepthwiseConv2dNative"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %rst : tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
|
||||
// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32:3,
|
||||
// CHECK-SAME: {1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,
|
||||
// CHECK-SAME: 1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128}>>
|
||||
// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
|
||||
// CHECK: %[[CONV:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[DEQUANTIZE]])
|
||||
// CHECK: return %[[CONV]]
|
||||
}
|
@ -0,0 +1,162 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The pass to legalize the quantization emulation ops from TF.
|
||||
//
|
||||
namespace {
|
||||
|
||||
// Legalize TF quantization emulation ops to that in Quant ops dialect.
|
||||
struct LegalizeTFToQuant : public FunctionPass<LegalizeTFToQuant> {
|
||||
explicit LegalizeTFToQuant() = default;
|
||||
LegalizeTFToQuant(const LegalizeTFToQuant &) {}
|
||||
|
||||
/// Performs the lowering to Quant ops dialect.
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
// TODO(fengliuai): move this rule to PreparePatterns.td
|
||||
// TODO(b/140968741): propagate the sign from the command line. Currently all
|
||||
// the FakeQuant is assumed to targeting UIN8, but per-channel kernel is
|
||||
// actually INT8.
|
||||
// Inserts a "tfl.quantize" and "tfl.dequantize" op pair (QDQs) after the
|
||||
// "tf.FakeQuantWithMinMaxVarsOp" to be constant folded. Since the constant
|
||||
// folding logic will use a "std.constant" op to replace the
|
||||
// "tf.FakeQuantWithMinMaxVarsOp", the "tfl.quantize" op is used to preserve
|
||||
// the quantization parameters as a TypeAttr and "tfl.dequantize" op used to
|
||||
// convert the output type to the next op. Here are the transformations:
|
||||
//
|
||||
// input min cst max cst input min cst max cst
|
||||
// \ | | \ | |
|
||||
// \ (tf.Identity) (tf.Identity) => \ (tf.Identity) (tf.Identity)
|
||||
// \ | | \ | |
|
||||
// tf.FakeQuantWithMinMaxVars tf.FakeQuantWithMinMaxVars
|
||||
// | |
|
||||
// tf.quantize
|
||||
// |
|
||||
// tf.dequantize
|
||||
// |
|
||||
// If the input is a constant, the result pattern will eventually converted to
|
||||
//
|
||||
// quant-emulated input
|
||||
// |
|
||||
// tf.quantize
|
||||
// |
|
||||
// tf.dequantize
|
||||
// |
|
||||
template <typename TFFakeQuantOp, bool PerAxis>
|
||||
struct InsertQuantOpsAfterTFFakeQuantOp
|
||||
: public OpRewritePattern<TFFakeQuantOp> {
|
||||
using BaseType = InsertQuantOpsAfterTFFakeQuantOp<TFFakeQuantOp, PerAxis>;
|
||||
|
||||
explicit InsertQuantOpsAfterTFFakeQuantOp<TFFakeQuantOp, PerAxis>(
|
||||
MLIRContext *ctx)
|
||||
: OpRewritePattern<TFFakeQuantOp>(ctx) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(TFFakeQuantOp tf_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// We don't want to insert quantize/dequantize if the quantize op exists.
|
||||
auto res = tf_op.outputs();
|
||||
if (!res.hasOneUse() || isa<quant::QuantizeCastOp>(*res.user_begin()))
|
||||
return this->matchFailure();
|
||||
|
||||
// Extract the min/max constant values from the operands. We also consider
|
||||
// a special case that there are tf.Identity ops between the min/max
|
||||
// constants and the tf.FakeQuantWithMinMaxVarsOp.
|
||||
Value min = tf_op.min(), max = tf_op.max();
|
||||
DenseFPElementsAttr min_value, max_value;
|
||||
if (auto id1 = dyn_cast_or_null<TF::IdentityOp>(min.getDefiningOp())) {
|
||||
id1.replaceAllUsesWith(id1.input());
|
||||
min = tf_op.min();
|
||||
rewriter.eraseOp(id1);
|
||||
}
|
||||
if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max.getDefiningOp())) {
|
||||
id2.replaceAllUsesWith(id2.input());
|
||||
max = tf_op.max();
|
||||
rewriter.eraseOp(id2);
|
||||
}
|
||||
if (!matchPattern(min, m_Constant(&min_value))) return this->matchFailure();
|
||||
if (!matchPattern(max, m_Constant(&max_value))) return this->matchFailure();
|
||||
|
||||
int quant_dim = -1;
|
||||
if (PerAxis) {
|
||||
// This is a special case that the quant_dim is the last dimensions
|
||||
// according to the tf.FakeQuantWithMinMaxPerChannel.
|
||||
quant_dim = res.getType().template cast<ShapedType>().getRank() - 1;
|
||||
}
|
||||
// Use the min/max from the operands and the num_bits and narrow_range
|
||||
// attribute to create the quantization parameter for the new quantize op.
|
||||
rewriter.setInsertionPointAfter(tf_op);
|
||||
IntegerAttr num_bits =
|
||||
rewriter.getI64IntegerAttr(tf_op.num_bits().getSExtValue());
|
||||
BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.narrow_range());
|
||||
Type res_type = tf_op.getType();
|
||||
TypeAttr qtype = quant::GetQuantizedTypeAttr(
|
||||
rewriter, res_type, min_value, max_value, quant_dim, num_bits,
|
||||
narrow_range, /*is_signed=*/true);
|
||||
if (!qtype) this->matchFailure();
|
||||
|
||||
// Finally, use the quantization parameter to create the quantize and
|
||||
// dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp
|
||||
// and its users.
|
||||
Value value = tf_op.outputs();
|
||||
auto quantize = rewriter.create<quant::QuantizeCastOp>(
|
||||
tf_op.getLoc(), qtype.getValue(), value);
|
||||
auto dequantize = rewriter.create<quant::DequantizeCastOp>(
|
||||
tf_op.getLoc(), res_type, quantize.getResult());
|
||||
value.replaceAllUsesWith(dequantize);
|
||||
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
|
||||
|
||||
return this->matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
using PreparePerTensorFakeQuant =
|
||||
InsertQuantOpsAfterTFFakeQuantOp<TF::FakeQuantWithMinMaxVarsOp, false>;
|
||||
|
||||
using PreparePerChannelFakeQuant =
|
||||
InsertQuantOpsAfterTFFakeQuantOp<TF::FakeQuantWithMinMaxVarsPerChannelOp,
|
||||
true>;
|
||||
|
||||
// TODO(fengliuai): add the support of the tf.QuantizeAndDequantize*
|
||||
// legalization.
|
||||
|
||||
void LegalizeTFToQuant::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto func = getFunction();
|
||||
auto *ctx = func.getContext();
|
||||
patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant>(ctx);
|
||||
applyPatternsGreedily(func, patterns);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the TensorFlow dialect to QuantOps dialect pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFToQuantPass() {
|
||||
return std::make_unique<LegalizeTFToQuant>();
|
||||
}
|
||||
|
||||
static PassRegistration<LegalizeTFToQuant> pass(
|
||||
"tf-to-quant", "Legalize TF to quant ops dialect");
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
@ -3,7 +3,8 @@
|
||||
|
||||
// CHECK-LABEL: import_stats_skip
|
||||
func @import_stats_skip(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
|
||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "skip"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
loc(fused["skip1", "skip2.cc":10:8, callsite("op" at "skip3.cc":10:8)])
|
||||
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
|
||||
|
||||
// CHECK-NEXT: "tfl.split"
|
||||
@ -12,7 +13,8 @@ func @import_stats_skip(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf3
|
||||
|
||||
// CHECK-LABEL: import_stats_name
|
||||
func @import_stats_name(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
|
||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
loc(fused["skip1.cc":10:8, "op", callsite("skip2" at "skip3.cc":10:8)])
|
||||
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
|
||||
|
||||
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"
|
||||
@ -23,7 +25,8 @@ func @import_stats_name(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf3
|
||||
|
||||
// CHECK-LABEL: import_stats_name_port
|
||||
func @import_stats_name_port(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
|
||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op_0"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
loc(fused["skip1.cc":10:8, "op_0", callsite("skip2" at "skip3.cc":10:8)])
|
||||
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
|
||||
|
||||
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"
|
||||
@ -34,6 +37,7 @@ func @import_stats_name_port(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor
|
||||
// CHECK-LABEL: import_stats_name_regex
|
||||
func @import_stats_name_regex(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
|
||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op_regex"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
loc(fused["skip1.cc":10:8, "op_regex", callsite("skip2" at "skip3.cc":10:8)])
|
||||
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
|
||||
|
||||
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"
|
||||
|
@ -46,9 +46,9 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) {
|
||||
std::vector<Record *> defs = records.getAllDerivedDefinitions("Op");
|
||||
llvm::sort(defs, LessRecord());
|
||||
|
||||
OUT(0) << "static std::unique_ptr<OpQuantSpec> "
|
||||
OUT(0) << "static std::unique_ptr<quant::OpQuantSpec> "
|
||||
"GetOpQuantSpec(mlir::Operation *op) {\n";
|
||||
OUT(2) << "auto spec = absl::make_unique<OpQuantSpec>();\n";
|
||||
OUT(2) << "auto spec = absl::make_unique<quant::OpQuantSpec>();\n";
|
||||
llvm::SmallVector<llvm::StringRef, 3> matches;
|
||||
for (auto *def : defs) {
|
||||
Operator op(def);
|
||||
@ -74,7 +74,7 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) {
|
||||
if (acc_uniform_trait_regex.match(trait_str, &matches)) {
|
||||
OUT(4) << "spec->biases_params.emplace(std::make_pair(" << matches[1]
|
||||
<< ", std::make_pair(tfl.GetAllNonBiasOperands(),"
|
||||
<< "GetUniformQuantizedTypeForBias)));\n";
|
||||
<< "quant::GetUniformQuantizedTypeForBias)));\n";
|
||||
matches.clear();
|
||||
}
|
||||
// There is a "QuantChannelDim" trait, set the quantization dimension.
|
||||
|
36
tensorflow/compiler/mlir/lite/quantization/xla/BUILD
Normal file
36
tensorflow/compiler/mlir/lite/quantization/xla/BUILD
Normal file
@ -0,0 +1,36 @@
|
||||
package(
|
||||
default_visibility = [
|
||||
":friends",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "friends",
|
||||
includes = ["//third_party/mlir:subpackages"],
|
||||
packages = [
|
||||
"//tensorflow/compiler/mlir/...",
|
||||
"//tensorflow/compiler/mlir/lite/...",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_xla_quantization_passes",
|
||||
srcs = [
|
||||
"op_quant_spec.inc",
|
||||
"propagate.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"passes.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
@ -0,0 +1,7 @@
|
||||
// TODO(fengliuai): automatically generate this file
|
||||
// TODO(fengliuai): add all the xla_hlo ops
|
||||
|
||||
static std::unique_ptr<quant::OpQuantSpec> GetOpQuantSpec(mlir::Operation *op) {
|
||||
auto spec = absl::make_unique<quant::OpQuantSpec>();
|
||||
return spec;
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user