Merge branch 'master' into identity_in_constant_value

This commit is contained in:
Harry Slatyer 2020-02-03 11:04:24 +11:00
commit 8ca5bed715
2860 changed files with 129738 additions and 45504 deletions

View File

@ -69,6 +69,7 @@
# rbe_linux_py3: Linux Python 3 RBE config # rbe_linux_py3: Linux Python 3 RBE config
# #
# rbe_win_py37: Windows Python 3.7 RBE config # rbe_win_py37: Windows Python 3.7 RBE config
# rbe_win_py38: Windows Python 3.8 RBE config
# #
# tensorflow_testing_rbe_linux: RBE options to use RBE with tensorflow-testing project on linux # tensorflow_testing_rbe_linux: RBE options to use RBE with tensorflow-testing project on linux
# tensorflow_testing_rbe_win: RBE options to use RBE with tensorflow-testing project on windows # tensorflow_testing_rbe_win: RBE options to use RBE with tensorflow-testing project on windows
@ -279,7 +280,6 @@ build:windows --host_linkopt=/OPT:REF
build:windows --linkopt=/OPT:ICF build:windows --linkopt=/OPT:ICF
build:windows --host_linkopt=/OPT:ICF build:windows --host_linkopt=/OPT:ICF
build:windows --experimental_strict_action_env=true build:windows --experimental_strict_action_env=true
build:windows --incompatible_windows_native_test_wrapper
# Verbose failure logs when something goes wrong # Verbose failure logs when something goes wrong
build:windows --verbose_failures build:windows --verbose_failures
@ -344,6 +344,7 @@ build:rbe_linux --config=avx_linux
build:rbe_linux --config=short_logs build:rbe_linux --config=short_logs
# TODO(gunan): Check why we need this specified in rbe, but not in other builds. # TODO(gunan): Check why we need this specified in rbe, but not in other builds.
build:rbe_linux --linkopt=-lrt build:rbe_linux --linkopt=-lrt
build:rbe_linux --linkopt=-lm
build:rbe_cpu_linux --config=rbe_linux build:rbe_cpu_linux --config=rbe_linux
build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain" build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain"
@ -392,6 +393,7 @@ build:rbe_win --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe
# TODO(gunan): Remove once we use MSVC 2019 with latest patches. # TODO(gunan): Remove once we use MSVC 2019 with latest patches.
build:rbe_win --define=override_eigen_strong_inline=true build:rbe_win --define=override_eigen_strong_inline=true
build:rbe_win --jobs=500
build:rbe_win_py37 --config=rbe build:rbe_win_py37 --config=rbe
build:rbe_win_py37 --repo_env=PYTHON_BIN_PATH=C:\\Python37\\python.exe build:rbe_win_py37 --repo_env=PYTHON_BIN_PATH=C:\\Python37\\python.exe
@ -399,6 +401,12 @@ build:rbe_win_py37 --repo_env=PYTHON_LIB_PATH=C:\\Python37\\lib\\site-packages
build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py37 build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py37
build:rbe_win_py37 --python_path=C:\\Python37\\python.exe build:rbe_win_py37 --python_path=C:\\Python37\\python.exe
build:rbe_win_py38 --config=rbe
build:rbe_win_py38 --repo_env=PYTHON_BIN_PATH=C:\\Python38\\python.exe
build:rbe_win_py38 --repo_env=PYTHON_LIB_PATH=C:\\Python38\\lib\\site-packages
build:rbe_win_py38 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py38
build:rbe_win_py38 --python_path=C:\\Python38\\python.exe
# These you may need to change for your own GCP project. # These you may need to change for your own GCP project.
build:tensorflow_testing_rbe --project_id=tensorflow-testing build:tensorflow_testing_rbe --project_id=tensorflow-testing
common:tensorflow_testing_rbe_linux --remote_instance_name=projects/tensorflow-testing/instances/default_instance common:tensorflow_testing_rbe_linux --remote_instance_name=projects/tensorflow-testing/instances/default_instance

View File

@ -1 +1 @@
1.1.0 1.2.1

View File

@ -29,20 +29,6 @@ to
[announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce). [announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce).
See all the [mailing lists](https://www.tensorflow.org/community/forums). See all the [mailing lists](https://www.tensorflow.org/community/forums).
## Feature Prioritization Survey
The TensorFlow team is working on building/improving features, and understands
that it is very important to prioritize these efforts based on what TF users
need.
The goal of this short, < 5 minute
[survey](https://google.qualtrics.com/jfe/form/SV_d5nqhCEbkDkQ7ad), is to help
the TensorFlow team better understand what features to prioritize based on your
feedback. Participation is of course optional.
Take the survey
[HERE](https://google.qualtrics.com/jfe/form/SV_d5nqhCEbkDkQ7ad).
## Install ## Install
See the [TensorFlow install guide](https://www.tensorflow.org/install) for the See the [TensorFlow install guide](https://www.tensorflow.org/install) for the
@ -164,4 +150,3 @@ Learn more about the
## License ## License
[Apache License 2.0](LICENSE) [Apache License 2.0](LICENSE)

File diff suppressed because one or more lines are too long

View File

@ -245,4 +245,4 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc=
### Known Vulnerabilities ### Known Vulnerabilities
For a list of known vulnerabilities and security advisories for TensorFlow, For a list of known vulnerabilities and security advisories for TensorFlow,
[click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/index.md). [click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/README.md).

View File

@ -1,11 +1,13 @@
workspace(name = "org_tensorflow") workspace(name = "org_tensorflow")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("//third_party:repo.bzl", "tf_http_archive")
http_archive( tf_http_archive(
name = "io_bazel_rules_closure", name = "io_bazel_rules_closure",
sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9", sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9",
strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149", strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149",
patch_file = "@org_tensorflow//third_party:rules_closure.patch",
urls = [ urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", "https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz",
"https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13 "https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13
@ -48,38 +50,6 @@ load("//third_party/toolchains/preconfig/generate:workspace.bzl",
remote_config_workspace() remote_config_workspace()
# Apple and Swift rules.
http_archive(
name = "build_bazel_rules_apple",
sha256 = "a045a436b642c70fb0c10ca84ff0fd2dcbd59cc89100d597a61e8374afafb366",
urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.18.0/rules_apple.0.18.0.tar.gz"],
) # https://github.com/bazelbuild/rules_apple/releases
http_archive(
name = "build_bazel_rules_swift",
sha256 = "18cd4df4e410b0439a4935f9ca035bd979993d42372ba79e7f2d4fafe9596ef0",
urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.12.1/rules_swift.0.12.1.tar.gz"],
) # https://github.com/bazelbuild/rules_swift/releases
http_archive(
name = "build_bazel_apple_support",
sha256 = "122ebf7fe7d1c8e938af6aeaee0efe788a3a2449ece5a8d6a428cb18d6f88033",
urls = ["https://github.com/bazelbuild/apple_support/releases/download/0.7.1/apple_support.0.7.1.tar.gz"],
) # https://github.com/bazelbuild/apple_support/releases
http_archive(
name = "bazel_skylib",
sha256 = "1dde365491125a3db70731e25658dfdd3bc5dbdfd11b840b3e987ecf043c7ca0",
urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.9.0/bazel-skylib.0.9.0.tar.gz"],
) # https://github.com/bazelbuild/bazel-skylib/releases
http_archive(
name = "com_github_apple_swift_swift_protobuf",
type = "zip",
strip_prefix = "swift-protobuf-1.6.0/",
urls = ["https://github.com/apple/swift-protobuf/archive/1.6.0.zip"],
) # https://github.com/apple/swift-protobuf/releases
http_file(
name = "xctestrunner",
executable = 1,
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.9/ios_test_runner.par"],
) # https://github.com/google/xctestrunner/releases
# Use `swift_rules_dependencies` to fetch the toolchains. With the # Use `swift_rules_dependencies` to fetch the toolchains. With the
# `git_repository` rules above, the following call will skip redefining them. # `git_repository` rules above, the following call will skip redefining them.
load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies") load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies")

View File

@ -49,8 +49,8 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
_TF_WORKSPACE_ROOT = '' _TF_WORKSPACE_ROOT = ''
_TF_BAZELRC = '' _TF_BAZELRC = ''
_TF_CURRENT_BAZEL_VERSION = None _TF_CURRENT_BAZEL_VERSION = None
_TF_MIN_BAZEL_VERSION = '1.0.0' _TF_MIN_BAZEL_VERSION = '1.2.1'
_TF_MAX_BAZEL_VERSION = '1.1.0' _TF_MAX_BAZEL_VERSION = '1.2.1'
NCCL_LIB_PATHS = [ NCCL_LIB_PATHS = [
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', '' 'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
@ -1221,7 +1221,7 @@ def is_reduced_optimize_huge_functions_available(environ_cp):
only, as of 2019-11-19). TensorFlow needs this flag to massively reduce only, as of 2019-11-19). TensorFlow needs this flag to massively reduce
compile times, but until 16.4 is officially released, we can't depend on it. compile times, but until 16.4 is officially released, we can't depend on it.
See also https://groups.google.com/a/tensorflow.org/g/build/c/SsW98Eo7l3o See also https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
Because it's very annoying to check this manually (to check the MSVC installed Because it's very annoying to check this manually (to check the MSVC installed
versions, you need to use the registry, and it's not clear if Bazel will be versions, you need to use the registry, and it's not clear if Bazel will be

View File

@ -2,6 +2,7 @@
# TensorFlow is a computational framework, primarily for use in machine # TensorFlow is a computational framework, primarily for use in machine
# learning applications. # learning applications.
load("@bazel_skylib//lib:selects.bzl", "selects")
load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary") load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary")
load( load(
"//tensorflow/core/platform:build_config.bzl", "//tensorflow/core/platform:build_config.bzl",
@ -478,6 +479,7 @@ bzl_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//tensorflow/core/platform:build_config_root_bzl", "//tensorflow/core/platform:build_config_root_bzl",
"//tensorflow/core/platform:rules_cc_bzl",
"//tensorflow/core/platform/default:cuda_build_defs_bzl", "//tensorflow/core/platform/default:cuda_build_defs_bzl",
"//third_party/mkl:build_defs_bzl", "//third_party/mkl:build_defs_bzl",
"//third_party/mkl_dnn:build_defs_bzl", "//third_party/mkl_dnn:build_defs_bzl",

View File

@ -23,10 +23,6 @@ from __future__ import print_function
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
from tensorflow.python.util.lazy_loader import LazyLoader
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
del LazyLoader
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
from tensorflow.python.platform import app # pylint: disable=g-import-not-at-top from tensorflow.python.platform import app # pylint: disable=g-import-not-at-top
app.flags = flags app.flags = flags

View File

@ -54,9 +54,10 @@ filegroup(
) )
filegroup( filegroup(
name = "pywrap_eager_hdrs", name = "pywrap_required_hdrs",
srcs = [ srcs = [
"c_api_internal.h", "c_api_internal.h",
"python_api.h",
"tf_status_helper.h", "tf_status_helper.h",
"tf_status_internal.h", "tf_status_internal.h",
"tf_tensor_internal.h", "tf_tensor_internal.h",
@ -98,6 +99,17 @@ tf_cuda_library(
], ],
) )
filegroup(
name = "pywrap_tf_session_hdrs",
srcs = [
"python_api.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)
cc_library( cc_library(
name = "tf_attrtype", name = "tf_attrtype",
hdrs = ["tf_attrtype.h"], hdrs = ["tf_attrtype.h"],
@ -302,6 +314,7 @@ tf_cuda_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:attr_builder", "//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/platform", "//tensorflow/core/platform",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
@ -639,7 +652,7 @@ tf_cuda_cc_test(
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core/kernels:ops_testutil", "//tensorflow/core/kernels:ops_testutil",
"//third_party/eigen3", "@com_google_absl//absl/container:inlined_vector",
], ],
) )

View File

@ -458,7 +458,7 @@ static void TF_Run_Helper(
EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape()); EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
continue; continue;
} }
c_outputs[i] = TF_TensorFromTensor(src, status); c_outputs[i] = TF_TensorFromTensor(src, &status->status);
if (!status->status.ok()) return; if (!status->status.ok()) return;
} }
} }
@ -1493,7 +1493,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
Tensor t; Tensor t;
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t); status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
if (!status->status.ok()) return; if (!status->status.ok()) return;
*value = TF_TensorFromTensor(t, status); *value = TF_TensorFromTensor(t, &status->status);
} }
void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name, void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
@ -1504,7 +1504,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
if (!status->status.ok()) return; if (!status->status.ok()) return;
const auto len = std::min(max_values, static_cast<int>(ts.size())); const auto len = std::min(max_values, static_cast<int>(ts.size()));
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
values[i] = TF_TensorFromTensor(ts[i], status); values[i] = TF_TensorFromTensor(ts[i], &status->status);
} }
} }
@ -2398,7 +2398,7 @@ unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
graph->graph.versions().producer(), &evaluated, &result_tensor); graph->graph.versions().producer(), &evaluated, &result_tensor);
if (evaluated) { if (evaluated) {
DCHECK(status->status.ok()); DCHECK(status->status.ok());
*result = TF_TensorFromTensor(result_tensor, status); *result = TF_TensorFromTensor(result_tensor, &status->status);
if (!status->status.ok()) evaluated = false; if (!status->status.ok()) evaluated = false;
} }
return evaluated; return evaluated;

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/shape_inference.h"
@ -549,7 +550,7 @@ TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(TFE_Op* op,
TF_Status* status) { TF_Status* status) {
TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification; TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification;
n->thread.reset(op->operation.EagerContext()->TFEnv()->StartThread( n->thread.reset(op->operation.EagerContext().TFEnv()->StartThread(
tensorflow::ThreadOptions(), "ExecuteOpThread", tensorflow::ThreadOptions(), "ExecuteOpThread",
[op, retvals, num_retvals, n]() { [op, retvals, num_retvals, n]() {
TFE_Execute(op, retvals, num_retvals, n->status.get()); TFE_Execute(op, retvals, num_retvals, n->status.get());
@ -634,7 +635,7 @@ TF_Tensor* TF_CheckpointReaderGetTensor(TF_CheckpointReader* reader,
std::unique_ptr<tensorflow::Tensor> tensor; std::unique_ptr<tensorflow::Tensor> tensor;
reader->GetTensor(name, &tensor, status); reader->GetTensor(name, &tensor, status);
if (!status->status.ok()) return nullptr; if (!status->status.ok()) return nullptr;
return tensorflow::TF_TensorFromTensor(*tensor, status); return tensorflow::TF_TensorFromTensor(*tensor, &status->status);
} }
void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader, void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader,
@ -767,8 +768,9 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
} while (0); } while (0);
// New server created for new server_def. Unused if updating server_def. // New server created for new server_def. Unused if updating server_def.
tensorflow::EagerContext* context = ctx->context;
tensorflow::GrpcServer* grpc_server = tensorflow::GrpcServer* grpc_server =
dynamic_cast<tensorflow::GrpcServer*>(ctx->context->GetServer()); dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
if (grpc_server == nullptr) { if (grpc_server == nullptr) {
std::unique_ptr<tensorflow::ServerInterface> new_server; std::unique_ptr<tensorflow::ServerInterface> new_server;
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server)); LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
@ -779,12 +781,12 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
} }
LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
LOG_AND_RETURN_IF_ERROR(ctx->context->StoreCollectiveOpsServer( LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
std::move(new_server), grpc_server->worker_env()->device_mgr, std::move(new_server), grpc_server->worker_env()->device_mgr,
grpc_server->worker_env()->collective_executor_mgr)); grpc_server->worker_env()->collective_executor_mgr));
} else { } else {
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def)); LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
LOG_AND_RETURN_IF_ERROR(ctx->context->StoreCollectiveOpsServer( LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
/*new_server=*/nullptr, grpc_server->worker_env()->device_mgr, /*new_server=*/nullptr, grpc_server->worker_env()->device_mgr,
grpc_server->worker_env()->collective_executor_mgr)); grpc_server->worker_env()->collective_executor_mgr));
} }

View File

@ -1260,11 +1260,10 @@ TEST_F(CApiFunctionTest, GraphToFunctionDefWithPlaceholderAttr) {
NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node3", "v2", NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node3", "v2",
&node3); &node3);
TF_Output inputs[] = {};
TF_Output outputs[] = {{node1, 0}, {node2, 0}, {node3, 0}}; TF_Output outputs[] = {{node1, 0}, {node2, 0}, {node3, 0}};
func_ = TF_GraphToFunction( func_ = TF_GraphToFunction(
func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1, func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
/*opers=*/nullptr, 0, inputs, 3, outputs, /*opers=*/nullptr, 0, nullptr, 3, outputs,
/*output_names=*/nullptr, /*output_names=*/nullptr,
/*opts=*/nullptr, /*description=*/nullptr, s.get()); /*opts=*/nullptr, /*description=*/nullptr, s.get());
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get()); ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
@ -1300,10 +1299,9 @@ TEST_F(CApiFunctionTest, GraphToFunctionDefWithArgAttr) {
&node); &node);
TF_Output inputs[] = {{node, 0}}; TF_Output inputs[] = {{node, 0}};
TF_Output outputs[] = {};
func_ = TF_GraphToFunction( func_ = TF_GraphToFunction(
func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1, func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
/*opers=*/nullptr, 1, inputs, 0, outputs, /*opers=*/nullptr, 1, inputs, 0, nullptr,
/*output_names=*/nullptr, /*output_names=*/nullptr,
/*opts=*/nullptr, /*description=*/nullptr, s.get()); /*opts=*/nullptr, /*description=*/nullptr, s.get());
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get()); ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
@ -1603,11 +1601,10 @@ void DefineStatefulFunction(const char* name, TF_Function** func) {
TF_Operation* random = TF_Operation* random =
RandomUniform(shape, TF_FLOAT, func_graph.get(), s.get()); RandomUniform(shape, TF_FLOAT, func_graph.get(), s.get());
TF_Output inputs[] = {};
TF_Output outputs[] = {{random, 0}}; TF_Output outputs[] = {{random, 0}};
*func = TF_GraphToFunction(func_graph.get(), name, *func = TF_GraphToFunction(func_graph.get(), name,
/*append_hash_to_fn_name=*/false, -1, /*append_hash_to_fn_name=*/false, -1,
/*opers=*/nullptr, 0, inputs, 1, outputs, /*opers=*/nullptr, 0, nullptr, 1, outputs,
/*output_names=*/nullptr, /*output_names=*/nullptr,
/*opts=*/nullptr, "", s.get()); /*opts=*/nullptr, "", s.get());
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get()); ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());

View File

@ -188,7 +188,7 @@ namespace tensorflow {
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status); TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
TF_Buffer* out); TF_Buffer* out);

View File

@ -51,7 +51,7 @@ limitations under the License.
#include "tensorflow/core/util/equal_graph_def.h" #include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow { namespace tensorflow {
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status); TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
namespace { namespace {
@ -227,7 +227,7 @@ TEST(CAPI, LibraryLoadFunctions) {
void TestEncodeDecode(int line, const std::vector<string>& data) { void TestEncodeDecode(int line, const std::vector<string>& data) {
const tensorflow::int64 n = data.size(); const tensorflow::int64 n = data.size();
TF_Status* status = TF_NewStatus(); Status status;
for (const std::vector<tensorflow::int64>& dims : for (const std::vector<tensorflow::int64>& dims :
std::vector<std::vector<tensorflow::int64>>{ std::vector<std::vector<tensorflow::int64>>{
{n}, {1, n}, {n, 1}, {n / 2, 2}}) { {n}, {1, n}, {n, 1}, {n / 2, 2}}) {
@ -236,8 +236,8 @@ void TestEncodeDecode(int line, const std::vector<string>& data) {
for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) { for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) {
src.flat<tstring>()(i) = data[i]; src.flat<tstring>()(i) = data[i];
} }
TF_Tensor* dst = TF_TensorFromTensor(src, status); TF_Tensor* dst = TF_TensorFromTensor(src, &status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); ASSERT_TRUE(status.ok()) << status.error_message();
// Convert back to a C++ Tensor and ensure we get expected output. // Convert back to a C++ Tensor and ensure we get expected output.
Tensor output; Tensor output;
@ -249,7 +249,6 @@ void TestEncodeDecode(int line, const std::vector<string>& data) {
TF_DeleteTensor(dst); TF_DeleteTensor(dst);
} }
TF_DeleteStatus(status);
} }
TEST(CAPI, TensorEncodeDecodeStrings) { TEST(CAPI, TensorEncodeDecodeStrings) {
@ -1394,8 +1393,9 @@ TEST(CAPI, SavedModel) {
TF_Operation* input_op = TF_Operation* input_op =
TF_GraphOperationByName(graph, input_op_name.c_str()); TF_GraphOperationByName(graph, input_op_name.c_str());
ASSERT_TRUE(input_op != nullptr); ASSERT_TRUE(input_op != nullptr);
csession.SetInputs({{input_op, TF_TensorFromTensor(input, s)}}); Status status;
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); csession.SetInputs({{input_op, TF_TensorFromTensor(input, &status)}});
ASSERT_TRUE(status.ok()) << status.error_message();
const tensorflow::string output_op_name( const tensorflow::string output_op_name(
tensorflow::ParseTensorName(output_name).first); tensorflow::ParseTensorName(output_name).first);
@ -2522,12 +2522,11 @@ TEST(CAPI, TestTensorIsNotAligned) {
// Take an unaligned slice. // Take an unaligned slice.
Tensor y = x.Slice(1, 13); Tensor y = x.Slice(1, 13);
TF_Status* status = TF_NewStatus(); Status status;
TF_Tensor* a = TF_TensorFromTensor(y, status); TF_Tensor* a = TF_TensorFromTensor(y, &status);
if (EIGEN_MAX_ALIGN_BYTES > 0) { if (EIGEN_MAX_ALIGN_BYTES > 0) {
EXPECT_FALSE(TF_TensorIsAligned(a)); EXPECT_FALSE(TF_TensorIsAligned(a));
} }
TF_DeleteStatus(status);
TF_DeleteTensor(a); TF_DeleteTensor(a);
} }

View File

@ -17,7 +17,7 @@ limitations under the License.
#include <memory.h> #include <memory.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <sys/time.h> #include <time.h>
#include <unistd.h> #include <unistd.h>
#include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api.h"
@ -58,12 +58,8 @@ int main(int argc, char** argv) {
} }
char file_name[100]; char file_name[100];
struct timeval t; time_t t = time(NULL);
if (gettimeofday(&t, NULL)) { snprintf(file_name, sizeof(file_name), "test-%d-%ld.txt", getpid(), t);
perror("gettimeofday failed");
return 1;
}
snprintf(file_name, sizeof(file_name), "test-%d-%ld.txt", getpid(), t.tv_sec);
size_t length = 2 + strlen(path) + strlen(file_name); size_t length = 2 + strlen(path) + strlen(file_name);
char* full_path = malloc(length); char* full_path = malloc(length);

View File

@ -26,8 +26,8 @@ tf_cuda_library(
"c_api.cc", "c_api.cc",
"c_api_debug.cc", "c_api_debug.cc",
"c_api_experimental.h", "c_api_experimental.h",
"c_api_internal.cc",
"c_api_internal.h", "c_api_internal.h",
"tensor_handle_interface.h",
], ],
hdrs = ["c_api.h"], hdrs = ["c_api.h"],
copts = tf_copts() + tfe_xla_copts(), copts = tf_copts() + tfe_xla_copts(),
@ -89,10 +89,11 @@ tf_cuda_library(
) )
filegroup( filegroup(
name = "pywrap_eager_hdrs", name = "pywrap_required_hdrs",
srcs = [ srcs = [
"c_api_experimental.h", "c_api_experimental.h",
"c_api_internal.h", "c_api_internal.h",
"tensor_handle_interface.h",
], ],
visibility = [ visibility = [
"//tensorflow/core:__pkg__", "//tensorflow/core:__pkg__",
@ -102,7 +103,10 @@ filegroup(
tf_cuda_library( tf_cuda_library(
name = "c_api_internal", name = "c_api_internal",
srcs = ["c_api_experimental.h"], srcs = [
"c_api_experimental.h",
"tensor_handle_interface.h",
],
hdrs = ["c_api_internal.h"], hdrs = ["c_api_internal.h"],
visibility = [ visibility = [
"//learning/deepmind/courier:__subpackages__", "//learning/deepmind/courier:__subpackages__",
@ -125,18 +129,6 @@ tf_cuda_library(
"//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/common_runtime/eager:kernel_and_device", "//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core/common_runtime/eager:tensor_handle", "//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/distributed_runtime:remote_device",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/eager:remote_tensor_handle",
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
"//tensorflow/core/profiler/lib:profiler_lib",
"//tensorflow/core/profiler/lib:profiler_session", "//tensorflow/core/profiler/lib:profiler_session",
], ],
) )

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/tf_tensor_internal.h" #include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h"
@ -43,6 +44,7 @@ limitations under the License.
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/platform.h" // NOLINT #include "tensorflow/core/platform/platform.h" // NOLINT
#include "tensorflow/core/protobuf/error_codes.pb.h" #include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/protobuf/device_filters.pb.h"
#include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/device_name_utils.h"
#ifdef TENSORFLOW_EAGER_USE_XLA #ifdef TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@ -81,6 +83,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/thread_annotations.h"
@ -93,10 +96,8 @@ using tensorflow::string;
namespace { namespace {
const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) { const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) {
if (op->inference_ctx) { const tensorflow::OpDef* op_def = op->operation.OpDef();
return op->inference_ctx->op_def; if (op_def) return op_def;
}
const tensorflow::OpDef* op_def;
status->status = status->status =
tensorflow::OpDefForOp(op->operation.Name().c_str(), &op_def); tensorflow::OpDefForOp(op->operation.Name().c_str(), &op_def);
return op_def; return op_def;
@ -265,9 +266,9 @@ tensorflow::Status GetReplacedFromExistingWorkers(
} }
tensorflow::Status CreateRemoteContexts( tensorflow::Status CreateRemoteContexts(
const std::vector<string>& remote_workers, tensorflow::uint64 context_id, TFE_Context* ctx, const std::vector<string>& remote_workers,
tensorflow::uint64 context_view_id, int keep_alive_secs, tensorflow::uint64 context_id, tensorflow::uint64 context_view_id,
const tensorflow::ServerDef& server_def, int keep_alive_secs, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
const bool lazy_copy_remote_function_inputs, const bool lazy_copy_remote_function_inputs,
const tensorflow::eager::CreateContextRequest& base_request) { const tensorflow::eager::CreateContextRequest& base_request) {
@ -296,7 +297,7 @@ tensorflow::Status CreateRemoteContexts(
continue; continue;
} }
tensorflow::eager::CreateContextRequest request(base_request); tensorflow::eager::CreateContextRequest request;
tensorflow::eager::CreateContextResponse* response = tensorflow::eager::CreateContextResponse* response =
new tensorflow::eager::CreateContextResponse(); new tensorflow::eager::CreateContextResponse();
request.set_context_id(context_id); request.set_context_id(context_id);
@ -304,6 +305,21 @@ tensorflow::Status CreateRemoteContexts(
*request.mutable_server_def() = server_def; *request.mutable_server_def() = server_def;
request.mutable_server_def()->set_job_name(parsed_name.job); request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task); request.mutable_server_def()->set_task_index(parsed_name.task);
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
server_def.default_session_config());
std::vector<bool> filtered_device_mask;
ctx->context->FilterDevicesForRemoteWorkers(
remote_worker, base_request.cluster_device_attributes(),
&filtered_device_mask);
DCHECK_EQ(filtered_device_mask.size(),
base_request.cluster_device_attributes_size());
for (int i = 0; i < filtered_device_mask.size(); i++) {
if (filtered_device_mask[i]) {
const auto& da = base_request.cluster_device_attributes(i);
*request.add_cluster_device_attributes() = da;
}
}
request.set_async(async); request.set_async(async);
request.set_keep_alive_secs(keep_alive_secs); request.set_keep_alive_secs(keep_alive_secs);
request.set_lazy_copy_remote_function_inputs( request.set_lazy_copy_remote_function_inputs(
@ -325,13 +341,34 @@ tensorflow::Status CreateRemoteContexts(
} }
tensorflow::Status UpdateRemoteContexts( tensorflow::Status UpdateRemoteContexts(
const std::vector<string>& remote_workers, tensorflow::uint64 context_id, TFE_Context* ctx, const std::vector<string>& remote_workers,
const std::vector<string>& added_workers,
const std::vector<string>& removed_workers, tensorflow::uint64 context_id,
tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def, tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers, tensorflow::eager::EagerClientCache* remote_eager_workers,
const tensorflow::eager::CreateContextRequest& base_request) { const tensorflow::eager::CreateContextRequest& base_request) {
int num_remote_workers = remote_workers.size(); int num_remote_workers = remote_workers.size();
tensorflow::BlockingCounter counter(num_remote_workers); tensorflow::BlockingCounter counter(num_remote_workers);
std::vector<tensorflow::Status> statuses(num_remote_workers); std::vector<tensorflow::Status> statuses(num_remote_workers);
int cluster_device_count = base_request.cluster_device_attributes_size();
std::unordered_set<string> added_or_removed(added_workers.begin(),
added_workers.end());
std::copy(removed_workers.begin(), removed_workers.end(),
std::inserter(added_or_removed, added_or_removed.end()));
// Whether each device is in the updated (added or removed) workers
std::vector<bool> device_added_or_removed(cluster_device_count);
for (int i = 0; i < base_request.cluster_device_attributes_size(); i++) {
const auto& da = base_request.cluster_device_attributes().at(i);
tensorflow::DeviceNameUtils::ParsedName pn;
tensorflow::DeviceNameUtils::ParseFullName(da.name(), &pn);
string task_name;
tensorflow::DeviceNameUtils::GetTaskName(pn, &task_name);
if (added_or_removed.find(task_name) != added_or_removed.end()) {
device_added_or_removed[i] = true;
}
}
for (int i = 0; i < num_remote_workers; i++) { for (int i = 0; i < num_remote_workers; i++) {
const string& remote_worker = remote_workers[i]; const string& remote_worker = remote_workers[i];
tensorflow::DeviceNameUtils::ParsedName parsed_name; tensorflow::DeviceNameUtils::ParsedName parsed_name;
@ -354,17 +391,42 @@ tensorflow::Status UpdateRemoteContexts(
continue; continue;
} }
std::vector<bool> filtered_device_mask;
ctx->context->FilterDevicesForRemoteWorkers(
remote_worker, base_request.cluster_device_attributes(),
&filtered_device_mask);
DCHECK_EQ(filtered_device_mask.size(), cluster_device_count);
// If any of the devices that match the device filters are in the set of
// added or removed workers, we must send a complete UpdateContextRequest.
// Otherwise, only send a simple request to increment context view ID.
std::vector<bool> added_or_removed_filtered_devices(cluster_device_count);
std::transform(device_added_or_removed.begin(),
device_added_or_removed.end(), filtered_device_mask.begin(),
added_or_removed_filtered_devices.begin(),
std::logical_and<bool>());
const bool full_update_request =
std::accumulate(added_or_removed_filtered_devices.begin(),
added_or_removed_filtered_devices.end(), false,
std::logical_or<bool>());
tensorflow::eager::UpdateContextRequest request; tensorflow::eager::UpdateContextRequest request;
auto* response = new tensorflow::eager::UpdateContextResponse(); auto* response = new tensorflow::eager::UpdateContextResponse();
*request.mutable_server_def() = server_def;
request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task);
for (const auto& da : base_request.cluster_device_attributes()) {
*request.add_cluster_device_attributes() = da;
}
request.set_context_id(context_id); request.set_context_id(context_id);
request.set_context_view_id(context_view_id); request.set_context_view_id(context_view_id);
if (full_update_request) {
*request.mutable_server_def() = server_def;
request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task);
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
server_def.default_session_config());
for (int i = 0; i < cluster_device_count; i++) {
if (filtered_device_mask[i]) {
const auto& da = base_request.cluster_device_attributes(i);
*request.add_cluster_device_attributes() = da;
}
}
}
eager_client->UpdateContextAsync( eager_client->UpdateContextAsync(
&request, response, &request, response,
@ -409,6 +471,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
// New server created for new server_def. Unused if updating server_def. // New server created for new server_def. Unused if updating server_def.
std::unique_ptr<tensorflow::ServerInterface> new_server; std::unique_ptr<tensorflow::ServerInterface> new_server;
tensorflow::EagerContext* context = ctx->context;
tensorflow::GrpcServer* grpc_server; tensorflow::GrpcServer* grpc_server;
if (reset_context) { if (reset_context) {
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server)); LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
@ -416,26 +479,25 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
LOG_AND_RETURN_IF_ERROR( LOG_AND_RETURN_IF_ERROR(
ListRemoteWorkers(grpc_server, worker_name, &remote_workers)); ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
} else { } else {
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers( LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
ctx->context->GetServer(), worker_name, &curr_remote_workers)); &curr_remote_workers));
// No need to check the cast here, since `ListRemoteWorkers` already checks // No need to check the cast here, since `ListRemoteWorkers` already checks
// if the server is a GRPC server or not. // if the server is a GRPC server or not.
grpc_server = grpc_server = dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
dynamic_cast<tensorflow::GrpcServer*>(ctx->context->GetServer());
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def)); LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
LOG_AND_RETURN_IF_ERROR( LOG_AND_RETURN_IF_ERROR(
ListRemoteWorkers(grpc_server, worker_name, &remote_workers)); ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
} }
tensorflow::uint64 context_id = ctx->context->GetContextId(); tensorflow::uint64 context_id = context->GetContextId();
tensorflow::uint64 context_view_id = ctx->context->GetContextViewId(); tensorflow::uint64 context_view_id = context->GetContextViewId();
if (reset_context) { if (reset_context) {
context_id = tensorflow::EagerContext::NewContextId(); context_id = tensorflow::EagerContext::NewContextId();
context_view_id = 0; context_view_id = 0;
// Make master eager context accessible by local eager service, which might // Make master eager context accessible by local eager service, which might
// receive send tensor requests from remote workers. // receive send tensor requests from remote workers.
LOG_AND_RETURN_IF_ERROR(grpc_server->AddMasterEagerContextToEagerService( LOG_AND_RETURN_IF_ERROR(
context_id, ctx->context)); grpc_server->AddMasterEagerContextToEagerService(context_id, context));
} }
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers; std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
@ -464,11 +526,11 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
&new_remote_device_mgr)); &new_remote_device_mgr));
remote_device_mgr = new_remote_device_mgr.get(); remote_device_mgr = new_remote_device_mgr.get();
} else { } else {
ctx->context->ClearCaches(); context->ClearCachesAndDefaultExecutor();
// TODO(b/143914772): Potential memory leak if rendezvous has pending // TODO(b/143914772): Potential memory leak if rendezvous has pending
// tensors for removed / replaced workers. // tensors for removed / replaced workers.
remote_device_mgr = ctx->context->GetOwnedRemoteDeviceMgr(); remote_device_mgr = context->GetOwnedRemoteDeviceMgr();
if (remote_device_mgr == nullptr) { if (remote_device_mgr == nullptr) {
LOG_AND_RETURN_IF_ERROR(tensorflow::errors::InvalidArgument( LOG_AND_RETURN_IF_ERROR(tensorflow::errors::InvalidArgument(
"Updating context with an invalid set of remote devices.")); "Updating context with an invalid set of remote devices."));
@ -479,8 +541,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
&added_workers, &removed_workers, &added_workers, &removed_workers,
&existing_workers); &existing_workers);
LOG_AND_RETURN_IF_ERROR(GetReplacedFromExistingWorkers( LOG_AND_RETURN_IF_ERROR(GetReplacedFromExistingWorkers(
&existing_workers, context_id, ctx->context->GetContextViewId(), &existing_workers, context_id, context->GetContextViewId(), server_def,
server_def, remote_eager_workers.get(), &replaced_workers)); remote_eager_workers.get(), &replaced_workers));
if (VLOG_IS_ON(1)) { if (VLOG_IS_ON(1)) {
VLOG(1) << "Updating cluster with following changes"; VLOG(1) << "Updating cluster with following changes";
for (const string& w : added_workers) VLOG(1) << " Added worker " << w; for (const string& w : added_workers) VLOG(1) << " Added worker " << w;
@ -516,7 +578,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
grpc_server->worker_env()->device_mgr->ListDeviceAttributes( grpc_server->worker_env()->device_mgr->ListDeviceAttributes(
&local_device_attributes); &local_device_attributes);
// This request make sure that we can create Rendevzous properly between // This request make sure that we can create Rendezvous properly between
// Local and Remote context. // Local and Remote context.
tensorflow::eager::CreateContextRequest base_request; tensorflow::eager::CreateContextRequest base_request;
for (const auto& da : cluster_device_attributes) { for (const auto& da : cluster_device_attributes) {
@ -525,18 +587,14 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
for (const auto& da : local_device_attributes) { for (const auto& da : local_device_attributes) {
*base_request.add_cluster_device_attributes() = da; *base_request.add_cluster_device_attributes() = da;
} }
base_request.mutable_server_def()
->mutable_default_session_config()
->MergeFrom(server_def.default_session_config());
// Initialize remote eager workers. // Initialize remote eager workers.
// TODO(b/138847548) Create remote eager contexts in async mode by default. // TODO(b/138847548) Create remote eager contexts in async mode by default.
if (reset_context) { if (reset_context) {
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
remote_workers, context_id, context_view_id, keep_alive_secs, ctx, remote_workers, context_id, context_view_id, keep_alive_secs,
server_def, remote_eager_workers.get(), server_def, remote_eager_workers.get(), context->Executor().Async(),
ctx->context->Executor().Async(), context->LazyCopyFunctionRemoteInputs(), base_request));
ctx->context->LazyCopyFunctionRemoteInputs(), base_request));
} else { } else {
// The master's context_view_id will be incremented by one // The master's context_view_id will be incremented by one
// the UpdateRemoteMaster call later. We want all new workers and // the UpdateRemoteMaster call later. We want all new workers and
@ -544,10 +602,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
// we must set their context_view_id to the existing master's // we must set their context_view_id to the existing master's
// context_view_id + 1. // context_view_id + 1.
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
added_workers, context_id, context_view_id + 1, keep_alive_secs, ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
server_def, remote_eager_workers.get(), server_def, remote_eager_workers.get(), context->Executor().Async(),
ctx->context->Executor().Async(), context->LazyCopyFunctionRemoteInputs(), base_request));
ctx->context->LazyCopyFunctionRemoteInputs(), base_request));
if (!existing_workers.empty()) { if (!existing_workers.empty()) {
if (VLOG_IS_ON(1)) { if (VLOG_IS_ON(1)) {
for (const string& w : existing_workers) { for (const string& w : existing_workers) {
@ -555,8 +612,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
} }
} }
LOG_AND_RETURN_IF_ERROR(UpdateRemoteContexts( LOG_AND_RETURN_IF_ERROR(UpdateRemoteContexts(
existing_workers, context_id, context_view_id + 1, server_def, ctx, existing_workers, added_workers, removed_workers, context_id,
remote_eager_workers.get(), base_request)); context_view_id + 1, server_def, remote_eager_workers.get(),
base_request));
} }
} }
@ -578,12 +636,12 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get())); TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr = tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
tensorflow::eager::CreateClusterFLR(context_id, ctx->context, tensorflow::eager::CreateClusterFLR(context_id, context,
worker_session.get()); worker_session.get());
auto remote_mgr = absl::make_unique<tensorflow::eager::RemoteMgr>( auto remote_mgr = absl::make_unique<tensorflow::eager::RemoteMgr>(
/*is_master=*/true, ctx->context); /*is_master=*/true, context);
LOG_AND_RETURN_IF_ERROR(ctx->context->InitializeRemoteMaster( LOG_AND_RETURN_IF_ERROR(context->InitializeRemoteMaster(
std::move(new_server), grpc_server->worker_env(), worker_session, std::move(new_server), grpc_server->worker_env(), worker_session,
std::move(remote_eager_workers), std::move(new_remote_device_mgr), std::move(remote_eager_workers), std::move(new_remote_device_mgr),
remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr, remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr,
@ -601,9 +659,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
grpc_server->worker_env()->session_mgr->WorkerSessionForSession( grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
session_name, &worker_session)); session_name, &worker_session));
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr = tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
tensorflow::eager::CreateClusterFLR(context_id, ctx->context, tensorflow::eager::CreateClusterFLR(context_id, context,
worker_session.get()); worker_session.get());
LOG_AND_RETURN_IF_ERROR(ctx->context->UpdateRemoteMaster( LOG_AND_RETURN_IF_ERROR(context->UpdateRemoteMaster(
grpc_server->worker_env(), std::move(remote_eager_workers), grpc_server->worker_env(), std::move(remote_eager_workers),
added_workers, removed_workers, context_id, r, device_mgr, added_workers, removed_workers, context_id, r, device_mgr,
keep_alive_secs, cluster_flr)); keep_alive_secs, cluster_flr));
@ -614,77 +672,6 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
} }
#endif // !IS_MOBILE_PLATFORM #endif // !IS_MOBILE_PLATFORM
tensorflow::Status OpInferSingleInputAttrs(TFE_Op* op,
TFE_TensorHandle* input) {
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++);
if (!input_def.number_attr().empty() || !input_def.type_list_attr().empty()) {
// Some clients that are still setting their input attributes manually are
// adding input list to their op by calling `TFE_OpAddInput` for each of
// its elements instead of calling `TFE_OpAddInputList`. When this happens,
// we cannot detect the end of such list, thus lose track of the input
// arguments in the op definition. To guarantee backward compatibility with
// those clients, disable automatic inference in this case.
op->inference_ctx.reset(nullptr);
return tensorflow::Status::OK();
}
const std::string& type_attr = input_def.type_attr();
if (!type_attr.empty() && ictx->attrs.find(type_attr) == ictx->attrs.end()) {
op->operation.MutableAttrs()->Set(type_attr, input->handle->dtype);
ictx->attrs.insert(type_attr);
}
return tensorflow::Status::OK();
}
void OpInferSingleTypeInputListAttrs(TFE_Op* op,
const tensorflow::OpDef::ArgDef& input_def,
TFE_TensorHandle** inputs,
int num_inputs) {
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
if (ictx->attrs.find(input_def.number_attr()) == ictx->attrs.end()) {
op->operation.MutableAttrs()->Set(input_def.number_attr(), num_inputs);
ictx->attrs.insert(input_def.number_attr());
}
if (ictx->attrs.find(input_def.type_attr()) == ictx->attrs.end()) {
op->operation.MutableAttrs()->Set(input_def.type_attr(),
inputs[0]->handle->dtype);
ictx->attrs.insert(input_def.type_attr());
}
}
void OpInferMixedTypeInputListAttrs(TFE_Op* op,
const tensorflow::OpDef::ArgDef& input_def,
TFE_TensorHandle** inputs, int num_inputs) {
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
if (ictx->attrs.find(input_def.type_list_attr()) == ictx->attrs.end()) {
std::unique_ptr<tensorflow::DataType[]> dtypes(
new tensorflow::DataType[num_inputs]);
for (int i = 0; i < num_inputs; ++i) {
dtypes[i] = inputs[i]->handle->dtype;
}
op->operation.MutableAttrs()->Set(
input_def.type_list_attr(),
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(dtypes.get(),
num_inputs));
ictx->attrs.insert(input_def.type_list_attr());
}
}
tensorflow::Status OpInferInputListAttrs(TFE_Op* op, TFE_TensorHandle** inputs,
int num_inputs) {
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++);
if (!input_def.type_list_attr().empty()) {
OpInferMixedTypeInputListAttrs(op, input_def, inputs, num_inputs);
} else if (!input_def.type_attr().empty() &&
!input_def.number_attr().empty()) {
OpInferSingleTypeInputListAttrs(op, input_def, inputs, num_inputs);
} else {
return tensorflow::errors::InvalidArgument("Invalid input list definition");
}
return tensorflow::Status::OK();
}
} // namespace } // namespace
extern "C" { extern "C" {
@ -720,12 +707,14 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
tensorflow::Rendezvous* r = tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get()); new tensorflow::IntraProcessRendezvous(device_mgr.get());
return new TFE_Context(opts->session_options.options, return new TFE_Context{new tensorflow::EagerContext(
opts->device_placement_policy, opts->mirroring_policy, opts->session_options.options,
opts->async, opts->lazy_remote_inputs_copy, static_cast<tensorflow::ContextDevicePlacementPolicy>(
device_mgr.release(), opts->device_placement_policy),
/*device_mgr_owned*/ true, r, static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
tensorflow::GetDefaultCustomKernelCreator()); opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
/*device_mgr_owned*/ true, r,
tensorflow::GetDefaultCustomKernelCreator())};
} }
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts, TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
@ -736,25 +725,33 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
tensorflow::Rendezvous* r = tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr); new tensorflow::IntraProcessRendezvous(device_mgr);
return new TFE_Context(opts->session_options.options, return new TFE_Context{new tensorflow::EagerContext(
opts->device_placement_policy, opts->mirroring_policy, opts->session_options.options,
opts->async, opts->lazy_remote_inputs_copy, device_mgr, static_cast<tensorflow::ContextDevicePlacementPolicy>(
/*device_mgr_owned*/ false, r, opts->device_placement_policy),
tensorflow::GetDefaultCustomKernelCreator()); static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
/*device_mgr_owned*/ false, r,
tensorflow::GetDefaultCustomKernelCreator())};
} }
void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; } void TFE_DeleteContext(TFE_Context* ctx) {
// context->RefCountIsOne() should be true here.
// TODO(iga): Remove EagerContext refcounting.
ctx->context->Unref();
delete ctx;
}
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
TF_DeviceList* list = new TF_DeviceList; TF_DeviceList* l = new TF_DeviceList;
ctx->context->local_device_mgr()->ListDeviceAttributes(&list->response); ctx->context->ListDevices(&l->response);
if (ctx->context->remote_device_mgr()) { return l;
ctx->context->remote_device_mgr()->ListDeviceAttributes(&list->response);
}
return list;
} }
void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context->ClearCaches(); } void TFE_ContextClearCaches(TFE_Context* ctx) {
ctx->context->ClearCachesAndThreadExecutors();
}
// Set server_def on the context, possibly updating it. // Set server_def on the context, possibly updating it.
TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
@ -772,6 +769,22 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
"Invalid tensorflow.ServerDef protocol buffer"); "Invalid tensorflow.ServerDef protocol buffer");
return; return;
} }
if (server_def.has_cluster_device_filters()) {
const auto& cdf = server_def.cluster_device_filters();
for (const auto& jdf : cdf.jobs()) {
const string& remote_prefix = "/job:" + jdf.name() + "/task:";
for (const auto& tdf : jdf.tasks()) {
const int32_t task_index = tdf.first;
std::vector<string> device_filters(tdf.second.device_filters_size());
for (int i = 0; i < tdf.second.device_filters_size(); i++) {
device_filters[i] = tdf.second.device_filters(i);
}
const string remote_worker = remote_prefix + std::to_string(task_index);
status->status =
ctx->context->SetRemoteDeviceFilters(remote_worker, device_filters);
}
}
}
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
ctx, /*reset_context=*/true); ctx, /*reset_context=*/true);
#endif // !IS_MOBILE_PLATFORM #endif // !IS_MOBILE_PLATFORM
@ -796,6 +809,11 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"Trying to update a context with invalid context id."); "Trying to update a context with invalid context id.");
} }
if (server_def.has_cluster_device_filters()) {
LOG(WARNING) << "Device filters can only be specified when initializing "
"the cluster. Any changes in device filters are ignored "
"when updating the server def.";
}
// TODO(haoyuzhang): Check server_def compatibility before the update // TODO(haoyuzhang): Check server_def compatibility before the update
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
ctx, /*reset_context=*/false); ctx, /*reset_context=*/false);
@ -810,8 +828,9 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
"TFE_ContextSetServerDef not supported on mobile"); "TFE_ContextSetServerDef not supported on mobile");
return false; return false;
#else // !defined(IS_MOBILE_PLATFORM) #else // !defined(IS_MOBILE_PLATFORM)
tensorflow::EagerContext* context = ctx->context;
tensorflow::GrpcServer* grpc_server = tensorflow::GrpcServer* grpc_server =
static_cast<tensorflow::GrpcServer*>(ctx->context->GetServer()); static_cast<tensorflow::GrpcServer*>(context->GetServer());
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers; std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
status->status = grpc_server->master_env()->worker_cache->GetEagerClientCache( status->status = grpc_server->master_env()->worker_cache->GetEagerClientCache(
@ -830,7 +849,7 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
// Send a rpc request to the worker to check aliveness. // Send a rpc request to the worker to check aliveness.
tensorflow::eager::KeepAliveRequest request; tensorflow::eager::KeepAliveRequest request;
request.set_context_id(ctx->context->GetContextId()); request.set_context_id(context->GetContextId());
tensorflow::eager::KeepAliveResponse response; tensorflow::eager::KeepAliveResponse response;
tensorflow::Status keep_alive_status; tensorflow::Status keep_alive_status;
@ -885,108 +904,180 @@ void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
if (h == nullptr) return; if (h == nullptr) return;
tensorflow::profiler::TraceMe activity( tensorflow::profiler::TraceMe activity(
"TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo); "TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
VLOG(1) << "Deleting tensor handle " << h << " with internal handle "
<< h->handle;
if (h->handle) {
h->handle->Unref();
}
delete h; delete h;
} }
tensorflow::TensorHandleInterface::~TensorHandleInterface() {
VLOG(1) << "Deleting tensor handle " << this << " with internal handle "
<< handle_;
if (handle_) {
handle_->Unref();
}
}
bool tensorflow::TensorHandleInterface::IsValid(Status* status) const {
if (handle_ == nullptr) {
*status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return false;
}
return true;
}
TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) { TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
return static_cast<TF_DataType>(h->handle->dtype); return h->handle->DataType();
}
TF_DataType tensorflow::TensorHandleInterface::DataType() const {
return static_cast<TF_DataType>(handle_->dtype);
} }
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr"); "The passed in handle is a nullptr");
return -1; return -1;
} }
return h->handle->NumDims(&status->status);
}
int tensorflow::TensorHandleInterface::NumDims(Status* status) const {
if (!IsValid(status)) {
return -1;
}
int result; int result;
status->status = h->handle->NumDims(&result); *status = handle_->NumDims(&result);
return result; return result;
} }
int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) { int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr"); "The passed in handle is a nullptr");
return -1; return -1;
} }
return h->handle->NumElements(&status->status);
}
int64_t tensorflow::TensorHandleInterface::NumElements(Status* status) const {
if (!IsValid(status)) {
return -1;
}
tensorflow::int64 result; tensorflow::int64 result;
status->status = h->handle->NumElements(&result); *status = handle_->NumElements(&result);
return result; return result;
} }
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
TF_Status* status) { TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr"); "The passed in handle is a nullptr");
return -1; return -1;
} }
return h->handle->Dim(dim_index, &status->status);
}
int64_t tensorflow::TensorHandleInterface::Dim(int dim_index,
Status* status) const {
if (!IsValid(status)) {
return -1;
}
tensorflow::int64 result; tensorflow::int64 result;
status->status = h->handle->Dim(dim_index, &result); *status = handle_->Dim(dim_index, &result);
return result; return result;
} }
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr"); "The passed in handle is a nullptr");
return nullptr; return nullptr;
} }
tensorflow::Device* d = h->handle->op_device(); return h->handle->DeviceName(&status->status);
}
const char* tensorflow::TensorHandleInterface::DeviceName(
Status* status) const {
if (!IsValid(status)) {
return nullptr;
}
tensorflow::Device* d = handle_->op_device();
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str(); : d->name().c_str();
} }
const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h, const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
TF_Status* status) { TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr"); "The passed in handle is a nullptr");
return nullptr; return nullptr;
} }
tensorflow::Device* d = h->handle->device(); return h->handle->BackingDeviceName(&status->status);
}
const char* tensorflow::TensorHandleInterface::BackingDeviceName(
Status* status) const {
if (!IsValid(status)) {
return nullptr;
}
tensorflow::Device* d = handle_->device();
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str(); : d->name().c_str();
} }
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor( TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
TFE_TensorHandle* h, TF_Status* status) { TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr || !h->handle->IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr"); "The passed in handle is a nullptr");
return nullptr; return nullptr;
} }
h->handle->Ref(); return new TFE_TensorHandle{
std::unique_ptr<AbstractTensorHandleInterface>(h->handle->Copy())};
}
return new TFE_TensorHandle(h->handle); AbstractTensorHandleInterface* tensorflow::TensorHandleInterface::Copy() {
handle_->Ref();
return new TensorHandleInterface(handle_);
} }
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr"); "The passed in handle is a nullptr");
return nullptr; return nullptr;
} }
tensorflow::TensorHandle* handle = h->handle;
return h->handle->Resolve(&status->status);
}
TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
if (!IsValid(status)) {
return nullptr;
}
// TODO(agarwal): move this implementation inside TFE_TensorHandle. // TODO(agarwal): move this implementation inside TFE_TensorHandle.
if (handle->IsRemote()) { if (handle_->IsRemote()) {
const tensorflow::Tensor* t = nullptr; const tensorflow::Tensor* t = nullptr;
tensorflow::TensorHandle* h_cpu = nullptr; tensorflow::TensorHandle* h_cpu = nullptr;
status->status = EagerCopyToDevice( *status = EagerCopyToDevice(handle_, handle_->Context(),
handle, handle->Context(), &handle->Context()->Executor(), &handle_->Context()->Executor(),
handle->Context()->HostCPU(), false, &h_cpu); handle_->Context()->HostCPU(), false, &h_cpu);
if (!status->status.ok()) { if (!status->ok()) {
return nullptr; return nullptr;
} }
status->status = h_cpu->Tensor(&t); *status = h_cpu->Tensor(&t);
if (!status->status.ok()) { if (!status->ok()) {
h_cpu->Unref(); h_cpu->Unref();
return nullptr; return nullptr;
} }
@ -995,28 +1086,30 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
return retval; return retval;
} else { } else {
tensorflow::Tensor tensor; tensorflow::Tensor tensor;
if (IsCPU(handle->device())) { if (IsCPU(handle_->device())) {
const tensorflow::Tensor* src = nullptr; const tensorflow::Tensor* src = nullptr;
status->status = handle->Tensor(&src); *status = handle_->Tensor(&src);
if (!status->status.ok()) return nullptr; if (!status->ok()) return nullptr;
tensor = *src; tensor = *src;
} else { } else {
tensorflow::EagerContext* ctx = handle->Context(); tensorflow::EagerContext* ctx = handle_->Context();
CHECK_NE(ctx, nullptr); CHECK_NE(ctx, nullptr);
status->status = h->handle->CopyToDevice(ctx, ctx->HostCPU(), &tensor); *status = handle_->CopyToDevice(*ctx, ctx->HostCPU(), &tensor);
if (!status->status.ok()) return nullptr; if (!status->ok()) return nullptr;
} }
return tensorflow::TF_TensorFromTensor(tensor, status); return tensorflow::TF_TensorFromTensor(tensor, status);
} }
} }
void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) { void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr || !h->handle->IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr"); "The passed in handle is a nullptr");
return nullptr; return nullptr;
} }
tensorflow::TensorHandle* handle = h->handle; tensorflow::TensorHandle* handle =
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle();
if (handle->IsRemote()) { if (handle->IsRemote()) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
@ -1045,7 +1138,8 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
void (*deallocator)(void* data, size_t len, void* arg), void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg, TF_Status* status) { void* deallocator_arg, TF_Status* status) {
tensorflow::Device* device; tensorflow::Device* device;
status->status = ctx->context->FindDeviceFromName(device_name, &device); tensorflow::EagerContext* context = ctx->context;
status->status = context->FindDeviceFromName(device_name, &device);
if (!status->status.ok()) { if (!status->status.ok()) {
deallocator(data, len, deallocator_arg); deallocator(data, len, deallocator_arg);
return nullptr; return nullptr;
@ -1073,11 +1167,12 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
buf->Unref(); buf->Unref();
tensorflow::TensorHandle* ret_handle; tensorflow::TensorHandle* ret_handle;
status->status = tensorflow::TensorHandle::CreateLocalHandle( status->status = tensorflow::TensorHandle::CreateLocalHandle(
t, device, ctx->context, &ret_handle); t, device, context, &ret_handle);
if (!status->status.ok()) { if (!status->status.ok()) {
return nullptr; return nullptr;
} }
return new TFE_TensorHandle(ret_handle); return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(ret_handle)};
} }
// This function will block till the operation that produces `h` has // This function will block till the operation that produces `h` has
@ -1085,12 +1180,14 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
// bytes of the memory pointed to by the device pointer returned above. // bytes of the memory pointed to by the device pointer returned above.
size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h, size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
TF_Status* status) { TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr || !h->handle->IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr"); "The passed in handle is a nullptr");
return 0; return 0;
} }
tensorflow::TensorHandle* handle = h->handle; tensorflow::TensorHandle* handle =
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle();
if (handle->IsRemote()) { if (handle->IsRemote()) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
@ -1108,8 +1205,14 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) { TF_Status* status) {
return NewOrResetOp(ctx, op_or_function_name, nullptr, status, std::unique_ptr<TFE_Op> new_op(
/* op_to_reset= */ nullptr); new TFE_Op{tensorflow::EagerOperation(ctx->context)});
status->status =
new_op->operation.Reset(op_or_function_name, nullptr, false, nullptr);
if (!status->status.ok()) {
new_op.reset();
}
return new_op.release();
} }
void TFE_DeleteOp(TFE_Op* op) { delete op; } void TFE_DeleteOp(TFE_Op* op) { delete op; }
@ -1120,7 +1223,7 @@ void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) { const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
tensorflow::Device* device = (op->operation.Device() == nullptr) tensorflow::Device* device = (op->operation.Device() == nullptr)
? op->operation.EagerContext()->HostCPU() ? op->operation.EagerContext().HostCPU()
: op->operation.Device(); : op->operation.Device();
return device->name().c_str(); return device->name().c_str();
} }
@ -1134,20 +1237,23 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
} }
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) { void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
op->operation.AddInput(input->handle); tensorflow::TensorHandle* h =
if (op->inference_ctx) { tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
status->status = OpInferSingleInputAttrs(op, input); input->handle.get())
} ->Handle();
op->operation.AddInput(h);
status->status = op->operation.MaybeInferSingleInputAttrs(h);
} }
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs, void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
TF_Status* status) { TF_Status* status) {
for (int i = 0; i < num_inputs; ++i) { for (int i = 0; i < num_inputs; ++i) {
op->operation.AddInput(inputs[i]->handle); op->operation.AddInput(
} tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
if (op->inference_ctx) { inputs[i]->handle.get())
status->status = OpInferInputListAttrs(op, inputs, num_inputs); ->Handle());
} }
status->status = op->operation.InferInputListAttrs(num_inputs);
} }
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
@ -1380,15 +1486,16 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) { TF_Status* status) {
VLOG(1) << "Calling TFE_Execute() on op " << op;
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals); absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
VLOG(1) << "Calling TFE_Execute() on op " << op;
status->status = tensorflow::EagerExecute(&op->operation, status->status = tensorflow::EagerExecute(&op->operation,
handle_retvals.data(), num_retvals); handle_retvals.data(), num_retvals);
if (!status->status.ok()) { if (!status->status.ok()) {
return; return;
} }
for (int i = 0; i < *num_retvals; ++i) { for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = new TFE_TensorHandle(handle_retvals[i]); retvals[i] = new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle_retvals[i])};
} }
} }
@ -1398,15 +1505,18 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
TF_Status* status) { TF_Status* status) {
tensorflow::TensorHandle* handle = nullptr; tensorflow::TensorHandle* handle = nullptr;
tensorflow::Device* device; tensorflow::Device* device;
status->status = ctx->context->FindDeviceFromName(device_name, &device); tensorflow::EagerContext* context = ctx->context;
status->status = context->FindDeviceFromName(device_name, &device);
if (!status->status.ok()) { if (!status->status.ok()) {
return nullptr; return nullptr;
} }
status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context, status->status = tensorflow::EagerCopyToDevice(
&ctx->context->Executor(), tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
device, false, &handle); ->Handle(),
context, &context->Executor(), device, false, &handle);
if (status->status.ok()) { if (status->status.ok()) {
return new TFE_TensorHandle(handle); return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
} }
return nullptr; return nullptr;
} }
@ -1454,11 +1564,12 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
TF_Status* status) { TF_Status* status) {
status->status = ctx->context->Executor().WaitForAllPendingNodes(); tensorflow::EagerContext* context = ctx->context;
status->status = context->Executor().WaitForAllPendingNodes();
if (!status->status.ok()) return; if (!status->status.ok()) return;
tensorflow::mutex_lock ml(*ctx->context->MetadataMu()); tensorflow::mutex_lock ml(*context->MetadataMu());
status->status = MessageToBuffer(*ctx->context->RunMetadataProto(), buf); status->status = MessageToBuffer(*context->RunMetadataProto(), buf);
ctx->context->ClearRunMetadata(); context->ClearRunMetadata();
} }
namespace { namespace {

View File

@ -206,14 +206,14 @@ typedef struct TFE_TensorDebugInfo TFE_TensorDebugInfo;
// error and nullptr is returned. This function can block till the operation // error and nullptr is returned. This function can block till the operation
// that produces `handle` has completed. // that produces `handle` has completed.
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
TFE_TensorHandle* handle, TF_Status* status); TFE_TensorHandle* h, TF_Status* status);
// Deletes `debug_info`. // Deletes `debug_info`.
TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo( TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo(
TFE_TensorDebugInfo* debug_info); TFE_TensorDebugInfo* debug_info);
// Returns the number of dimensions used to represent the tensor on its device. // Returns the number of dimensions used to represent the tensor on its device.
// The number of dimensions used to reprensent the tensor on device can be // The number of dimensions used to represent the tensor on device can be
// different from the number returned by TFE_TensorHandleNumDims. // different from the number returned by TFE_TensorHandleNumDims.
// The return value was current at the time of TFE_TensorDebugInfo creation. // The return value was current at the time of TFE_TensorDebugInfo creation.
TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims( TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims(

View File

@ -28,19 +28,22 @@ using tensorflow::string;
namespace { namespace {
std::vector<int64> TensorShapeAsVector(TFE_TensorHandle* handle, std::vector<int64> TensorShapeAsVector(const tensorflow::TensorHandle& handle,
TF_Status* status) { tensorflow::Status* status) {
std::vector<int64> shape; std::vector<int64> shape;
int rank = TFE_TensorHandleNumDims(handle, status); int rank = -1;
if (TF_GetCode(status) != TF_OK) { *status = handle.NumDims(&rank);
if (!status->ok()) {
return shape; return shape;
} }
shape.reserve(rank); shape.reserve(rank);
for (int i = 0; i < rank; ++i) { for (int i = 0; i < rank; ++i) {
shape.push_back(TFE_TensorHandleDim(handle, i, status)); tensorflow::int64 dim;
if (TF_GetCode(status) != TF_OK) { *status = handle.Dim(i, &dim);
if (!status->ok()) {
return shape; return shape;
} }
shape.push_back(dim);
} }
return shape; return shape;
} }
@ -50,15 +53,20 @@ std::vector<int64> TensorShapeAsVector(TFE_TensorHandle* handle,
extern "C" { extern "C" {
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
TFE_TensorHandle* handle, TF_Status* status) { TFE_TensorHandle* h, TF_Status* status) {
return h->handle->TensorDebugInfo(&status->status);
}
TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
Status* status) {
const tensorflow::Tensor* tensor; const tensorflow::Tensor* tensor;
status->status = handle->handle->Tensor(&tensor); *status = handle_->Tensor(&tensor);
if (TF_GetCode(status) != TF_OK) { if (!status->ok()) {
return nullptr; return nullptr;
} }
#ifdef TENSORFLOW_EAGER_USE_XLA #ifdef TENSORFLOW_EAGER_USE_XLA
tensorflow::Device* device = handle->handle->device(); tensorflow::Device* device = handle_->device();
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn. // If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
tensorflow::XlaDevice* xla_device = tensorflow::XlaDevice* xla_device =
@ -67,15 +75,15 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
tensorflow::XlaDevice::PaddedShapeFn shape_fn = tensorflow::XlaDevice::PaddedShapeFn shape_fn =
xla_device->metadata().padded_shape_fn(); xla_device->metadata().padded_shape_fn();
xla::Shape padded_shape; xla::Shape padded_shape;
status->status = shape_fn(*tensor, &padded_shape); *status = shape_fn(*tensor, &padded_shape);
if (!status->status.ok()) { if (!status->ok()) {
return nullptr; return nullptr;
} }
if (VLOG_IS_ON(3)) { if (VLOG_IS_ON(3)) {
std::vector<int64> shape_to_log = TensorShapeAsVector(handle, status); std::vector<int64> shape_to_log = TensorShapeAsVector(*handle_, status);
if (!status->status.ok()) { if (!status->ok()) {
// Ignore the status here as we are simply logging. // Ignore the status here as we are simply logging.
status->status = tensorflow::Status::OK(); *status = tensorflow::Status::OK();
} else { } else {
VLOG(3) << "Fully padded shape of [" VLOG(3) << "Fully padded shape of ["
<< absl::StrJoin(shape_to_log, ", ") << "] is " << absl::StrJoin(shape_to_log, ", ") << "] is "
@ -88,7 +96,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
// Currently, the only case of XlaTensor containing a tuple shape is to // Currently, the only case of XlaTensor containing a tuple shape is to
// represent 64 bit ints, doubles, and complex numbers (we don't support // represent 64 bit ints, doubles, and complex numbers (we don't support
// 64bit complex numbers). // 64bit complex numbers).
status->status = tensorflow::errors::InvalidArgument( *status = tensorflow::errors::InvalidArgument(
"XlaTensors should only contain tuples of size 2. Shape: ", "XlaTensors should only contain tuples of size 2. Shape: ",
padded_shape.DebugString()); padded_shape.DebugString());
return nullptr; return nullptr;
@ -100,13 +108,13 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
const xla::Shape& shape1 = const xla::Shape& shape1 =
xla::ShapeUtil::GetTupleElementShape(padded_shape, 1); xla::ShapeUtil::GetTupleElementShape(padded_shape, 1);
if (shape0.IsTuple() || shape1.IsTuple()) { if (shape0.IsTuple() || shape1.IsTuple()) {
status->status = tensorflow::errors::InvalidArgument( *status = tensorflow::errors::InvalidArgument(
"XlaTensors should not contain nested tuples. Shape: ", "XlaTensors should not contain nested tuples. Shape: ",
padded_shape.DebugString()); padded_shape.DebugString());
return nullptr; return nullptr;
} }
if (!xla::ShapeUtil::Equal(shape0, shape1)) { if (!xla::ShapeUtil::Equal(shape0, shape1)) {
status->status = tensorflow::errors::InvalidArgument( *status = tensorflow::errors::InvalidArgument(
"Subshapes of XlaTensors should be the same. Shape: ", "Subshapes of XlaTensors should be the same. Shape: ",
padded_shape.DebugString()); padded_shape.DebugString());
return nullptr; return nullptr;
@ -131,15 +139,15 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
dev_dims.push_back(padded_shape.dimensions(dim_index)); dev_dims.push_back(padded_shape.dimensions(dim_index));
} }
} }
status->status = tensorflow::Status::OK(); *status = tensorflow::Status::OK();
return new TFE_TensorDebugInfo(dev_dims); return new TFE_TensorDebugInfo(dev_dims);
} }
#endif // TENSORFLOW_EAGER_USE_XLA #endif // TENSORFLOW_EAGER_USE_XLA
// If the tensor is not an XLA tensor, the device shape is // If the tensor is not an XLA tensor, the device shape is
// the same as regular tensor shape. // the same as regular tensor shape.
std::vector<int64> dev_dims = TensorShapeAsVector(handle, status); std::vector<int64> dev_dims = TensorShapeAsVector(*handle_, status);
if (TF_GetCode(status) != TF_OK) { if (!status->ok()) {
return nullptr; return nullptr;
} }
return new TFE_TensorDebugInfo(dev_dims); return new TFE_TensorDebugInfo(dev_dims);

View File

@ -18,22 +18,23 @@ limitations under the License.
#include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/lib/monitoring/counter.h" #include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h" #include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h" #include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/profiler/rpc/client/capture_profile.h" #include "tensorflow/core/profiler/rpc/client/capture_profile.h"
#include "tensorflow/core/profiler/rpc/profiler_server.h" #include "tensorflow/core/profiler/rpc/profiler_server.h"
using tensorflow::string; using tensorflow::string;
void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name, void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status, const char* raw_device_name, TF_Status* status) {
TFE_Op* op_to_reset) {
if (op_to_reset) { if (op_to_reset) {
NewOrResetOp(ctx, op_or_function_name, raw_device_name, status, status->status = op_to_reset->operation.Reset(
op_to_reset); op_or_function_name, raw_device_name, false, nullptr);
} else { } else {
TF_SetStatus(status, TF_INVALID_ARGUMENT, TF_SetStatus(status, TF_INVALID_ARGUMENT,
"op_to_reset should not be nullptr"); "op_to_reset should not be nullptr");
@ -41,7 +42,9 @@ void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name,
} }
void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
op->operation.ConsumeInput(h->handle); op->operation.ConsumeInput(
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle());
} }
TFE_Profiler* TFE_NewProfiler() { return new TFE_Profiler(); } TFE_Profiler* TFE_NewProfiler() { return new TFE_Profiler(); }
@ -85,14 +88,14 @@ bool TFE_ProfilerClientStartTracing(const char* service_addr,
int num_tracing_attempts, int num_tracing_attempts,
TF_Status* status) { TF_Status* status) {
tensorflow::Status s = tensorflow::Status s =
tensorflow::profiler::client::ValidateHostPortPair(service_addr); tensorflow::profiler::ValidateHostPortPair(service_addr);
if (!s.ok()) { if (!s.ok()) {
Set_TF_Status_from_Status(status, s); Set_TF_Status_from_Status(status, s);
return false; return false;
} }
s = tensorflow::profiler::client::StartTracing( s = tensorflow::profiler::Trace(service_addr, logdir, worker_list,
service_addr, logdir, worker_list, include_dataset_ops, duration_ms, include_dataset_ops, duration_ms,
num_tracing_attempts); num_tracing_attempts);
tensorflow::Set_TF_Status_from_Status(status, s); tensorflow::Set_TF_Status_from_Status(status, s);
return s.ok(); return s.ok();
} }
@ -101,14 +104,14 @@ void TFE_ProfilerClientMonitor(const char* service_addr, int duration_ms,
int monitoring_level, bool display_timestamp, int monitoring_level, bool display_timestamp,
TF_Buffer* result, TF_Status* status) { TF_Buffer* result, TF_Status* status) {
tensorflow::Status s = tensorflow::Status s =
tensorflow::profiler::client::ValidateHostPortPair(service_addr); tensorflow::profiler::ValidateHostPortPair(service_addr);
if (!s.ok()) { if (!s.ok()) {
Set_TF_Status_from_Status(status, s); Set_TF_Status_from_Status(status, s);
return; return;
} }
string content; string content;
s = tensorflow::profiler::client::Monitor( s = tensorflow::profiler::Monitor(service_addr, duration_ms, monitoring_level,
service_addr, duration_ms, monitoring_level, display_timestamp, &content); display_timestamp, &content);
void* data = tensorflow::port::Malloc(content.length()); void* data = tensorflow::port::Malloc(content.length());
content.copy(static_cast<char*>(data), content.length(), 0); content.copy(static_cast<char*>(data), content.length(), 0);
result->data = data; result->data = data;
@ -616,3 +619,16 @@ void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) { TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
return new TFE_Executor(&ctx->context->Executor()); return new TFE_Executor(&ctx->context->Executor());
} }
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
ctx->context->HostCPU()->parsed_name());
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
void* data = tensorflow::port::Malloc(str.length());
str.copy(static_cast<char*>(data), str.length(), 0);
buf->data = data;
buf->length = str.length();
buf->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
}

View File

@ -29,10 +29,10 @@ extern "C" {
// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster // and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster
// than seperately calling it because if the existing op has the same // than seperately calling it because if the existing op has the same
// `raw_device_name`, it skips parsing and just leave as it is. // `raw_device_name`, it skips parsing and just leave as it is.
TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Context* ctx, TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset,
const char* op_or_function_name, const char* op_or_function_name,
const char* raw_device_name, const char* raw_device_name,
TF_Status* status, TFE_Op* op_to_reset); TF_Status* status);
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
TF_Status* status); TF_Status* status);
@ -458,6 +458,11 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
void (*deallocator)(void* data, size_t len, void* arg), void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg, TF_Status* status); void* deallocator_arg, TF_Status* status);
// Retrieves the address space (i.e. job, replia, task) of the local host and
// saves it in the buffer.
TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
TF_Buffer* buf);
#ifdef __cplusplus #ifdef __cplusplus
} /* end extern "C" */ } /* end extern "C" */
#endif #endif

View File

@ -1,66 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/host_info.h"
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status,
TFE_Op* op_to_reset) {
const char* name = op_or_function_name; // Shorthand
const tensorflow::AttrTypeMap* types;
bool is_function = false;
status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function);
if (!status->status.ok()) {
return nullptr;
}
if (op_to_reset && op_to_reset->ctx != ctx) {
status->status = tensorflow::errors::Internal(
"Cannot reset a TFE_Op from another TFE_Context");
return nullptr;
}
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
if (!is_function) {
const tensorflow::OpDef* op_def;
status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def);
if (!status->status.ok()) {
return nullptr;
}
inference_ctx.reset(new TFE_OpInferenceContext(op_def));
} else if (!ctx->context->FindFunctionByName(name)) {
status->status = tensorflow::errors::NotFound(
"'", name,
"' is neither a type of a primitive operation nor a name "
"of a function registered in binary running on ",
tensorflow::port::Hostname(),
". Make sure the operation or function is "
"registered in the binary running in this process.");
return nullptr;
}
if (op_to_reset) {
status->status = op_to_reset->Reset(
name, is_function, types, raw_device_name, std::move(inference_ctx));
return op_to_reset;
}
TFE_Op* new_op =
new TFE_Op(ctx, name, is_function, types, std::move(inference_ctx));
status->status = new_op->operation.SetDeviceName(raw_device_name);
return new_op;
}

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/context.h"
@ -62,36 +63,10 @@ struct TFE_ContextOptions {
}; };
struct TFE_Context { struct TFE_Context {
TFE_Context(const tensorflow::SessionOptions& opts,
TFE_ContextDevicePlacementPolicy default_device_placement_policy,
TFE_ContextMirroringPolicy default_mirroring_policy, bool async,
const bool lazy_remote_inputs_copy,
const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned,
tensorflow::Rendezvous* rendezvous,
const tensorflow::CustomKernelCreator* custom_kernel_creator)
: context(new tensorflow::EagerContext(
opts,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
default_device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(
default_mirroring_policy),
async, lazy_remote_inputs_copy, device_mgr, device_mgr_owned,
rendezvous, custom_kernel_creator)) {}
~TFE_Context() {
// TODO(iga): Add a separate API method to shutdown TFE_Context so that we
// don't send RPCs and block in destructor.
context->WaitForAndCloseRemoteContexts();
// context->RefCountIsOne() should be true here.
// TODO(iga): Remove EagerContext refcounting.
context->Unref();
}
tensorflow::EagerContext* context; tensorflow::EagerContext* context;
}; };
struct TFE_TensorHandle { struct TFE_TensorHandle {
explicit TFE_TensorHandle(tensorflow::TensorHandle* h) : handle(h) {}
static TFE_TensorHandle* CreateLocalHandle(const class tensorflow::Tensor& t, static TFE_TensorHandle* CreateLocalHandle(const class tensorflow::Tensor& t,
TF_Status* s) { TF_Status* s) {
tensorflow::TensorHandle* handle; tensorflow::TensorHandle* handle;
@ -99,10 +74,11 @@ struct TFE_TensorHandle {
if (!s->status.ok()) { if (!s->status.ok()) {
return nullptr; return nullptr;
} }
return new TFE_TensorHandle(handle); return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
} }
tensorflow::TensorHandle* handle; std::unique_ptr<AbstractTensorHandleInterface> handle;
}; };
struct TFE_TensorDebugInfo { struct TFE_TensorDebugInfo {
@ -113,46 +89,10 @@ struct TFE_TensorDebugInfo {
std::vector<tensorflow::int64> dev_dims; std::vector<tensorflow::int64> dev_dims;
}; };
struct TFE_OpInferenceContext {
explicit TFE_OpInferenceContext(const tensorflow::OpDef* op_def)
: op_def(op_def) {}
const tensorflow::OpDef* op_def; // op definition from protobuf
int input_arg_idx = 0; // arg definition index for the next input to be added
tensorflow::gtl::FlatSet<std::string> attrs; // attributes inferred so far
};
struct TFE_Op { struct TFE_Op {
TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
const tensorflow::AttrTypeMap* t,
std::unique_ptr<TFE_OpInferenceContext> inference_ctx)
: ctx(ctx),
operation(ctx->context, op, is_function, t),
inference_ctx(std::move(inference_ctx)) {}
void Clear() {
operation.Clear();
inference_ctx.reset();
}
tensorflow::Status Reset(const char* op, bool is_function,
const tensorflow::AttrTypeMap* t,
const char* raw_device_name,
std::unique_ptr<TFE_OpInferenceContext> infer_ctx) {
inference_ctx = std::move(infer_ctx);
return operation.Reset(ctx->context, op, is_function, t, raw_device_name,
nullptr);
}
TFE_Context* ctx;
tensorflow::EagerOperation operation; tensorflow::EagerOperation operation;
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
}; };
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status,
TFE_Op* op_to_reset = nullptr);
struct TFE_Profiler { struct TFE_Profiler {
explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); } explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); }

View File

@ -1362,10 +1362,11 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
TFE_TensorHandle* inputs[] = {input1, input2}; TFE_TensorHandle* inputs[] = {input1, input2};
TFE_OpAddInput(concatOp, dim, status); TFE_OpAddInput(concatOp, dim, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CHECK(concatOp->inference_ctx); CHECK(concatOp->operation.OpDef());
TFE_OpAddInput(concatOp, inputs[0], status); TFE_OpAddInput(concatOp, inputs[0], status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_FALSE(concatOp->inference_ctx) << "Inference context is still present"; EXPECT_FALSE(concatOp->operation.OpDef())
<< "Inference context is still present";
TFE_OpAddInput(concatOp, inputs[1], status); TFE_OpAddInput(concatOp, inputs[1], status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);

View File

@ -284,7 +284,7 @@ class ForwardAccumulator {
// Temporarily push or pop transient state for this accumulator. // Temporarily push or pop transient state for this accumulator.
// //
// Allows an accumulator which is currently processing an operation to // Allows an accumulator which is currently processing an operation to
// temporarily reset its state. Without pushing and poping, accumulators // temporarily reset its state. Without pushing and popping, accumulators
// ignore operations executed as a direct result of their own jvp // ignore operations executed as a direct result of their own jvp
// computations. // computations.
void PushState() { call_state_.emplace(nullptr, false); } void PushState() { call_state_.emplace(nullptr, false); }

View File

@ -0,0 +1,90 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
#define TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
// Abstract interface to a TensorHandle.
//
// A TensorHandle is management class around a Tensor which may track additional
// metadata and synchronization.
//
// This allows us to hide concrete implementations of TensorHandle from header
// files. The interface lists the common functionality that must be provided by
// any concrete implementation. However, in cases where the true concrete class
// is needed a static_cast can be applied.
class AbstractTensorHandleInterface {
public:
virtual ~AbstractTensorHandleInterface() {}
// Check if the handle is in a valid initialized state.
virtual bool IsValid(tensorflow::Status* status) const = 0;
// Returns tensor dtype.
virtual TF_DataType DataType() const = 0;
// Returns number of dimensions.
virtual int NumDims(tensorflow::Status* status) const = 0;
// Returns number of elements across all dimensions.
virtual int64_t NumElements(tensorflow::Status* status) const = 0;
// Returns size of specified dimension
virtual int64_t Dim(int dim_index, tensorflow::Status* status) const = 0;
// Returns the device which created the handle.
virtual const char* DeviceName(tensorflow::Status* status) const = 0;
// Returns the device where the tensor was placed.
virtual const char* BackingDeviceName(tensorflow::Status* status) const = 0;
// Returns a tensor for the handle. If tensor is remote, it will be copied.
virtual TF_Tensor* Resolve(tensorflow::Status* status) = 0;
// Returns debug information about the tensor.
virtual TFE_TensorDebugInfo* TensorDebugInfo(tensorflow::Status* status) = 0;
// Return a copy of the handle.
virtual AbstractTensorHandleInterface* Copy() = 0;
};
namespace tensorflow {
class TensorHandleInterface : public AbstractTensorHandleInterface {
public:
explicit TensorHandleInterface(TensorHandle* h) : handle_(h) {}
~TensorHandleInterface() override;
bool IsValid(Status* status) const override;
TF_DataType DataType() const override;
int NumDims(Status* status) const override;
int64_t NumElements(Status* status) const override;
int64_t Dim(int dim_index, Status* status) const override;
const char* DeviceName(Status* status) const override;
const char* BackingDeviceName(Status* status) const override;
TF_Tensor* Resolve(Status* status) override;
TFE_TensorDebugInfo* TensorDebugInfo(Status* status) override;
AbstractTensorHandleInterface* Copy() override;
// TODO(gjn): This is not a very generic interface, but is needed for specific
// use cases.
TensorHandle* Handle() { return handle_; }
private:
TensorHandle* handle_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_

View File

@ -18,37 +18,23 @@ cc_library(
], ],
) )
# Core TensorFlow depends on this, this will be included in main library
cc_library(
name = "filesystem_interface_impl",
srcs = ["filesystem_interface.cc"],
hdrs = ["filesystem_interface.h"],
deps = [
":modular_filesystem",
"//tensorflow/c:tf_file_statistics",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_internal",
"//tensorflow/core:ptr_util",
"//tensorflow/core/platform:env",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:strcat",
"//tensorflow/core/platform:stringpiece",
],
alwayslink = 1,
)
# Core TensorFlow depends on this, will be included in main library # Core TensorFlow depends on this, will be included in main library
cc_library( cc_library(
name = "modular_filesystem", name = "modular_filesystem",
srcs = ["modular_filesystem.cc"], srcs = [
"modular_filesystem.cc",
"modular_filesystem_registration.cc",
"modular_filesystem_registration.h",
],
hdrs = ["modular_filesystem.h"], hdrs = ["modular_filesystem.h"],
deps = [ deps = [
":filesystem_interface", ":filesystem_interface",
"//tensorflow/c:tf_status_helper", "//tensorflow/c:tf_status_helper",
"//tensorflow/core:lib", "//tensorflow/c:tf_status_internal",
"//tensorflow/core:ptr_util", "//tensorflow/core:ptr_util",
"//tensorflow/core/platform:env", "//tensorflow/core/platform:env",
"//tensorflow/core/platform:strcat", "//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
], ],
) )
@ -63,16 +49,12 @@ tf_cc_test(
"notap", # b/139060984, requires implementing modular support for Google filesystem "notap", # b/139060984, requires implementing modular support for Google filesystem
], ],
deps = [ deps = [
":filesystem_interface_impl", ":modular_filesystem",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_internal",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core/lib/io:path", "//tensorflow/core/lib/io:path",
"//tensorflow/core/platform:env", "//tensorflow/core/platform:env",
"//tensorflow/core/platform:error", "//tensorflow/core/platform:error",
"//tensorflow/core/platform:stacktrace_handler", "//tensorflow/core/platform:stacktrace_handler",
"//tensorflow/core/platform:str_util",
"//tensorflow/core/platform:strcat",
"//tensorflow/core/platform:test", "//tensorflow/core/platform:test",
], ],
) )

View File

@ -1,366 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/util/ptr_util.h"
/// This translation unit is linked in core TensorFlow and provides the
/// functionality needed for plugin registration to check ABI/API compatibility,
/// to ensure required methods are present, to ensure plugins are not allowed to
/// change functionality after being loaded and to register the filesystems
/// provided by a plugin. Consult the header file for more information about
/// how this is achieved.
namespace tensorflow {
namespace {
// Checks if the plugin and core ABI numbers match, filling in `status`.
//
// If the numbers don't match, plugin cannot be loaded.
static bool CheckABIHelper(int pluginABI, int coreABI, StringPiece where,
TF_Status* status) {
if (pluginABI != coreABI) {
TF_SetStatus(
status, TF_FAILED_PRECONDITION,
strings::StrCat("Plugin ABI (", pluginABI, ") for ", where,
" operations doesn't match expected core ABI (",
coreABI, "). Plugin cannot be loaded.")
.c_str());
return false;
}
return true;
}
// Checks if the plugin and core ABI numbers match, for all operations.
//
// If the numbers don't match, plugin cannot be loaded.
//
// Uses the simpler `CheckABIHelper(int, int, StringPiece, TF_Status*)`
static bool CheckABI(
int plugin_filesystem_ops_ABI,
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
int plugin_random_access_file_ops_ABI,
const TF_WritableFileOps* plugin_writable_file_ops,
int plugin_writable_file_ops_ABI,
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
int plugin_read_only_memory_region_ops_ABI, TF_Status* status) {
if (!CheckABIHelper(plugin_filesystem_ops_ABI, TF_FILESYSTEM_OPS_ABI,
"filesystem", status))
return false;
if (plugin_random_access_file_ops != nullptr &&
!CheckABIHelper(plugin_random_access_file_ops_ABI,
TF_RANDOM_ACCESS_FILE_OPS_ABI, "random access file",
status))
return false;
if (plugin_writable_file_ops != nullptr &&
!CheckABIHelper(plugin_writable_file_ops_ABI, TF_WRITABLE_FILE_OPS_ABI,
"writable file", status))
return false;
if (plugin_read_only_memory_region_ops != nullptr &&
!CheckABIHelper(plugin_read_only_memory_region_ops_ABI,
TF_READ_ONLY_MEMORY_REGION_OPS_ABI,
"read only memory region", status))
return false;
return true;
}
// Checks if the plugin and core API numbers match, logging mismatches.
static void CheckAPIHelper(int plugin_API, int core_API, StringPiece where) {
if (plugin_API != core_API) {
VLOG(0) << "Plugin API (" << plugin_API << ") for " << where
<< " operations doesn't match expected core API (" << core_API
<< "). Plugin will be loaded but functionality might be missing.";
}
}
// Checks if the plugin and core API numbers match, for all operations.
//
// Uses the simpler `CheckAPIHelper(int, int, StringPiece)`.
static void CheckAPI(
int plugin_filesystem_ops_API,
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
int plugin_random_access_file_ops_API,
const TF_WritableFileOps* plugin_writable_file_ops,
int plugin_writable_file_ops_API,
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
int plugin_read_only_memory_region_ops_API) {
CheckAPIHelper(plugin_filesystem_ops_API, TF_FILESYSTEM_OPS_API,
"filesystem");
if (plugin_random_access_file_ops != nullptr)
CheckAPIHelper(plugin_random_access_file_ops_API,
TF_RANDOM_ACCESS_FILE_OPS_API, "random access file");
if (plugin_writable_file_ops != nullptr)
CheckAPIHelper(plugin_writable_file_ops_API, TF_WRITABLE_FILE_OPS_API,
"writable file");
if (plugin_read_only_memory_region_ops != nullptr)
CheckAPIHelper(plugin_read_only_memory_region_ops_API,
TF_READ_ONLY_MEMORY_REGION_OPS_API,
"read only memory region");
}
// Validates the filesystem operations supplied by the plugin.
static bool ValidateHelper(const TF_FilesystemOps* ops, TF_Status* status) {
if (ops == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without operations");
return false;
}
if (ops->init == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `init` operation");
return false;
}
if (ops->cleanup == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `cleanup` operation");
return false;
}
return true;
}
// Validates the random access file operations supplied by the plugin.
static bool ValidateHelper(const TF_RandomAccessFileOps* ops,
TF_Status* status) {
if (ops == nullptr) {
// We allow filesystems where files can only be written to (from TF code)
return true;
}
if (ops->cleanup == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `cleanup` operation on "
"random access files");
return false;
}
return true;
}
// Validates the writable file operations supplied by the plugin.
static bool ValidateHelper(const TF_WritableFileOps* ops, TF_Status* status) {
if (ops == nullptr) {
// We allow read-only filesystems
return true;
}
if (ops->cleanup == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `cleanup` operation on "
"writable files");
return false;
}
return true;
}
// Validates the read only memory region operations given by the plugin.
static bool ValidateHelper(const TF_ReadOnlyMemoryRegionOps* ops,
TF_Status* status) {
if (ops == nullptr) {
// read only memory region support is always optional
return true;
}
if (ops->cleanup == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `cleanup` operation on "
"read only memory regions");
return false;
}
if (ops->data == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `data` operation on "
"read only memory regions");
return false;
}
if (ops->length == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `length` operation on "
"read only memory regions");
return false;
}
return true;
}
// Validates the operations supplied by the plugin.
//
// Uses the 4 simpler `ValidateHelper(const TF_..., TF_Status*)` to validate
// each individual function table and then checks that the function table for a
// specific file type exists if the plugin offers support for creating that
// type of files.
static bool Validate(
const TF_FilesystemOps* plugin_filesystem_ops,
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
const TF_WritableFileOps* plugin_writable_file_ops,
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
TF_Status* status) {
if (!ValidateHelper(plugin_filesystem_ops, status)) return false;
if (!ValidateHelper(plugin_random_access_file_ops, status)) return false;
if (!ValidateHelper(plugin_writable_file_ops, status)) return false;
if (!ValidateHelper(plugin_read_only_memory_region_ops, status)) return false;
if (plugin_filesystem_ops->new_random_access_file != nullptr &&
plugin_random_access_file_ops == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Filesystem allows creation of random access files but no "
"operations on them have been supplied.");
return false;
}
if ((plugin_filesystem_ops->new_writable_file != nullptr ||
plugin_filesystem_ops->new_appendable_file != nullptr) &&
plugin_writable_file_ops == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Filesystem allows creation of writable files but no "
"operations on them have been supplied.");
return false;
}
if (plugin_filesystem_ops->new_read_only_memory_region_from_file != nullptr &&
plugin_read_only_memory_region_ops == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Filesystem allows creation of readonly memory regions but no "
"operations on them have been supplied.");
return false;
}
return true;
}
// Copies a function table from plugin memory space to core memory space.
//
// This has three benefits:
// * allows having newer plugins than the current core TensorFlow: the
// additional entries in the plugin's table are just discarded;
// * allows having older plugins than the current core TensorFlow (though
// we are still warning users): the entries that core TensorFlow expects
// but plugins didn't provide will be set to `nullptr` values and core
// TensorFlow will know to not call these on behalf of users;
// * increased security as plugins will not be able to alter function table
// after loading up. Thus, malicious plugins can't alter functionality to
// probe for gadgets inside core TensorFlow. We can even protect the area
// of memory where the copies reside to not allow any more writes to it
// after all copies are created.
template <typename T>
static std::unique_ptr<const T> CopyToCore(const T* plugin_ops,
size_t plugin_size) {
if (plugin_ops == nullptr) return nullptr;
size_t copy_size = sizeof(T);
if (plugin_size < copy_size) {
copy_size = plugin_size;
}
auto core_ops = tensorflow::MakeUnique<T>();
memcpy(const_cast<T*>(core_ops.get()), plugin_ops, copy_size);
return core_ops;
}
} // namespace
} // namespace tensorflow
void RegisterFilesystemPlugin(
int plugin_filesystem_ops_ABI, int plugin_filesystem_ops_API,
size_t plugin_filesystem_ops_size, int plugin_random_access_file_ops_ABI,
int plugin_random_access_file_ops_API,
size_t plugin_random_access_file_ops_size, int plugin_writable_file_ops_ABI,
int plugin_writable_file_ops_API, size_t plugin_writable_file_ops_size,
int plugin_read_only_memory_region_ops_ABI,
int plugin_read_only_memory_region_ops_API,
size_t plugin_read_only_memory_region_ops_size, const char* scheme,
const TF_FilesystemOps* plugin_filesystem_ops,
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
const TF_WritableFileOps* plugin_writable_file_ops,
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
TF_Status* status) {
if (scheme == nullptr) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"`scheme` argument must not be `nullptr`.");
return;
}
// ABI numbers must match exactly for plugin to be loaded
if (!tensorflow::CheckABI(
plugin_filesystem_ops_ABI, plugin_random_access_file_ops,
plugin_random_access_file_ops_ABI, plugin_writable_file_ops,
plugin_writable_file_ops_ABI, plugin_read_only_memory_region_ops,
plugin_read_only_memory_region_ops_ABI, status)) {
return;
}
// API numbers should match but mismatch doesn't block plugin load
tensorflow::CheckAPI(plugin_filesystem_ops_API, plugin_random_access_file_ops,
plugin_random_access_file_ops_API,
plugin_writable_file_ops, plugin_writable_file_ops_API,
plugin_read_only_memory_region_ops,
plugin_read_only_memory_region_ops_API);
// Plugin can only be loaded if all supplied ops are valid
if (!tensorflow::Validate(plugin_filesystem_ops,
plugin_random_access_file_ops,
plugin_writable_file_ops,
plugin_read_only_memory_region_ops, status)) {
return;
}
// Copy all the function tables to core TensorFlow memory space
auto core_filesystem_ops = tensorflow::CopyToCore<TF_FilesystemOps>(
plugin_filesystem_ops, plugin_filesystem_ops_size);
auto core_random_access_file_ops =
tensorflow::CopyToCore<TF_RandomAccessFileOps>(
plugin_random_access_file_ops, plugin_random_access_file_ops_size);
auto core_writable_file_ops = tensorflow::CopyToCore<TF_WritableFileOps>(
plugin_writable_file_ops, plugin_writable_file_ops_size);
auto core_read_only_memory_region_ops =
tensorflow::CopyToCore<TF_ReadOnlyMemoryRegionOps>(
plugin_read_only_memory_region_ops,
plugin_read_only_memory_region_ops_size);
// Initialize the opaque filesystem structure
auto filesystem = tensorflow::MakeUnique<TF_Filesystem>();
core_filesystem_ops->init(filesystem.get(), status);
if (!status->status.ok()) {
core_filesystem_ops->cleanup(filesystem.get());
return;
}
// Register new filesystem
status->status = tensorflow::Env::Default()->RegisterFileSystem(
scheme, tensorflow::MakeUnique<tensorflow::ModularFileSystem>(
std::move(filesystem), std::move(core_filesystem_ops),
std::move(core_random_access_file_ops),
std::move(core_writable_file_ops),
std::move(core_read_only_memory_region_ops)));
}

View File

@ -56,7 +56,7 @@ extern "C" {
/// Lifetime: The wrapper data structures are owned by core TensorFlow. The data /// Lifetime: The wrapper data structures are owned by core TensorFlow. The data
/// pointed to by the `void*` members is always owned by the plugin. The plugin /// pointed to by the `void*` members is always owned by the plugin. The plugin
/// will provide functions to call to allocate and deallocate this data (see /// will provide functions to call to allocate and deallocate this data (see
/// next section) and core TensorFlow ensures to call these at the proper time. /// next sections) and core TensorFlow ensures to call these at the proper time.
/// ///
/// Plugins will never receive a `TF_*` pointer that is `nullptr`. Core /// Plugins will never receive a `TF_*` pointer that is `nullptr`. Core
/// TensorFlow will never touch the `void*` wrapped by these structures, except /// TensorFlow will never touch the `void*` wrapped by these structures, except
@ -529,7 +529,7 @@ typedef struct TF_FilesystemOps {
/// If `statuses` is not null, plugins must fill each element with detailed /// If `statuses` is not null, plugins must fill each element with detailed
/// status for each file, as if calling `path_exists` on each one. Core /// status for each file, as if calling `path_exists` on each one. Core
/// TensorFlow initializes the `statuses` array and plugins must use /// TensorFlow initializes the `statuses` array and plugins must use
/// `TF_SetStatus` to set each element instead of dirrectly assigning. /// `TF_SetStatus` to set each element instead of directly assigning.
/// ///
/// DEFAULT IMPLEMENTATION: Checks existence of every file. Needs /// DEFAULT IMPLEMENTATION: Checks existence of every file. Needs
/// `path_exists`. /// `path_exists`.
@ -601,6 +601,10 @@ typedef struct TF_FilesystemOps {
/// ///
/// Plugins must not return `nullptr`. Returning empty strings is allowed. /// Plugins must not return `nullptr`. Returning empty strings is allowed.
/// ///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// This function will be called by core TensorFlow to clean up all path /// This function will be called by core TensorFlow to clean up all path
/// arguments for all other methods in the filesystem API. /// arguments for all other methods in the filesystem API.
/// ///
@ -618,6 +622,10 @@ typedef struct TF_FilesystemOps {
/// In case of error, plugins must set `status` to a value different than /// In case of error, plugins must set `status` to a value different than
/// `TF_OK`, free memory allocated for `entries` and return -1. /// `TF_OK`, free memory allocated for `entries` and return -1.
/// ///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// Plugins: /// Plugins:
/// * Must set `status` to `TF_OK` if all children were returned. /// * Must set `status` to `TF_OK` if all children were returned.
/// * Must set `status` to `TF_NOT_FOUND` if `path` doesn't point to a /// * Must set `status` to `TF_NOT_FOUND` if `path` doesn't point to a
@ -654,6 +662,10 @@ typedef struct TF_FilesystemOps {
/// different than `TF_OK`, free any memory that might have been allocated for /// different than `TF_OK`, free any memory that might have been allocated for
/// `entries` and return -1. /// `entries` and return -1.
/// ///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// Plugins: /// Plugins:
/// * Must set `status` to `TF_OK` if all matches were returned. /// * Must set `status` to `TF_OK` if all matches were returned.
/// * Might use any other error value for `status` to signal other errors. /// * Might use any other error value for `status` to signal other errors.
@ -736,95 +748,132 @@ constexpr size_t TF_FILESYSTEM_OPS_SIZE = sizeof(TF_FilesystemOps);
/// SECTION 4. Plugin registration and initialization /// SECTION 4. Plugin registration and initialization
/// ---------------------------------------------------------------------------- /// ----------------------------------------------------------------------------
/// ///
/// In this section we define two functions: /// In this section we define the API used by core TensorFlow to initialize a
/// * `TF_InitPlugin`: must be present in the plugin shared object as it will /// filesystem provided by a plugin. That is, we define the following:
/// be called by core TensorFlow when the filesystem plugin is loaded; /// * `TF_InitPlugin` function: must be present in the plugin shared object as
/// * `RegisterFilesystemPlugin`: it is implemented by core TensorFlow but /// it will be called by core TensorFlow when the filesystem plugin is
/// plugins must call it in their `TF_InitPlugin`, usually using the macro /// loaded;
/// `TF_REGISTER_FILESYSTEM_PLUGIN`. /// * `TF_FilesystemPluginOps` struct: used to transfer information between
/// plugins and core TensorFlow about the operations provided and metadata;
/// * `TF_FilesystemPluginInfo` struct: similar to the above structure, but
/// collects information about all the file schemes that the plugin provides
/// support for, as well as about the plugin's memory handling routines;
/// * `TF_SetFilesystemVersionMetadata` function: must be called by plugins in
/// their `TF_InitPlugin` to record the versioning information the plugins
/// are compiled against.
/// ///
/// The `TF_InitPlugin` function is used by plugins to set up the data /// The `TF_InitPlugin` function is used by plugins to set up the data
/// structures that implement this interface, as presented in Section 2. /// structures that implement this interface, as presented in Section 2. In
/// /// order to not have plugin shared objects call back symbols defined in core
/// The `RegisterFilesystemPlugin` is used by core TensorFlow to check that /// TensorFlow, `TF_InitPlugin` has a `TF_FilesystemPluginInfo` argument which
/// plugins satisfy the requirements expected by core TensorFlow, as follows: /// the plugin must fill (using the `TF_SetFilesystemVersionMetadata` for the
/// 1. If ABI numbers don't match we don't load the plugin, else we continue. /// metadata and setting up all the supported operations and the URI schemes
/// 2. If the API numbers are mismatched, we warn the user and continue /// that are supported).
/// loading the plugin.
/// 3. If any required operation is missing, we stop loading the plugin.
///
/// If all these checks succeed, we copy the plugin operations to a different
/// memory location so that core TensorFlow has the guarantee that they won't be
/// changed by plugins at a later time. Finally, we initialize the opaque
/// pointer of `TF_Filesystem` by calling the required `init` function of
/// `TF_FilesystemOps` and if that succeeds we register the filesystem.
// Initializes a TensorFlow plugin. /// This structure incorporates the operations defined in Section 2 and the
// /// metadata defined in section 3, allowing plugins to define different ops
// Must be implemented by the plugin DSO. It is called by TensorFlow runtime. /// for different URI schemes.
// ///
// Filesystem plugins can be loaded on demand by users via /// Every URI scheme is of the form "fs" for URIs of form "fs:///path/to/file".
// `Env::LoadLibrary` or during TensorFlow's startup if they are on certain /// For local filesystems (i.e., when the URI is "/path/to/file"), the scheme
// paths (although this has a security risk if two plugins register for the /// must be "". The scheme must never be `nullptr`.
// same filesystem and the malicious one loads before the legimitate one - ///
// but we consider this to be something that users should care about and /// Every plugin fills this in `TF_InitPlugin`, using the alocator passed as
// manage themselves). In both of these cases, core TensorFlow looks for /// argument to allocate memory. After `TF_InitPlugin` finishes, core
// the `TF_InitPlugin` symbol and calls that function. /// TensorFlow uses the information present in this to initialize filesystems
// /// for the URI schemes that the plugin requests.
// A plugin is loaded only if this `status` is `TF_OK` after the call. ///
TF_CAPI_EXPORT extern void TF_InitPlugin(TF_Status* status); /// All pointers defined in this structure point to memory allocated by the DSO
/// using an allocator provided by core TensorFlow when calling `TF_InitPlugin`.
///
/// IMPORTANT: To maintain binary compatibility, the layout of this structure
/// must not change! In the unlikely case that a new type of file needs to be
/// supported, add the new ops and metadata at the end of the structure.
typedef struct TF_FilesystemPluginOps {
char* scheme;
int filesystem_ops_abi;
int filesystem_ops_api;
size_t filesystem_ops_size;
TF_FilesystemOps* filesystem_ops;
int random_access_file_ops_abi;
int random_access_file_ops_api;
size_t random_access_file_ops_size;
TF_RandomAccessFileOps* random_access_file_ops;
int writable_file_ops_abi;
int writable_file_ops_api;
size_t writable_file_ops_size;
TF_WritableFileOps* writable_file_ops;
int read_only_memory_region_ops_abi;
int read_only_memory_region_ops_api;
size_t read_only_memory_region_ops_size;
TF_ReadOnlyMemoryRegionOps* read_only_memory_region_ops;
} TF_FilesystemPluginOps;
/// Registers a filesystem plugin so that core TensorFlow can use it. /// This structure gathers together all the operations provided by the plugin.
/// ///
/// Must be called by the plugin during `TF_InitPlugin`, usually by using the /// Plugins must provide exactly `num_schemes` elements in the `ops` array.
/// convenience `TF_REGISTER_FILESYSTEM_PLUGIN` macro.
/// ///
/// Arguments (grouped by category): /// Since memory that is allocated by the DSO gets transferred to core
/// * `..ABI`: ABI compatibility numbers (see Section 3.). /// TensorFlow, we need to provide a way for the allocation and deallocation to
/// * `..API`: API compatibility numbers (see Section 3.). /// match. This is why this structure also defines `plugin_memory_allocate` and
/// * `..Size`: Sizes of the operation tables (see Section 3.). /// `plugin_memory_free` members.
/// * `scheme`: The URI scheme that plugin is registering filesystems for.
/// Must be of the form "fs" for URIs of form "fs:///path/to/file". For
/// local filesystems (i.e., when the URI is "/path/to/file"), `scheme`
/// must be "". Must never be `nullptr`.
/// * `..Ops`: The function tables provided by the plugin. Owned by the
/// plugin, but core TensorFlow makes a copy of these.
/// * `status`: The output variable for representing success/failure.
/// ///
/// Sets `status` to `TF_OK` if plugin was registered and filesystem operations /// All memory allocated by the plugin that will be owned by core TensorFlow
/// can be invoked from anywhere during TensorFlow's runtime. Any other value of /// must be allocated using the allocator in this structure. Core TensorFlow
/// `status` means that plugin failed to load properly and as such the /// will use the deallocator to free this memory once it no longer needs it.
/// operations it provides cannot be used at all (i.e., core TensorFlow will ///
/// never run them, returning early with `TF_UNIMPLEMENTED` or similar error /// IMPORTANT: To maintain binary compatibility, the layout of this structure
/// values). /// must not change! In the unlikely case that new global operations must be
TF_CAPI_EXPORT extern void RegisterFilesystemPlugin( /// provided, add them at the end of the structure.
int pluginFilesystemOpsABI, int pluginFilesystemOpsAPI, typedef struct TF_FilesystemPluginInfo {
size_t pluginFilesystemOpsSize, int pluginRandomAccessFileOpsABI, size_t num_schemes;
int pluginRandomAccessFileOpsAPI, size_t pluginRandomAccessFileOpsSize, TF_FilesystemPluginOps* ops;
int pluginWritableFileOpsABI, int pluginWritableFileOpsAPI, void* (*plugin_memory_allocate)(size_t size);
size_t pluginWritableFileOpsSize, int pluginReadOnlyMemoryRegionOpsABI, void (*plugin_memory_free)(void* ptr);
int pluginReadOnlyMemoryRegionOpsAPI, } TF_FilesystemPluginInfo;
size_t pluginReadOnlyMemoryRegionOpsSize, const char* scheme,
const TF_FilesystemOps* pluginFilesystemOps,
const TF_RandomAccessFileOps* pluginRandomAccessFileOps,
const TF_WritableFileOps* pluginWritableFileOps,
const TF_ReadOnlyMemoryRegionOps* pluginReadOnlyMemoryRegionOps,
TF_Status* status);
/// This macro is just a convenience wrapper around `RegisterFilesystemPlugin`. /// Convenience function for setting the versioning metadata.
/// Plugins should prefer using this macro instead of a direct call. ///
#define TF_REGISTER_FILESYSTEM_PLUGIN( \ /// The argument is guaranteed to not be `nullptr`.
scheme, pluginFilesystemOps, pluginRandomAccessFileOps, \ ///
pluginWritableFileOps, pluginReadOnlyMemoryRegionOps, status) \ /// We want this to be defined in the plugin's memory space and we guarantee
RegisterFilesystemPlugin( \ /// that core TensorFlow will never call this.
TF_FILESYSTEM_OPS_ABI, TF_FILESYSTEM_OPS_API, TF_FILESYSTEM_OPS_SIZE, \ static inline void TF_SetFilesystemVersionMetadata(
TF_RANDOM_ACCESS_FILE_OPS_ABI, TF_RANDOM_ACCESS_FILE_OPS_API, \ TF_FilesystemPluginOps* ops) {
TF_RANDOM_ACCESS_FILE_OPS_SIZE, TF_WRITABLE_FILE_OPS_ABI, \ ops->filesystem_ops_abi = TF_FILESYSTEM_OPS_ABI;
TF_WRITABLE_FILE_OPS_API, TF_WRITABLE_FILE_OPS_SIZE, \ ops->filesystem_ops_api = TF_FILESYSTEM_OPS_API;
TF_READ_ONLY_MEMORY_REGION_OPS_ABI, TF_READ_ONLY_MEMORY_REGION_OPS_API, \ ops->filesystem_ops_size = TF_FILESYSTEM_OPS_SIZE;
TF_READ_ONLY_MEMORY_REGION_OPS_SIZE, scheme, pluginFilesystemOps, \ ops->random_access_file_ops_abi = TF_RANDOM_ACCESS_FILE_OPS_ABI;
pluginRandomAccessFileOps, pluginWritableFileOps, \ ops->random_access_file_ops_api = TF_RANDOM_ACCESS_FILE_OPS_API;
pluginReadOnlyMemoryRegionOps, status) ops->random_access_file_ops_size = TF_RANDOM_ACCESS_FILE_OPS_SIZE;
ops->writable_file_ops_abi = TF_WRITABLE_FILE_OPS_ABI;
ops->writable_file_ops_api = TF_WRITABLE_FILE_OPS_API;
ops->writable_file_ops_size = TF_WRITABLE_FILE_OPS_SIZE;
ops->read_only_memory_region_ops_abi = TF_READ_ONLY_MEMORY_REGION_OPS_ABI;
ops->read_only_memory_region_ops_api = TF_READ_ONLY_MEMORY_REGION_OPS_API;
ops->read_only_memory_region_ops_size = TF_READ_ONLY_MEMORY_REGION_OPS_SIZE;
}
/// Initializes a TensorFlow plugin.
///
/// Must be implemented by the plugin DSO. It is called by TensorFlow runtime.
///
/// Filesystem plugins can be loaded on demand by users via
/// `Env::LoadLibrary` or during TensorFlow's startup if they are on certain
/// paths (although this has a security risk if two plugins register for the
/// same filesystem and the malicious one loads before the legimitate one -
/// but we consider this to be something that users should care about and
/// manage themselves). In both of these cases, core TensorFlow looks for
/// the `TF_InitPlugin` symbol and calls this function.
///
/// For every filesystem URI scheme that this plugin supports, the plugin must
/// add one `TF_FilesystemPluginInfo` entry in `plugin_info->ops` and call
/// `TF_SetFilesystemVersionMetadata` for that entry.
///
/// Plugins must also initialize `plugin_info->plugin_memory_allocate` and
/// `plugin_info->plugin_memory_free` to ensure memory allocated by plugin is
/// freed in a compatible way.
TF_CAPI_EXPORT extern void TF_InitPlugin(TF_FilesystemPluginInfo* plugin_info);
#ifdef __cplusplus #ifdef __cplusplus
} // end extern "C" } // end extern "C"

View File

@ -18,11 +18,10 @@ limitations under the License.
#include <string> #include <string>
#include <utility> #include <utility>
#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h"
#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/file_system_helper.h" #include "tensorflow/core/platform/file_system_helper.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/util/ptr_util.h" #include "tensorflow/core/util/ptr_util.h"
// TODO(mihaimaruseac): After all filesystems are converted, all calls to // TODO(mihaimaruseac): After all filesystems are converted, all calls to
@ -165,16 +164,18 @@ Status ModularFileSystem::GetChildren(const std::string& dir,
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus); UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
std::string translated_name = TranslateName(dir); std::string translated_name = TranslateName(dir);
char** children; // Note that `children` is allocated by the plugin and freed by core
// TensorFlow, so we need to use `plugin_memory_free_` here.
char** children = nullptr;
const int num_children = const int num_children =
ops_->get_children(filesystem_.get(), translated_name.c_str(), &children, ops_->get_children(filesystem_.get(), translated_name.c_str(), &children,
plugin_status.get()); plugin_status.get());
if (num_children >= 0) { if (num_children >= 0) {
for (int i = 0; i < num_children; i++) { for (int i = 0; i < num_children; i++) {
result->push_back(std::string(children[i])); result->push_back(std::string(children[i]));
free(children[i]); plugin_memory_free_(children[i]);
} }
free(children); plugin_memory_free_(children);
} }
return StatusFromTF_Status(plugin_status.get()); return StatusFromTF_Status(plugin_status.get());
@ -186,15 +187,17 @@ Status ModularFileSystem::GetMatchingPaths(const std::string& pattern,
return internal::GetMatchingPaths(this, Env::Default(), pattern, result); return internal::GetMatchingPaths(this, Env::Default(), pattern, result);
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus); UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
char** matches; // Note that `matches` is allocated by the plugin and freed by core
// TensorFlow, so we need to use `plugin_memory_free_` here.
char** matches = nullptr;
const int num_matches = ops_->get_matching_paths( const int num_matches = ops_->get_matching_paths(
filesystem_.get(), pattern.c_str(), &matches, plugin_status.get()); filesystem_.get(), pattern.c_str(), &matches, plugin_status.get());
if (num_matches >= 0) { if (num_matches >= 0) {
for (int i = 0; i < num_matches; i++) { for (int i = 0; i < num_matches; i++) {
result->push_back(std::string(matches[i])); result->push_back(std::string(matches[i]));
free(matches[i]); plugin_memory_free_(matches[i]);
} }
free(matches); plugin_memory_free_(matches);
} }
return StatusFromTF_Status(plugin_status.get()); return StatusFromTF_Status(plugin_status.get());
@ -358,7 +361,8 @@ std::string ModularFileSystem::TranslateName(const std::string& name) const {
CHECK(p != nullptr) << "TranslateName(" << name << ") returned nullptr"; CHECK(p != nullptr) << "TranslateName(" << name << ") returned nullptr";
std::string ret(p); std::string ret(p);
free(p); // Since `p` is allocated by plugin, free it using plugin's method.
plugin_memory_free_(p);
return ret; return ret;
} }
@ -435,4 +439,8 @@ Status ModularWritableFile::Tell(int64* position) {
return StatusFromTF_Status(plugin_status.get()); return StatusFromTF_Status(plugin_status.get());
} }
Status RegisterFilesystemPlugin(const std::string& dso_path) {
return filesystem_registration::RegisterFilesystemPluginImpl(dso_path);
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -32,7 +32,7 @@ namespace tensorflow {
// TODO(b/143949615): After all filesystems are converted, this file will be // TODO(b/143949615): After all filesystems are converted, this file will be
// moved to core/platform, and this class can become a singleton and replace the // moved to core/platform, and this class can become a singleton and replace the
// need for `Env::Default()`. At that time, we might decide to remove the need // need for `Env::Default()`. At that time, we might decide to remove the need
// for `Env::Default()` altoghether, but that's a different project, not in // for `Env::Default()` altogether, but that's a different project, not in
// scope for now. I'm just mentioning this here as that transition will mean // scope for now. I'm just mentioning this here as that transition will mean
// removal of the registration part from `Env` and adding it here instead: we // removal of the registration part from `Env` and adding it here instead: we
// will need tables to hold for each scheme the function tables that implement // will need tables to hold for each scheme the function tables that implement
@ -46,12 +46,16 @@ class ModularFileSystem final : public FileSystem {
std::unique_ptr<const TF_RandomAccessFileOps> random_access_file_ops, std::unique_ptr<const TF_RandomAccessFileOps> random_access_file_ops,
std::unique_ptr<const TF_WritableFileOps> writable_file_ops, std::unique_ptr<const TF_WritableFileOps> writable_file_ops,
std::unique_ptr<const TF_ReadOnlyMemoryRegionOps> std::unique_ptr<const TF_ReadOnlyMemoryRegionOps>
read_only_memory_region_ops) read_only_memory_region_ops,
std::function<void*(size_t)> plugin_memory_allocate,
std::function<void(void*)> plugin_memory_free)
: filesystem_(std::move(filesystem)), : filesystem_(std::move(filesystem)),
ops_(std::move(filesystem_ops)), ops_(std::move(filesystem_ops)),
random_access_file_ops_(std::move(random_access_file_ops)), random_access_file_ops_(std::move(random_access_file_ops)),
writable_file_ops_(std::move(writable_file_ops)), writable_file_ops_(std::move(writable_file_ops)),
read_only_memory_region_ops_(std::move(read_only_memory_region_ops)) {} read_only_memory_region_ops_(std::move(read_only_memory_region_ops)),
plugin_memory_allocate_(std::move(plugin_memory_allocate)),
plugin_memory_free_(std::move(plugin_memory_free)) {}
~ModularFileSystem() override { ops_->cleanup(filesystem_.get()); } ~ModularFileSystem() override { ops_->cleanup(filesystem_.get()); }
@ -93,6 +97,8 @@ class ModularFileSystem final : public FileSystem {
std::unique_ptr<const TF_WritableFileOps> writable_file_ops_; std::unique_ptr<const TF_WritableFileOps> writable_file_ops_;
std::unique_ptr<const TF_ReadOnlyMemoryRegionOps> std::unique_ptr<const TF_ReadOnlyMemoryRegionOps>
read_only_memory_region_ops_; read_only_memory_region_ops_;
std::function<void*(size_t)> plugin_memory_allocate_;
std::function<void(void*)> plugin_memory_free_;
TF_DISALLOW_COPY_AND_ASSIGN(ModularFileSystem); TF_DISALLOW_COPY_AND_ASSIGN(ModularFileSystem);
}; };
@ -156,6 +162,9 @@ class ModularReadOnlyMemoryRegion final : public ReadOnlyMemoryRegion {
TF_DISALLOW_COPY_AND_ASSIGN(ModularReadOnlyMemoryRegion); TF_DISALLOW_COPY_AND_ASSIGN(ModularReadOnlyMemoryRegion);
}; };
// Registers a filesystem plugin so that core TensorFlow can use it.
Status RegisterFilesystemPlugin(const std::string& dso_path);
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_H_ #endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_H_

View File

@ -0,0 +1,346 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
// Checks that all schemes provided by a plugin are valid.
// TODO(mihaimaruseac): More validation could be done here, based on supported
// charset, maximum length, etc. Punting it for later.
static Status ValidateScheme(const char* scheme) {
if (scheme == nullptr)
return errors::InvalidArgument(
"Attempted to register filesystem with `nullptr` URI scheme");
return Status::OK();
}
// Checks if the plugin and core ABI numbers match.
//
// If the numbers don't match, plugin cannot be loaded.
static Status CheckABI(int pluginABI, int coreABI, StringPiece where) {
if (pluginABI != coreABI)
return errors::FailedPrecondition(
strings::StrCat("Plugin ABI (", pluginABI, ") for ", where,
" operations doesn't match expected core ABI (",
coreABI, "). Plugin cannot be loaded."));
return Status::OK();
}
// Checks if the plugin and core ABI numbers match, for all operations.
//
// If the numbers don't match, plugin cannot be loaded.
//
// Uses the simpler `CheckABI(int, int, StringPiece)`.
static Status ValidateABI(const TF_FilesystemPluginOps* ops) {
TF_RETURN_IF_ERROR(
CheckABI(ops->filesystem_ops_abi, TF_FILESYSTEM_OPS_ABI, "filesystem"));
if (ops->random_access_file_ops != nullptr)
TF_RETURN_IF_ERROR(CheckABI(ops->random_access_file_ops_abi,
TF_RANDOM_ACCESS_FILE_OPS_ABI,
"random access file"));
if (ops->writable_file_ops != nullptr)
TF_RETURN_IF_ERROR(CheckABI(ops->writable_file_ops_abi,
TF_WRITABLE_FILE_OPS_ABI, "writable file"));
if (ops->read_only_memory_region_ops != nullptr)
TF_RETURN_IF_ERROR(CheckABI(ops->read_only_memory_region_ops_abi,
TF_READ_ONLY_MEMORY_REGION_OPS_ABI,
"read only memory region"));
return Status::OK();
}
// Checks if the plugin and core API numbers match, logging mismatches.
static void CheckAPI(int plugin_API, int core_API, StringPiece where) {
if (plugin_API != core_API) {
VLOG(0) << "Plugin API (" << plugin_API << ") for " << where
<< " operations doesn't match expected core API (" << core_API
<< "). Plugin will be loaded but functionality might be missing.";
}
}
// Checks if the plugin and core API numbers match, for all operations.
//
// Uses the simpler `CheckAPIHelper(int, int, StringPiece)`.
static void ValidateAPI(const TF_FilesystemPluginOps* ops) {
CheckAPI(ops->filesystem_ops_api, TF_FILESYSTEM_OPS_API, "filesystem");
if (ops->random_access_file_ops != nullptr)
CheckAPI(ops->random_access_file_ops_api, TF_RANDOM_ACCESS_FILE_OPS_API,
"random access file");
if (ops->writable_file_ops != nullptr)
CheckAPI(ops->writable_file_ops_api, TF_WRITABLE_FILE_OPS_API,
"writable file");
if (ops->read_only_memory_region_ops != nullptr)
CheckAPI(ops->read_only_memory_region_ops_api,
TF_READ_ONLY_MEMORY_REGION_OPS_API, "read only memory region");
}
// Validates the filesystem operations supplied by the plugin.
static Status ValidateHelper(const TF_FilesystemOps* ops) {
if (ops == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without operations");
if (ops->init == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `init` operation");
if (ops->cleanup == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `cleanup` operation");
return Status::OK();
}
// Validates the random access file operations supplied by the plugin.
static Status ValidateHelper(const TF_RandomAccessFileOps* ops) {
if (ops == nullptr) {
// We allow filesystems where files can only be written to (from TF code)
return Status::OK();
}
if (ops->cleanup == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `cleanup` operation on random "
"access files");
return Status::OK();
}
// Validates the writable file operations supplied by the plugin.
static Status ValidateHelper(const TF_WritableFileOps* ops) {
if (ops == nullptr) {
// We allow read-only filesystems
return Status::OK();
}
if (ops->cleanup == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `cleanup` operation on writable "
"files");
return Status::OK();
}
// Validates the read only memory region operations given by the plugin.
static Status ValidateHelper(const TF_ReadOnlyMemoryRegionOps* ops) {
if (ops == nullptr) {
// read only memory region support is always optional
return Status::OK();
}
if (ops->cleanup == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `cleanup` operation on read "
"only memory regions");
if (ops->data == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `data` operation on read only "
"memory regions");
if (ops->length == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `length` operation on read only "
"memory regions");
return Status::OK();
}
// Validates the operations supplied by the plugin.
//
// Uses the 4 simpler `ValidateHelper(const TF_...*)` to validate each
// individual function table and then checks that the function table for a
// specific file type exists if the plugin offers support for creating that
// type of files.
static Status ValidateOperations(const TF_FilesystemPluginOps* ops) {
TF_RETURN_IF_ERROR(ValidateHelper(ops->filesystem_ops));
TF_RETURN_IF_ERROR(ValidateHelper(ops->random_access_file_ops));
TF_RETURN_IF_ERROR(ValidateHelper(ops->writable_file_ops));
TF_RETURN_IF_ERROR(ValidateHelper(ops->read_only_memory_region_ops));
if (ops->filesystem_ops->new_random_access_file != nullptr &&
ops->random_access_file_ops == nullptr)
return errors::FailedPrecondition(
"Filesystem allows creation of random access files but no "
"operations on them have been supplied.");
if ((ops->filesystem_ops->new_writable_file != nullptr ||
ops->filesystem_ops->new_appendable_file != nullptr) &&
ops->writable_file_ops == nullptr)
return errors::FailedPrecondition(
"Filesystem allows creation of writable files but no "
"operations on them have been supplied.");
if (ops->filesystem_ops->new_read_only_memory_region_from_file != nullptr &&
ops->read_only_memory_region_ops == nullptr)
return errors::FailedPrecondition(
"Filesystem allows creation of readonly memory regions but no "
"operations on them have been supplied.");
return Status::OK();
}
// Copies a function table from plugin memory space to core memory space.
//
// This has three benefits:
// * allows having newer plugins than the current core TensorFlow: the
// additional entries in the plugin's table are just discarded;
// * allows having older plugins than the current core TensorFlow (though
// we are still warning users): the entries that core TensorFlow expects
// but plugins didn't provide will be set to `nullptr` values and core
// TensorFlow will know to not call these on behalf of users;
// * increased security as plugins will not be able to alter function table
// after loading up. Thus, malicious plugins can't alter functionality to
// probe for gadgets inside core TensorFlow. We can even protect the area
// of memory where the copies reside to not allow any more writes to it
// after all copies are created.
template <typename T>
static std::unique_ptr<const T> CopyToCore(const T* plugin_ops,
size_t plugin_size) {
if (plugin_ops == nullptr) return nullptr;
size_t copy_size = std::min(plugin_size, sizeof(T));
auto core_ops = tensorflow::MakeUnique<T>();
memset(core_ops.get(), 0, sizeof(T));
memcpy(core_ops.get(), plugin_ops, copy_size);
return core_ops;
}
// Registers one filesystem from the plugin.
//
// Must be called only with `index` a valid index in `info->ops`.
static Status RegisterFileSystem(const TF_FilesystemPluginInfo* info,
int index) {
// Step 1: Copy all the function tables to core TensorFlow memory space
auto core_filesystem_ops = CopyToCore<TF_FilesystemOps>(
info->ops[index].filesystem_ops, info->ops[index].filesystem_ops_size);
auto core_random_access_file_ops = CopyToCore<TF_RandomAccessFileOps>(
info->ops[index].random_access_file_ops,
info->ops[index].random_access_file_ops_size);
auto core_writable_file_ops =
CopyToCore<TF_WritableFileOps>(info->ops[index].writable_file_ops,
info->ops[index].writable_file_ops_size);
auto core_read_only_memory_region_ops =
CopyToCore<TF_ReadOnlyMemoryRegionOps>(
info->ops[index].read_only_memory_region_ops,
info->ops[index].read_only_memory_region_ops_size);
// Step 2: Initialize the opaque filesystem structure
auto filesystem = tensorflow::MakeUnique<TF_Filesystem>();
TF_Status* c_status = TF_NewStatus();
Status status = Status::OK();
core_filesystem_ops->init(filesystem.get(), c_status);
status = Status(c_status->status);
TF_DeleteStatus(c_status);
if (!status.ok()) return status;
// Step 3: Actual registration
return Env::Default()->RegisterFileSystem(
info->ops[index].scheme,
tensorflow::MakeUnique<tensorflow::ModularFileSystem>(
std::move(filesystem), std::move(core_filesystem_ops),
std::move(core_random_access_file_ops),
std::move(core_writable_file_ops),
std::move(core_read_only_memory_region_ops),
info->plugin_memory_allocate, info->plugin_memory_free));
}
// Registers filesystem at `index`, if plugin is providing valid information.
//
// Extracted to a separate function so that pointers inside `info` are freed
// by the caller regardless of whether validation/registration failed or not.
//
// Must be called only with `index` a valid index in `info->ops`.
static Status ValidateAndRegisterFilesystems(
const TF_FilesystemPluginInfo* info, int index) {
TF_RETURN_IF_ERROR(ValidateScheme(info->ops[index].scheme));
TF_RETURN_IF_ERROR(ValidateABI(&info->ops[index]));
ValidateAPI(&info->ops[index]); // we just warn on API number mismatch
TF_RETURN_IF_ERROR(ValidateOperations(&info->ops[index]));
TF_RETURN_IF_ERROR(RegisterFileSystem(info, index));
return Status::OK();
}
// Ensures that the plugin provides the required memory management operations.
static Status ValidatePluginMemoryRoutines(
const TF_FilesystemPluginInfo* info) {
if (info->plugin_memory_allocate == nullptr)
return errors::FailedPrecondition(
"Cannot load filesystem plugin which does not provide "
"`plugin_memory_allocate`");
if (info->plugin_memory_free == nullptr)
return errors::FailedPrecondition(
"Cannot load filesystem plugin which does not provide "
"`plugin_memory_free`");
return Status::OK();
}
namespace filesystem_registration {
Status RegisterFilesystemPluginImpl(const std::string& dso_path) {
// Step 1: Load plugin
Env* env = Env::Default();
void* dso_handle;
TF_RETURN_IF_ERROR(env->LoadLibrary(dso_path.c_str(), &dso_handle));
// Step 2: Load symbol for `TF_InitPlugin`
void* dso_symbol;
TF_RETURN_IF_ERROR(
env->GetSymbolFromLibrary(dso_handle, "TF_InitPlugin", &dso_symbol));
// Step 3: Call `TF_InitPlugin`
TF_FilesystemPluginInfo info;
memset(&info, 0, sizeof(info));
auto TF_InitPlugin =
reinterpret_cast<int (*)(TF_FilesystemPluginInfo*)>(dso_symbol);
TF_InitPlugin(&info);
// Step 4: Ensure plugin provides the memory management functions.
TF_RETURN_IF_ERROR(ValidatePluginMemoryRoutines(&info));
// Step 5: Validate and register all filesystems
// Try to register as many filesystems as possible.
// Free memory once we no longer need it
Status status;
for (int i = 0; i < info.num_schemes; i++) {
status.Update(ValidateAndRegisterFilesystems(&info, i));
info.plugin_memory_free(info.ops[i].scheme);
info.plugin_memory_free(info.ops[i].filesystem_ops);
info.plugin_memory_free(info.ops[i].random_access_file_ops);
info.plugin_memory_free(info.ops[i].writable_file_ops);
info.plugin_memory_free(info.ops[i].read_only_memory_region_ops);
}
info.plugin_memory_free(info.ops);
return status;
}
} // namespace filesystem_registration
} // namespace tensorflow

View File

@ -0,0 +1,28 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
namespace filesystem_registration {
Status RegisterFilesystemPluginImpl(const std::string& dso_path);
} // namespace filesystem_registration
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_

View File

@ -1,35 +1,47 @@
# Experimental posix filesystem plugin. # Experimental posix filesystem plugin.
load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
package( package(
default_visibility = ["//visibility:private"],
licenses = ["notice"], # Apache 2.0 licenses = ["notice"], # Apache 2.0
) )
# Although this target results in a shared object that will be loaded at # Filesystem implementation for POSIX environments: Linux, MacOS, Android, etc.
# runtime, this target must be a `cc_library` instead of a `cc_binary`. Making tf_cc_shared_object(
# it a `cc_binary` requires `linkshared = True`. In turn, this brings in several name = "libposix_filesystem.so",
# TensorFlow symbols under `tensorflow::` namespace, for which we have no ABI framework_so = [],
# guarantees. Hence, in order to maintain ABI compatibility, this is marked as a linkstatic = False,
# `cc_library` for now and we will revisit in the future. visibility = ["//visibility:public"],
# TODO(mihaimaruseac): Determine if `cc_binary` makes more sense (when all deps = [":posix_filesystem_impl"],
# filesystems are converted and BUILD files are refactored to be modular). )
# TODO(b/144585140): The helpers should be separated into a different BUILD target
# but doing that would result in symbols not being visible when loading plugin. # The real implementation of the filesystem.
# Revisit this once POSIX filesystem completely lands. See also the other TODO.
# This also has the unfortunate effect that both versions of copy_file get
# compiled, regardless of which one actually gets used!
cc_library( cc_library(
name = "posix_filesystem", name = "posix_filesystem_impl",
srcs = [ srcs = ["posix_filesystem.cc"],
"posix_filesystem.cc",
"posix_filesystem_helper.cc",
"posix_filesystem_helper.h",
"copy_file.h",
] + select({
"//tensorflow:linux_x86_64": ["copy_file_linux.cc"],
"//conditions:default": ["copy_file_portable.cc"],
}),
deps = [ deps = [
":posix_filesystem_helper",
"//tensorflow/c:tf_status", "//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface", "//tensorflow/c/experimental/filesystem:filesystem_interface",
], ],
) )
# Library implementing helper functionality, so that the above only contains
# the API implementation for modular filesystems.
cc_library(
name = "posix_filesystem_helper",
srcs = ["posix_filesystem_helper.cc"],
hdrs = ["posix_filesystem_helper.h"],
deps = [":copy_file"],
)
# On Linux, we can copy files faster using `sendfile`. But not elsewhere.
# Hence, this private library to select which implementation to use.
cc_library(
name = "copy_file",
srcs = select({
"//tensorflow:linux_x86_64": ["copy_file_linux.cc"],
"//conditions:default": ["copy_file_portable.cc"],
}),
hdrs = ["copy_file.h"],
)

View File

@ -24,8 +24,6 @@ limitations under the License.
#include <sys/stat.h> #include <sys/stat.h>
#include <unistd.h> #include <unistd.h>
#include <vector>
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" #include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_helper.h" #include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_helper.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
@ -33,6 +31,9 @@ limitations under the License.
// Implementation of a filesystem for POSIX environments. // Implementation of a filesystem for POSIX environments.
// This filesystem will support `file://` and empty (local) URI schemes. // This filesystem will support `file://` and empty (local) URI schemes.
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); }
// SECTION 1. Implementation for `TF_RandomAccessFile` // SECTION 1. Implementation for `TF_RandomAccessFile`
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
namespace tf_random_access_file { namespace tf_random_access_file {
@ -45,7 +46,9 @@ typedef struct PosixFile {
static void Cleanup(TF_RandomAccessFile* file) { static void Cleanup(TF_RandomAccessFile* file) {
auto posix_file = static_cast<PosixFile*>(file->plugin_file); auto posix_file = static_cast<PosixFile*>(file->plugin_file);
close(posix_file->fd); close(posix_file->fd);
free(const_cast<char*>(posix_file->filename)); // This would be safe to free using `free` directly as it is only opaque.
// However, it is better to be consistent everywhere.
plugin_memory_free(const_cast<char*>(posix_file->filename));
delete posix_file; delete posix_file;
} }
@ -100,7 +103,7 @@ typedef struct PosixFile {
static void Cleanup(TF_WritableFile* file) { static void Cleanup(TF_WritableFile* file) {
auto posix_file = static_cast<PosixFile*>(file->plugin_file); auto posix_file = static_cast<PosixFile*>(file->plugin_file);
free(const_cast<char*>(posix_file->filename)); plugin_memory_free(const_cast<char*>(posix_file->filename));
delete posix_file; delete posix_file;
} }
@ -383,12 +386,13 @@ static int GetChildren(const TF_Filesystem* filesystem, const char* path,
if (num_entries < 0) { if (num_entries < 0) {
TF_SetStatusFromIOError(status, errno, path); TF_SetStatusFromIOError(status, errno, path);
} else { } else {
*entries = static_cast<char**>(calloc(num_entries, sizeof((*entries)[0]))); *entries = static_cast<char**>(
plugin_memory_allocate(num_entries * sizeof((*entries)[0])));
for (int i = 0; i < num_entries; i++) { for (int i = 0; i < num_entries; i++) {
(*entries)[i] = strdup(dir_entries[i]->d_name); (*entries)[i] = strdup(dir_entries[i]->d_name);
free(dir_entries[i]); plugin_memory_free(dir_entries[i]);
} }
free(dir_entries); plugin_memory_free(dir_entries);
} }
return num_entries; return num_entries;
@ -396,48 +400,59 @@ static int GetChildren(const TF_Filesystem* filesystem, const char* path,
} // namespace tf_posix_filesystem } // namespace tf_posix_filesystem
void TF_InitPlugin(TF_Status* status) { static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
TF_RandomAccessFileOps random_access_file_ops = { const char* uri) {
tf_random_access_file::Cleanup, TF_SetFilesystemVersionMetadata(ops);
tf_random_access_file::Read, ops->scheme = strdup(uri);
};
TF_WritableFileOps writable_file_ops = {
tf_writable_file::Cleanup, tf_writable_file::Append,
tf_writable_file::Tell, tf_writable_file::Flush,
tf_writable_file::Sync, tf_writable_file::Close,
};
TF_ReadOnlyMemoryRegionOps read_only_memory_region_ops = {
tf_read_only_memory_region::Cleanup,
tf_read_only_memory_region::Data,
tf_read_only_memory_region::Length,
};
TF_FilesystemOps filesystem_ops = {
tf_posix_filesystem::Init,
tf_posix_filesystem::Cleanup,
tf_posix_filesystem::NewRandomAccessFile,
tf_posix_filesystem::NewWritableFile,
tf_posix_filesystem::NewAppendableFile,
tf_posix_filesystem::NewReadOnlyMemoryRegionFromFile,
tf_posix_filesystem::CreateDir,
/*recursively_create_dir=*/nullptr,
tf_posix_filesystem::DeleteFile,
tf_posix_filesystem::DeleteDir,
/*delete_recursively=*/nullptr,
tf_posix_filesystem::RenameFile,
tf_posix_filesystem::CopyFile,
tf_posix_filesystem::PathExists,
/*paths_exist=*/nullptr,
tf_posix_filesystem::Stat,
/*is_directory=*/nullptr,
/*get_file_size=*/nullptr,
/*translate_name=*/nullptr,
tf_posix_filesystem::GetChildren,
/*get_matching_paths=*/nullptr,
/*flush_caches=*/nullptr,
};
for (const char* scheme : {"", "file"}) ops->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
TF_REGISTER_FILESYSTEM_PLUGIN(scheme, &filesystem_ops, plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
&random_access_file_ops, &writable_file_ops, ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup;
&read_only_memory_region_ops, status); ops->random_access_file_ops->read = tf_random_access_file::Read;
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
ops->writable_file_ops->append = tf_writable_file::Append;
ops->writable_file_ops->tell = tf_writable_file::Tell;
ops->writable_file_ops->flush = tf_writable_file::Flush;
ops->writable_file_ops->sync = tf_writable_file::Sync;
ops->writable_file_ops->close = tf_writable_file::Close;
ops->read_only_memory_region_ops = static_cast<TF_ReadOnlyMemoryRegionOps*>(
plugin_memory_allocate(TF_READ_ONLY_MEMORY_REGION_OPS_SIZE));
ops->read_only_memory_region_ops->cleanup =
tf_read_only_memory_region::Cleanup;
ops->read_only_memory_region_ops->data = tf_read_only_memory_region::Data;
ops->read_only_memory_region_ops->length = tf_read_only_memory_region::Length;
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_posix_filesystem::Init;
ops->filesystem_ops->cleanup = tf_posix_filesystem::Cleanup;
ops->filesystem_ops->new_random_access_file =
tf_posix_filesystem::NewRandomAccessFile;
ops->filesystem_ops->new_writable_file = tf_posix_filesystem::NewWritableFile;
ops->filesystem_ops->new_appendable_file =
tf_posix_filesystem::NewAppendableFile;
ops->filesystem_ops->new_read_only_memory_region_from_file =
tf_posix_filesystem::NewReadOnlyMemoryRegionFromFile;
ops->filesystem_ops->create_dir = tf_posix_filesystem::CreateDir;
ops->filesystem_ops->delete_file = tf_posix_filesystem::DeleteFile;
ops->filesystem_ops->delete_dir = tf_posix_filesystem::DeleteDir;
ops->filesystem_ops->rename_file = tf_posix_filesystem::RenameFile;
ops->filesystem_ops->copy_file = tf_posix_filesystem::CopyFile;
ops->filesystem_ops->path_exists = tf_posix_filesystem::PathExists;
ops->filesystem_ops->stat = tf_posix_filesystem::Stat;
ops->filesystem_ops->get_children = tf_posix_filesystem::GetChildren;
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
info->plugin_memory_allocate = plugin_memory_allocate;
info->plugin_memory_free = plugin_memory_free;
info->num_schemes = 2;
info->ops = static_cast<TF_FilesystemPluginOps*>(
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
ProvideFilesystemSupportFor(&info->ops[0], "");
ProvideFilesystemSupportFor(&info->ops[1], "file");
} }

View File

@ -44,7 +44,7 @@ int TransferFileContents(const char* src, const char* dst, mode_t mode,
} }
// Both files have been opened, do the transfer. // Both files have been opened, do the transfer.
// Since errno would be overriden by `close` below, save it here. // Since errno would be overridden by `close` below, save it here.
int error_code = 0; int error_code = 0;
if (CopyFileContents(dst_fd, src_fd, size) < 0) error_code = errno; if (CopyFileContents(dst_fd, src_fd, size) < 0) error_code = errno;

View File

@ -0,0 +1,36 @@
# Experimental windows filesystem plugin.
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
package(
licenses = ["notice"], # Apache 2.0
)
# Filesystem implementation for Windows environment
tf_cc_shared_object(
name = "windows_filesystem.dll",
framework_so = [],
linkstatic = False,
tags = [
"manual",
"nobuilder",
"notap",
],
visibility = ["//visibility:public"],
deps = [":windows_filesystem_impl"],
)
# The real implementation of the filesystem.
cc_library(
name = "windows_filesystem_impl",
srcs = ["windows_filesystem.cc"],
copts = get_win_copts(),
tags = [
"manual",
"nobuilder",
"notap",
],
deps = [
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
],
)

View File

@ -0,0 +1,73 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stdlib.h>
#include <string.h>
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h"
// Implementation of a filesystem for POSIX environments.
// This filesystem will support `file://` and empty (local) URI schemes.
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); }
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {
// TODO(mihaimaruseac): Implement later
} // namespace tf_random_access_file
// SECTION 2. Implementation for `TF_WritableFile`
// ----------------------------------------------------------------------------
namespace tf_writable_file {
// TODO(mihaimaruseac): Implement later
} // namespace tf_writable_file
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
// ----------------------------------------------------------------------------
namespace tf_read_only_memory_region {
// TODO(mihaimaruseac): Implement later
} // namespace tf_read_only_memory_region
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
// ----------------------------------------------------------------------------
namespace tf_windows_filesystem {
// TODO(mihaimaruseac): Implement later
} // namespace tf_windows_filesystem
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
const char* uri) {
TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri);
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
info->plugin_memory_allocate = plugin_memory_allocate;
info->plugin_memory_free = plugin_memory_free;
info->num_schemes = 2;
info->ops = static_cast<TF_FilesystemPluginOps*>(
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
ProvideFilesystemSupportFor(&info->ops[0], "");
ProvideFilesystemSupportFor(&info->ops[1], "file");
}

View File

@ -181,7 +181,8 @@ void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor,
return; return;
} }
const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i)); const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i));
TF_Tensor* result = ::tensorflow::TF_TensorFromTensor(cc_tensor, status); TF_Tensor* result =
::tensorflow::TF_TensorFromTensor(cc_tensor, &status->status);
if (TF_GetCode(status) == TF_OK) { if (TF_GetCode(status) == TF_OK) {
*tensor = result; *tensor = result;
} }

View File

@ -18,19 +18,36 @@ limitations under the License.
#include "tensorflow/c/kernels.h" #include "tensorflow/c/kernels.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include <stddef.h>
#include <stdint.h>
#include <string.h>
#include <memory>
#include <string>
#include "absl/container/inlined_vector.h"
#include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/ops_testutil.h" #include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
struct MyCustomKernel { struct MyCustomKernel {
bool created; bool created;

View File

@ -133,7 +133,7 @@ TEST(OpsTest, TestShapeInference_VectorizeFunction) {
TEST(OpsTest, AttributeAccessors) { TEST(OpsTest, AttributeAccessors) {
TF_OpDefinitionBuilder* builder = TF_OpDefinitionBuilder* builder =
TF_NewOpDefinitionBuilder("AttributeAccesorsOp"); TF_NewOpDefinitionBuilder("AttributeAccessorsOp");
TF_OpDefinitionBuilderAddAttr(builder, "foo1: int >= 2"); TF_OpDefinitionBuilderAddAttr(builder, "foo1: int >= 2");
TF_OpDefinitionBuilderAddAttr(builder, "foo2: string=\"my string\""); TF_OpDefinitionBuilderAddAttr(builder, "foo2: string=\"my string\"");
TF_OpDefinitionBuilderSetIsCommutative(builder, true); TF_OpDefinitionBuilderSetIsCommutative(builder, true);
@ -151,7 +151,7 @@ TEST(OpsTest, AttributeAccessors) {
op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length); op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
bool found = false; bool found = false;
for (const auto& op : op_list.op()) { for (const auto& op : op_list.op()) {
if (op.name() == "AttributeAccesorsOp") { if (op.name() == "AttributeAccessorsOp") {
ASSERT_TRUE(op.is_commutative()); ASSERT_TRUE(op.is_commutative());
ASSERT_TRUE(op.is_aggregate()); ASSERT_TRUE(op.is_aggregate());
ASSERT_TRUE(op.allows_uninitialized_input()); ASSERT_TRUE(op.allows_uninitialized_input());

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h" #include "tensorflow/c/tf_tensor.h"
#include <memory>
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor_internal.h" #include "tensorflow/c/tf_tensor_internal.h"
@ -103,49 +105,35 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg); buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
} }
TF_Tensor* ret = // TODO(gjn): Make the choice of interface a compile-time configuration.
new TF_Tensor{Tensor(static_cast<tensorflow::DataType>(dtype), tensorflow::TensorInterface ret(
tensorflow::TensorShape(dimvec), buf)}; Tensor(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimvec), buf));
buf->Unref(); buf->Unref();
size_t elem_size = TF_DataTypeSize(dtype); size_t elem_size = TF_DataTypeSize(dtype);
if (elem_size > 0 && len < (elem_size * ret->tensor.NumElements())) { if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
delete ret;
return nullptr; return nullptr;
} }
return ret; return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(ret)};
} }
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) { TF_Tensor* TF_TensorMaybeMove(TF_Tensor* t) {
// It is safe to move the Tensor if and only if we own the unique reference to return t->tensor->CanMove() ? t : nullptr;
// it. In that case, we might as well not delete and reallocate, but a future
// implementation might need to do so.
TensorBuffer* buf = tensorflow::TensorCApi::Buffer(tensor->tensor);
if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
buf->OwnsMemory()) {
return tensor;
}
return nullptr;
} }
void TF_DeleteTensor(TF_Tensor* t) { delete t; } void TF_DeleteTensor(TF_Tensor* t) { delete t; }
TF_DataType TF_TensorType(const TF_Tensor* t) { TF_DataType TF_TensorType(const TF_Tensor* t) { return t->tensor->Type(); }
return static_cast<TF_DataType>(t->tensor.dtype());
}
int TF_NumDims(const TF_Tensor* t) { return t->tensor.dims(); } int TF_NumDims(const TF_Tensor* t) { return t->tensor->NumDims(); }
int64_t TF_Dim(const TF_Tensor* t, int dim_index) { int64_t TF_Dim(const TF_Tensor* t, int dim_index) {
return static_cast<int64_t>(t->tensor.dim_size(dim_index)); return t->tensor->Dim(dim_index);
} }
size_t TF_TensorByteSize(const TF_Tensor* t) { size_t TF_TensorByteSize(const TF_Tensor* t) { return t->tensor->ByteSize(); }
return tensorflow::TensorCApi::Buffer(t->tensor)->size();
}
void* TF_TensorData(const TF_Tensor* t) { void* TF_TensorData(const TF_Tensor* t) { return t->tensor->Data(); }
return tensorflow::TensorCApi::Buffer(t->tensor)->data();
}
int64_t TF_TensorElementCount(const TF_Tensor* t) { int64_t TF_TensorElementCount(const TF_Tensor* t) {
int64_t result = 1; int64_t result = 1;
@ -160,16 +148,69 @@ void TF_TensorBitcastFrom(const TF_Tensor* from, TF_DataType type,
TF_Tensor* to, const int64_t* new_dims, TF_Tensor* to, const int64_t* new_dims,
int num_new_dims, TF_Status* status) { int num_new_dims, TF_Status* status) {
TF_SetStatus(status, TF_OK, ""); TF_SetStatus(status, TF_OK, "");
Status cc_status(
static_cast<tensorflow::TensorInterface*>(to->tensor.get())
->BitcastFrom(*static_cast<const tensorflow::TensorInterface*>(
from->tensor.get()),
type, new_dims, num_new_dims));
Set_TF_Status_from_Status(status, cc_status);
}
namespace tensorflow {
bool TensorInterface::CanMove() const {
// It is safe to move the Tensor if and only if we own the unique reference to
// it. In that case, we might as well not delete and reallocate, but a future
// implementation might need to do so.
TensorBuffer* buf = tensorflow::TensorCApi::Buffer(tensor_);
if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
buf->OwnsMemory()) {
return true;
}
return false;
}
TF_DataType TensorInterface::Type() const {
return static_cast<TF_DataType>(tensor_.dtype());
}
int TensorInterface::NumDims() const { return tensor_.dims(); }
int64_t TensorInterface::Dim(int dim_index) const {
return static_cast<int64_t>(tensor_.dim_size(dim_index));
}
int64_t TensorInterface::NumElements() const {
return static_cast<int64_t>(tensor_.NumElements());
}
size_t TensorInterface::ByteSize() const {
return tensorflow::TensorCApi::Buffer(tensor_)->size();
}
void* TensorInterface::Data() const {
return tensorflow::TensorCApi::Buffer(tensor_)->data();
}
Status TensorInterface::BitcastFrom(const TensorInterface& from,
TF_DataType type, const int64_t* new_dims,
int num_new_dims) {
tensorflow::TensorShape s; tensorflow::TensorShape s;
for (int i = 0; i < num_new_dims; ++i) { for (int i = 0; i < num_new_dims; ++i) {
s.AddDim(new_dims[i]); s.AddDim(new_dims[i]);
} }
Status cc_status(to->tensor.BitcastFrom( return tensor_.BitcastFrom(from.tensor_,
from->tensor, static_cast<tensorflow::DataType>(type), s)); static_cast<tensorflow::DataType>(type), s);
Set_TF_Status_from_Status(status, cc_status);
} }
} // namespace tensorflow
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
void StringEncode(const char* src, size_t src_len, char* dst) {
dst = tensorflow::core::EncodeVarint64(dst, src_len);
memcpy(dst, src, src_len);
}
size_t TF_StringEncode(const char* src, size_t src_len, char* dst, size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
size_t dst_len, TF_Status* status) { size_t dst_len, TF_Status* status) {
const size_t sz = TF_StringEncodedSize(src_len); const size_t sz = TF_StringEncodedSize(src_len);
@ -185,8 +226,7 @@ size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
src_len, "-byte string")); src_len, "-byte string"));
return 0; return 0;
} }
dst = tensorflow::core::EncodeVarint64(dst, src_len); StringEncode(src, src_len, dst);
memcpy(dst, src, src_len);
return sz; return sz;
} }
@ -245,13 +285,11 @@ static TF_Tensor* EmptyTensor(TF_DataType dtype,
namespace tensorflow { namespace tensorflow {
// Non-static for testing. // Non-static for testing.
TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) {
TF_Status* status) { *status = tensorflow::Status::OK();
TF_SetStatus(status, TF_OK, "");
if (!src.IsInitialized()) { if (!src.IsInitialized()) {
Set_TF_Status_from_Status( *status = FailedPrecondition(
status, FailedPrecondition( "attempt to use a tensor with an uninitialized value");
"attempt to use a tensor with an uninitialized value"));
return nullptr; return nullptr;
} }
if (src.NumElements() == 0) { if (src.NumElements() == 0) {
@ -259,14 +297,13 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
} }
if (src.dtype() == tensorflow::DT_RESOURCE) { if (src.dtype() == tensorflow::DT_RESOURCE) {
if (src.shape().dims() != 0) { if (src.shape().dims() != 0) {
Set_TF_Status_from_Status( *status = InvalidArgument(
status, InvalidArgument( "Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
"Unexpected non-scalar DT_RESOURCE tensor seen (shape: ", src.shape().DebugString(),
src.shape().DebugString(), "). Please file a bug at "
"). Please file a bug at " "https://github.com/tensorflow/tensorflow/issues/new, "
"https://github.com/tensorflow/tensorflow/issues/new, " "ideally with a "
"ideally with a " "short code snippet that reproduces this error.");
"short code snippet that reproduces this error."));
return nullptr; return nullptr;
} }
const string str = const string str =
@ -276,12 +313,11 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
return t; return t;
} }
if (src.dtype() != tensorflow::DT_STRING) { if (src.dtype() != tensorflow::DT_STRING) {
auto* result = new TF_Tensor(); Tensor tensor;
if (!result->tensor.CopyFrom(src, src.shape())) { if (!tensor.CopyFrom(src, src.shape())) {
delete result;
return nullptr; return nullptr;
} }
return result; return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(tensor)};
} }
// DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly // DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly
// encoded sequence of strings. // encoded sequence of strings.
@ -305,23 +341,15 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
*offsets = (dst - data_start); *offsets = (dst - data_start);
offsets++; offsets++;
const string& s = srcarray(i); const string& s = srcarray(i);
size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status); const size_t consumed = TF_StringEncodedSize(s.size());
if (TF_GetCode(status) != TF_OK) { StringEncode(s.data(), s.size(), dst);
Set_TF_Status_from_Status(
status,
InvalidArgument("invalid string tensor encoding (string #", i, " of ",
srcarray.size(), "): ", TF_Message(status)));
delete[] base;
return nullptr;
}
dst += consumed; dst += consumed;
dst_len -= consumed; dst_len -= consumed;
} }
if (dst != base + size) { if (dst != base + size) {
Set_TF_Status_from_Status( *status = InvalidArgument(
status, InvalidArgument( "invalid string tensor encoding (decoded ", (dst - base),
"invalid string tensor encoding (decoded ", (dst - base), " bytes, but the tensor is encoded in ", size, " bytes");
" bytes, but the tensor is encoded in ", size, " bytes"));
delete[] base; delete[] base;
return nullptr; return nullptr;
} }
@ -339,31 +367,35 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
} }
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) { Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
if (src->tensor.dtype() == DT_RESOURCE) { return static_cast<const tensorflow::TensorInterface*>(src->tensor.get())
if (src->tensor.dims() != 0) { ->ToTensor(dst);
}
Status TensorInterface::ToTensor(Tensor* dst) const {
if (tensor_.dtype() == DT_RESOURCE) {
if (tensor_.dims() != 0) {
return InvalidArgument( return InvalidArgument(
"Malformed TF_RESOURCE tensor: expected a scalar, got a tensor with " "Malformed TF_RESOURCE tensor: expected a scalar, got a tensor with "
"shape ", "shape ",
src->tensor.shape().DebugString()); tensor_.shape().DebugString());
} }
*dst = Tensor(tensorflow::DT_RESOURCE, src->tensor.shape()); *dst = Tensor(tensorflow::DT_RESOURCE, tensor_.shape());
if (!dst->scalar<tensorflow::ResourceHandle>()().ParseFromString( if (!dst->scalar<tensorflow::ResourceHandle>()().ParseFromString(
string(static_cast<const char*>(TF_TensorData(src)), string(static_cast<const char*>(Data()), ByteSize()))) {
TF_TensorByteSize(src)))) {
return InvalidArgument( return InvalidArgument(
"Malformed TF_RESOUCE tensor: unable to parse resource handle"); "Malformed TF_RESOURCE tensor: unable to parse resource handle");
} }
return Status::OK(); return Status::OK();
} }
if (src->tensor.dtype() != DT_STRING) { if (tensor_.dtype() != DT_STRING) {
*dst = src->tensor; *dst = tensor_;
return Status::OK(); return Status::OK();
} }
// TF_STRING tensors require copying since Tensor class expects a sequence of // TF_STRING tensors require copying since Tensor class expects a sequence of
// string objects. // string objects.
const tensorflow::int64 num_elements = src->tensor.NumElements(); const tensorflow::int64 num_elements = tensor_.NumElements();
const char* input = reinterpret_cast<const char*>(TF_TensorData(src)); const char* input = reinterpret_cast<const char*>(Data());
const size_t src_size = TF_TensorByteSize(src); const size_t src_size = ByteSize();
if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) < if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) <
num_elements) { num_elements) {
return InvalidArgument( return InvalidArgument(
@ -372,7 +404,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
const char* data_start = input + sizeof(tensorflow::uint64) * num_elements; const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
const char* limit = input + src_size; const char* limit = input + src_size;
*dst = Tensor(src->tensor.dtype(), src->tensor.shape()); *dst = Tensor(tensor_.dtype(), tensor_.shape());
auto dstarray = dst->flat<tstring>(); auto dstarray = dst->flat<tstring>();
for (tensorflow::int64 i = 0; i < num_elements; ++i) { for (tensorflow::int64 i = 0; i < num_elements; ++i) {
tensorflow::uint64 offset = tensorflow::uint64 offset =
@ -391,8 +423,8 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
return Status::OK(); return Status::OK();
} }
bool TensorInterface::IsAligned() const { return tensor_.IsAligned(); }
} // namespace tensorflow } // namespace tensorflow
bool TF_TensorIsAligned(const TF_Tensor* tensor) { bool TF_TensorIsAligned(const TF_Tensor* t) { return t->tensor->IsAligned(); }
return tensor->tensor.IsAligned();
}

View File

@ -16,9 +16,12 @@ limitations under the License.
#ifndef TENSORFLOW_C_TF_TENSOR_INTERNAL_H_ #ifndef TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
#define TENSORFLOW_C_TF_TENSOR_INTERNAL_H_ #define TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
#include <memory>
#include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_datatype.h"
#include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_interface.h"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
// Internal structures used by the C API. These are likely to change and should // Internal structures used by the C API. These are likely to change and should
@ -28,7 +31,7 @@ limitations under the License.
// passed to or returned from C functions *by pointer*. Otherwise, changes to // passed to or returned from C functions *by pointer*. Otherwise, changes to
// its internal structure will break the C API's binary interface. // its internal structure will break the C API's binary interface.
typedef struct TF_Tensor { typedef struct TF_Tensor {
::tensorflow::Tensor tensor; std::unique_ptr<AbstractTensorInterface> tensor;
} TF_Tensor; } TF_Tensor;
class TF_ManagedBuffer : public tensorflow::TensorBuffer { class TF_ManagedBuffer : public tensorflow::TensorBuffer {
@ -83,4 +86,5 @@ void* allocate_tensor(const char* operation, size_t len, Allocator* allocator);
// a different Allocator as `arg`. // a different Allocator as `arg`.
void deallocate_buffer(void* data, size_t len, void* arg); void deallocate_buffer(void* data, size_t len, void* arg);
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_ #endif // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_

View File

@ -96,7 +96,7 @@ class SymbolicGradientBuilder {
// Used to identify nodes at which to stop backprop. // Used to identify nodes at which to stop backprop.
std::unordered_set<int> GetStopBackpropNodes( std::unordered_set<int> GetStopBackpropNodes(
const std::vector<bool>& reachable_nodes, const std::vector<bool>& reachable_nodes,
const std::unordered_set<int>& output_nodes); const std::unordered_set<int>& output_nodes) const;
const Scope& scope_; const Scope& scope_;
const ops::GradOpRegistry* registry_; const ops::GradOpRegistry* registry_;
@ -190,7 +190,7 @@ std::vector<bool> SymbolicGradientBuilder::GetReachableNodes() {
std::unordered_set<int> SymbolicGradientBuilder::GetStopBackpropNodes( std::unordered_set<int> SymbolicGradientBuilder::GetStopBackpropNodes(
const std::vector<bool>& reachable_nodes, const std::vector<bool>& reachable_nodes,
const std::unordered_set<int>& output_nodes) { const std::unordered_set<int>& output_nodes) const {
// Output nodes that get transitively consumed by other `outputs_` are stored // Output nodes that get transitively consumed by other `outputs_` are stored
// in `internal_outputs`. // in `internal_outputs`.
std::unordered_set<int> internal_outputs; std::unordered_set<int> internal_outputs;
@ -346,8 +346,8 @@ Status SymbolicGradientBuilder::SumGradients(const Output& src, Output* grad) {
"Unable to find backprop list for node.id ", src.node()->name()); "Unable to find backprop list for node.id ", src.node()->name());
} }
const auto& grads = iter->second; const auto& grads = iter->second;
// Filter any backproped 'NoGradient' Outputs from 'grads' (if needed). // Filter any backpropped 'NoGradient' Outputs from 'grads' (if needed).
// Return any valid backproped gradients that remain after filtering, // Return any valid backpropped gradients that remain after filtering,
// or 'NoGradient' otherwise. // or 'NoGradient' otherwise.
std::vector<Output> grads_to_keep; std::vector<Output> grads_to_keep;
for (const Output& o : grads) { for (const Output& o : grads) {
@ -519,7 +519,7 @@ Status SymbolicGradientBuilder::AddGradients() {
// Backprop along the in edges. // Backprop along the in edges.
// TODO(andydavis) Find cleaner way to map each grad output returned by // TODO(andydavis) Find cleaner way to map each grad output returned by
// gradient function to the src node/output to which it should be // gradient function to the src node/output to which it should be
// backproped. Maybe grad functions can return a vector of Output pairs to // backpropped. Maybe grad functions can return a vector of Output pairs to
// make this association explicit. // make this association explicit.
size_t dx_index = 0; size_t dx_index = 0;
for (const Edge* e : n->in_edges()) { for (const Edge* e : n->in_edges()) {

View File

@ -64,7 +64,7 @@ bool IsZero(const Scope& scope, const Output& grad) {
// Multiply after broadcasting vec to match dimensions of mat. // Multiply after broadcasting vec to match dimensions of mat.
// Args: // Args:
// vec: A 1-D tensor of dimension [D0] // vec: A 1-D tensor of dimension [D0]
// mat: A 2-D tensor of dimesnion [D0, D1] // mat: A 2-D tensor of dimension [D0, D1]
// //
// Returns: // Returns:
// A tensor of dimension [D0, D1], the result fo vec * mat. // A tensor of dimension [D0, D1], the result fo vec * mat.

View File

@ -259,6 +259,9 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) {
RunTest(x, x_init_value, y, y_shape); RunTest(x, x_init_value, y, y_shape);
} }
// TODO(rocm):
// Re-enable this test once 3D pooling is supported on ROCm platform
#ifndef TENSORFLOW_USE_ROCM
TEST_F(NNGradTest, MaxPool3DGradHelper) { TEST_F(NNGradTest, MaxPool3DGradHelper) {
TensorShape x_shape({1, 3, 3, 3, 1}); TensorShape x_shape({1, 3, 3, 3, 1});
TensorShape y_shape({1, 1, 1, 1, 1}); TensorShape y_shape({1, 1, 1, 1, 1});
@ -271,6 +274,7 @@ TEST_F(NNGradTest, MaxPool3DGradHelper) {
SetRandomValuesForMaxPooling<float>(&x_init_value); SetRandomValuesForMaxPooling<float>(&x_init_value);
RunTest(x, x_init_value, y, y_shape); RunTest(x, x_init_value, y, y_shape);
} }
#endif
TEST_F(NNGradTest, AvgPoolGradHelper) { TEST_F(NNGradTest, AvgPoolGradHelper) {
TensorShape x_shape({1, 2, 2, 1}); TensorShape x_shape({1, 2, 2, 1});
@ -283,6 +287,9 @@ TEST_F(NNGradTest, AvgPoolGradHelper) {
RunTest(x, x_shape, y, y_shape); RunTest(x, x_shape, y, y_shape);
} }
// TODO(rocm):
// Re-enable this test once 3D pooling is supported on ROCm platform
#ifndef TENSORFLOW_USE_ROCM
TEST_F(NNGradTest, AvgPool3DGradHelper) { TEST_F(NNGradTest, AvgPool3DGradHelper) {
TensorShape x_shape({1, 3, 3, 3, 1}); TensorShape x_shape({1, 3, 3, 3, 1});
TensorShape y_shape({1, 1, 1, 1, 1}); TensorShape y_shape({1, 1, 1, 1, 1});
@ -293,6 +300,7 @@ TEST_F(NNGradTest, AvgPool3DGradHelper) {
auto y = AvgPool3D(scope_, x, ksize, strides, "SAME"); auto y = AvgPool3D(scope_, x, ksize, strides, "SAME");
RunTest(x, x_shape, y, y_shape); RunTest(x, x_shape, y, y_shape);
} }
#endif
TEST_F(NNGradTest, LRN) { TEST_F(NNGradTest, LRN) {
TensorShape x_shape({1, 1, 2, 1}); TensorShape x_shape({1, 1, 2, 1});

View File

@ -124,13 +124,12 @@ cc_library(
hdrs = ["bundle_v2.h"], hdrs = ["bundle_v2.h"],
deps = [ deps = [
":constants", ":constants",
"@com_google_absl//absl/container:flat_hash_set",
] + if_not_mobile([
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:strcat", "//tensorflow/core/platform:strcat",
"//tensorflow/core/util/tensor_bundle", "//tensorflow/core/util/tensor_bundle",
]), "@com_google_absl//absl/container:flat_hash_set",
],
) )
tf_cc_test( tf_cc_test(

View File

@ -1,5 +1,6 @@
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
load("//tensorflow/core/platform:build_config.bzl", "if_llvm_aarch64_available")
package( package(
default_visibility = ["//visibility:private"], default_visibility = ["//visibility:private"],
@ -27,9 +28,15 @@ cc_library(
"compile.h", "compile.h",
"flags.h", "flags.h",
], ],
defines = if_llvm_aarch64_available(["TF_LLVM_AARCH64_AVAILABLE=1"]),
visibility = ["//tensorflow/python:__pkg__"],
deps = [ deps = [
":aot_only_var_handle_op", ":aot_only_var_handle_op",
":embedded_protocol_buffers", ":embedded_protocol_buffers",
"@com_google_absl//absl/base",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"//tensorflow/compiler/tf2xla", "//tensorflow/compiler/tf2xla",
"//tensorflow/compiler/tf2xla:mlir_tf2xla", "//tensorflow/compiler/tf2xla:mlir_tf2xla",
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc", "//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
@ -53,10 +60,13 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory", "@llvm-project//llvm:arm_code_gen", # fixdeps: keep
"@com_google_absl//absl/strings", "@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@com_google_absl//absl/types:span", "@llvm-project//llvm:target",
], "@llvm-project//llvm:x86_code_gen", # fixdeps: keep
] + if_llvm_aarch64_available([
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
]),
) )
tf_cc_test( tf_cc_test(
@ -86,6 +96,19 @@ tf_cc_binary(
deps = [":tfcompile_main"], deps = [":tfcompile_main"],
) )
cc_library(
name = "llvm_targets",
visibility = ["//tensorflow/python:__pkg__"],
deps = [
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@llvm-project//llvm:target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
] + if_llvm_aarch64_available([
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
]),
)
cc_library( cc_library(
name = "tfcompile_main", name = "tfcompile_main",
srcs = ["tfcompile_main.cc"], srcs = ["tfcompile_main.cc"],
@ -104,11 +127,6 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:aarch64_code_gen", # fixdeps: keep
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@llvm-project//llvm:target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
], ],
) )
@ -214,8 +232,13 @@ cc_library(
cc_library( cc_library(
name = "aot_only_var_handle_op", name = "aot_only_var_handle_op",
srcs = ["aot_only_var_handle_op.cc"], srcs = ["aot_only_var_handle_op.cc"],
hdrs = ["aot_only_var_handle_op.h"],
visibility = [
"//tensorflow/compiler/tf2xla:__pkg__",
],
deps = [ deps = [
"//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/core:framework",
], ],
alwayslink = 1, alwayslink = 1,
) )

View File

@ -13,9 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/aot/aot_only_var_handle_op.h"
#include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
@ -51,6 +54,31 @@ void XlaAotOnlyVarHandleOp::Compile(XlaOpKernelContext* context) {
} }
} // namespace } // namespace
REGISTER_XLA_OP(Name("VarHandleOp").CompilationOnly(), XlaAotOnlyVarHandleOp); REGISTER_OP(tfcompile::kXlaAotOnlyVarHandleOp)
.Doc(R"doc(
Internal VarHandleOp registration used for XLA AOT compilation.
)doc")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.Attr("dtype: type")
.Attr("shape: shape")
.Output("resource: resource")
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Scalar());
DataType t;
TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t));
PartialTensorShape p;
TF_RETURN_IF_ERROR(c->GetAttr("shape", &p));
shape_inference::ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s));
c->set_output_handle_shapes_and_types(
0, std::vector<shape_inference::ShapeAndType>{{s, t}});
return Status::OK();
});
REGISTER_XLA_OP(Name(tfcompile::kXlaAotOnlyVarHandleOp).CompilationOnly(),
XlaAotOnlyVarHandleOp);
} // namespace tensorflow } // namespace tensorflow

View File

@ -0,0 +1,27 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_
#define TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_
namespace tensorflow {
namespace tfcompile {
static constexpr const char* const kXlaAotOnlyVarHandleOp =
"_XlaAotOnlyVarHandleOp";
} // namespace tfcompile
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_

View File

@ -74,16 +74,16 @@ void DumpStatsToStdout(const Stats& stats) {
const int kBufSize = 1000; const int kBufSize = 1000;
char buf[kBufSize]; char buf[kBufSize];
snprintf(buf, kBufSize, "Mean with %2.0f%% trimmed:", trim_ratio * 100); snprintf(buf, kBufSize, "Mean with %2.0f%% trimmed:", trim_ratio * 100);
const string label_trimmed(buf); std::string label_trimmed(buf);
snprintf(buf, kBufSize, "Mean of %2.0f%% best:", best_ratio * 100); snprintf(buf, kBufSize, "Mean of %2.0f%% best:", best_ratio * 100);
const string label_best(buf); std::string label_best(buf);
std::vector<std::pair<string, double>> groups = { std::vector<std::pair<std::string, double>> groups = {
{"Best:", sorted_us.front()}, {"Best:", sorted_us.front()},
{"Worst:", sorted_us.back()}, {"Worst:", sorted_us.back()},
{"Median:", sorted_us[count_us / 2]}, {"Median:", sorted_us[count_us / 2]},
{"Mean:", sum_us / count_us}, {"Mean:", sum_us / count_us},
{label_trimmed, sum_us_trimmed / count_us_trimmed}, {std::move(label_trimmed), sum_us_trimmed / count_us_trimmed},
{label_best, sum_us_best / count_us_best}, {std::move(label_best), sum_us_best / count_us_best},
}; };
int max_label_size = 0; int max_label_size = 0;
double max_us = 0; double max_us = 0;
@ -102,7 +102,7 @@ void DumpStatsToStdout(const Stats& stats) {
} }
// Dump stats out. // Dump stats out.
printf("Benchmark ran %zu iterations over %lld us\n", count_us, printf("Benchmark ran %zu iterations over %lld us\n", count_us,
stats.total_us); static_cast<long long>(stats.total_us)); // NOLINT
for (const auto& g : groups) { for (const auto& g : groups) {
printf(" %-*s %*.3f us\n", max_label_size, g.first.c_str(), max_digits + 4, printf(" %-*s %*.3f us\n", max_label_size, g.first.c_str(), max_digits + 4,
g.second); g.second);
@ -114,7 +114,8 @@ void Benchmark(const Options& options, const BenchmarkFn& fn, Stats* stats) {
const int64 max_us = (options.max_micros <= 0 && options.max_iters <= 0) const int64 max_us = (options.max_micros <= 0 && options.max_iters <= 0)
? Options::kDefaultMicros ? Options::kDefaultMicros
: options.max_micros; : options.max_micros;
printf("Running benchmark for %lld us\n", max_us); // NOLINTNEXTLINE
printf("Running benchmark for %lld us\n", static_cast<long long>(max_us));
const int64 start_us = NowMicros(); const int64 start_us = NowMicros();
int64 iters = 0; int64 iters = 0;
while (true) { while (true) {

View File

@ -423,8 +423,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
GenNameToIndexCode(config.fetch(), opts.gen_name_to_index); GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
const string include_xla_data_proto = const string include_xla_data_proto =
opts.gen_program_shape opts.gen_program_shape
? ? R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
: ""; : "";
const string include_hlo_profile_printer_data_proto = const string include_hlo_profile_printer_data_proto =

View File

@ -20,6 +20,9 @@ limitations under the License.
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "absl/base/call_once.h"
#include "llvm-c/Target.h"
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/flags.h" #include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/compiler/tf2xla/tf2xla.h" #include "tensorflow/compiler/tf2xla/tf2xla.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
@ -90,7 +93,7 @@ Status CompileXla(xla::CompileOnlyClient* client,
} // namespace } // namespace
Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config, Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
const MainFlags& flags, CompileResult* compile_result) { const MainFlags& flags, CompileResult* compile_result) {
// Converts the graph into an XLA computation, and compiles the // Converts the graph into an XLA computation, and compiles the
// computation. // computation.
@ -108,8 +111,8 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
if (!flags.mlir_components.empty()) { if (!flags.mlir_components.empty()) {
return errors::Unknown("Unknown mlir_components ", flags.mlir_components); return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
} }
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config,
ConvertGraphDefToXla(graph_def, config, client, &computation)); client, &computation));
} }
if (!flags.out_session_module.empty()) { if (!flags.out_session_module.empty()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module, TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
@ -132,5 +135,96 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
return CompileXla(client, computation, aot_opts, compile_result); return CompileXla(client, computation, aot_opts, compile_result);
} }
static Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
if (absl::EndsWith(fname, ".pbtxt")) {
return ReadTextProto(Env::Default(), fname, proto);
} else {
return ReadBinaryProto(Env::Default(), fname, proto);
}
}
static absl::once_flag targets_init;
static void InitializeTargets() {
// Initialize all LLVM targets so we can cross compile.
#if TF_LLVM_AARCH64_AVAILABLE
LLVMInitializeAArch64Target();
LLVMInitializeAArch64TargetInfo();
LLVMInitializeAArch64TargetMC();
LLVMInitializeAArch64AsmPrinter();
#endif
LLVMInitializeARMTarget();
LLVMInitializeARMTargetInfo();
LLVMInitializeARMTargetMC();
LLVMInitializeARMAsmPrinter();
LLVMInitializePowerPCTarget();
LLVMInitializePowerPCTargetInfo();
LLVMInitializePowerPCTargetMC();
LLVMInitializePowerPCAsmPrinter();
LLVMInitializeX86Target();
LLVMInitializeX86TargetInfo();
LLVMInitializeX86TargetMC();
LLVMInitializeX86AsmPrinter();
}
Status Main(const MainFlags& flags) {
absl::call_once(targets_init, &InitializeTargets);
// Process config.
tf2xla::Config config;
if (flags.config.empty()) {
return errors::InvalidArgument("Must specify --config");
}
TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
TF_RETURN_IF_ERROR(ValidateConfig(config));
if (flags.dump_fetch_nodes) {
std::set<string> nodes;
for (const tf2xla::Fetch& fetch : config.fetch()) {
nodes.insert(fetch.id().node_name());
}
std::cout << absl::StrJoin(nodes, ",");
return Status::OK();
}
// Read and initialize the graph.
if (flags.graph.empty()) {
return errors::InvalidArgument("Must specify --graph");
}
GraphDef graph_def;
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
CompileResult compile_result;
TF_RETURN_IF_ERROR(
CompileGraph(std::move(graph_def), config, flags, &compile_result));
// Write output files.
Env* env = Env::Default();
const std::vector<char>& obj = compile_result.aot->object_file_data();
TF_RETURN_IF_ERROR(
WriteStringToFile(env, flags.out_function_object,
absl::string_view(obj.data(), obj.size())));
CodegenOpts codegen_opts;
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
codegen_opts.gen_program_shape = flags.gen_program_shape;
codegen_opts.target_triple = flags.target_triple;
if (flags.cpp_class.empty()) {
return errors::InvalidArgument("Must specify --cpp_class");
}
codegen_opts.gen_hlo_profile_printer_data =
xla::GetDebugOptionsFromFlags().xla_hlo_profile();
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
&codegen_opts.namespaces));
MetadataResult metadata_result;
TF_RETURN_IF_ERROR(
GenerateMetadata(codegen_opts, compile_result, &metadata_result));
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object,
metadata_result.object_file_data));
string header;
TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result,
metadata_result, &header));
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
return Status::OK();
}
} // namespace tfcompile } // namespace tfcompile
} // namespace tensorflow } // namespace tensorflow

View File

@ -42,9 +42,12 @@ struct CompileResult {
// that performs the graph operations. // that performs the graph operations.
// //
// The XLA compilation options are specified in the flags. // The XLA compilation options are specified in the flags.
Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config, Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
const MainFlags& flags, CompileResult* compile_result); const MainFlags& flags, CompileResult* compile_result);
// The full compilation method, for reuse in a library setting.
Status Main(const MainFlags& flags);
} // namespace tfcompile } // namespace tfcompile
} // namespace tensorflow } // namespace tensorflow

View File

@ -25,6 +25,7 @@ namespace tensorflow {
namespace tfcompile { namespace tfcompile {
// Flags for the tfcompile binary. See *.cc file for descriptions. // Flags for the tfcompile binary. See *.cc file for descriptions.
struct MainFlags { struct MainFlags {
string graph; string graph;
string config; string config;

View File

@ -25,6 +25,7 @@ test_suite(
":test_graph_tfmatmulandadd_test", ":test_graph_tfmatmulandadd_test",
":test_graph_tfsplits_test", ":test_graph_tfsplits_test",
":test_graph_tftop_k_test", ":test_graph_tftop_k_test",
":test_graph_tfvariable_readonly_test",
":test_graph_tfvariable_sequential_updates_test", ":test_graph_tfvariable_sequential_updates_test",
":test_graph_tfvariable_test", ":test_graph_tfvariable_test",
":tfcompile_test", ":tfcompile_test",
@ -73,6 +74,7 @@ genrule(
"test_graph_tfsplits.pb", "test_graph_tfsplits.pb",
"test_graph_tftop_k.pb", "test_graph_tftop_k.pb",
"test_graph_tfvariable.pb", "test_graph_tfvariable.pb",
"test_graph_tfvariable_readonly.pb",
"test_graph_tfvariable_sequential_updates.pb", "test_graph_tfvariable_sequential_updates.pb",
], ],
# Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any # Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any
@ -238,6 +240,17 @@ tf_library(
], ],
) )
tf_library(
name = "test_graph_tfvariable_readonly",
testonly = 1,
config = "test_graph_tfvariable_readonly.config.pbtxt",
cpp_class = "VariableReadonlyComp",
graph = "test_graph_tfvariable_readonly.pb",
tags = [
"manual",
],
)
tf_library( tf_library(
name = "test_graph_tfvariable_sequential_updates", name = "test_graph_tfvariable_sequential_updates",
testonly = 1, testonly = 1,
@ -269,6 +282,7 @@ tf_cc_test(
":test_graph_tfsplits", ":test_graph_tfsplits",
":test_graph_tftop_k", ":test_graph_tftop_k",
":test_graph_tfvariable", ":test_graph_tfvariable",
":test_graph_tfvariable_readonly",
":test_graph_tfvariable_sequential_updates", ":test_graph_tfvariable_sequential_updates",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test",
@ -323,6 +337,42 @@ tf_library(
], ],
) )
tf_library(
name = "test_graph_tfcond_mlir_bridge",
testonly = 1,
config = "test_graph_tfcond.config.pbtxt",
cpp_class = "CondComp",
graph = "test_graph_tfcond.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfassert_eq_mlir_bridge",
testonly = 1,
config = "test_graph_tfassert_eq.config.pbtxt",
cpp_class = "AssertComp",
graph = "test_graph_tfassert_eq.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfgather_mlir_bridge",
testonly = 1,
config = "test_graph_tfgather.config.pbtxt",
cpp_class = "GatherComp",
graph = "test_graph_tfgather.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library( tf_library(
name = "test_graph_tfmatmul_mlir_bridge", name = "test_graph_tfmatmul_mlir_bridge",
testonly = 1, testonly = 1,
@ -361,6 +411,66 @@ tf_library(
], ],
) )
tf_library(
name = "test_graph_tfsplits_mlir_bridge",
testonly = 1,
config = "test_graph_tfsplits.config.pbtxt",
cpp_class = "SplitsComp",
graph = "test_graph_tfsplits.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tftop_k_mlir_bridge",
testonly = 1,
config = "test_graph_tftop_k.config.pbtxt",
cpp_class = "TopKComp",
graph = "test_graph_tftop_k.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfvariable_readonly_mlir_bridge",
testonly = 1,
config = "test_graph_tfvariable_readonly.config.pbtxt",
cpp_class = "VariableReadonlyComp",
graph = "test_graph_tfvariable_readonly.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfvariable_mlir_bridge",
testonly = 1,
config = "test_graph_tfvariable.config.pbtxt",
cpp_class = "VariableComp",
graph = "test_graph_tfvariable.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfvariable_sequential_updates_mlir_bridge",
testonly = 1,
config = "test_graph_tfvariable_sequential_updates.config.pbtxt",
cpp_class = "VariableSequentialUpdatesComp",
graph = "test_graph_tfvariable_sequential_updates.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_cc_test( tf_cc_test(
name = "tfcompile_test_mlir_bridge", name = "tfcompile_test_mlir_bridge",
srcs = ["tfcompile_test.cc"], srcs = ["tfcompile_test.cc"],
@ -372,9 +482,17 @@ tf_cc_test(
":test_graph_tfadd_mlir_bridge", ":test_graph_tfadd_mlir_bridge",
":test_graph_tfadd_with_ckpt_mlir_bridge", ":test_graph_tfadd_with_ckpt_mlir_bridge",
":test_graph_tfadd_with_ckpt_saver_mlir_bridge", ":test_graph_tfadd_with_ckpt_saver_mlir_bridge",
":test_graph_tfassert_eq_mlir_bridge",
":test_graph_tfcond_mlir_bridge",
":test_graph_tfgather_mlir_bridge",
":test_graph_tfmatmul_mlir_bridge", ":test_graph_tfmatmul_mlir_bridge",
":test_graph_tfmatmulandadd_mlir_bridge", ":test_graph_tfmatmulandadd_mlir_bridge",
":test_graph_tfmatmulandadd_with_profiling_mlir_bridge", ":test_graph_tfmatmulandadd_with_profiling_mlir_bridge",
":test_graph_tfsplits_mlir_bridge",
":test_graph_tftop_k_mlir_bridge",
":test_graph_tfvariable_mlir_bridge",
":test_graph_tfvariable_readonly_mlir_bridge",
":test_graph_tfvariable_sequential_updates_mlir_bridge",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla:xla_data_proto_cc",

View File

@ -34,6 +34,7 @@ from tensorflow.python.framework import function
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
@ -153,11 +154,21 @@ def tftop_k(_):
array_ops.identity(output[1], name='indices') array_ops.identity(output[1], name='indices')
def tfvariable(_): def tfvariable_readonly(_):
x = variables.Variable(1000.0, name='x') x = variables.Variable(1000.0, name='x')
old_x = x.value() old_x = x.value()
with ops.control_dependencies([old_x]): with ops.control_dependencies([old_x]):
new_x = x.assign_add(42.0) new_value = math_ops.add(old_x, 42.0)
array_ops.identity(new_value, name='result')
# TODO(b/147908587): Change x and the two constants back to have a scalar shape
# when the bug is fixed.
def tfvariable(_):
x = variables.Variable([1000.0], name='x', shape=[1])
old_x = x.value()
with ops.control_dependencies([old_x]):
new_x = x.assign_add([42.0])
array_ops.stack([old_x, new_x], name='result') array_ops.stack([old_x, new_x], name='result')
@ -184,6 +195,7 @@ def write_graph(build_graph, out_dir):
def main(_): def main(_):
control_flow_util.enable_control_flow_v2()
write_graph(tfadd, FLAGS.out_dir) write_graph(tfadd, FLAGS.out_dir)
write_graph(tfadd_with_ckpt, FLAGS.out_dir) write_graph(tfadd_with_ckpt, FLAGS.out_dir)
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir) write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
@ -196,6 +208,7 @@ def main(_):
write_graph(tfsplits, FLAGS.out_dir) write_graph(tfsplits, FLAGS.out_dir)
write_graph(tftop_k, FLAGS.out_dir) write_graph(tftop_k, FLAGS.out_dir)
write_graph(tfvariable, FLAGS.out_dir) write_graph(tfvariable, FLAGS.out_dir)
write_graph(tfvariable_readonly, FLAGS.out_dir)
write_graph(tfvariable_sequential_updates, FLAGS.out_dir) write_graph(tfvariable_sequential_updates, FLAGS.out_dir)

View File

@ -0,0 +1,12 @@
# Text form of tensorflow.tf2xla.Config proto.
fetch {
id { node_name: "result" }
}
variable {
node_name: "x"
shape {
}
type: DT_FLOAT
readonly: true
}

View File

@ -30,9 +30,17 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_mlir_bridge.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_mlir_bridge.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver_mlir_bridge.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfcond_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfgather_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul_mlir_bridge.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmul_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_mlir_bridge.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling_mlir_bridge.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tftop_k_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates_mlir_bridge.h"
#else #else
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
@ -47,6 +55,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h" #include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h"
#include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h" #include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable.h" #include "tensorflow/compiler/aot/tests/test_graph_tfvariable.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates.h" #include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates.h"
#endif #endif
@ -167,8 +176,6 @@ TEST(TFCompileTest, AddWithCkptSaver) {
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]); EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
} }
// TODO(bixia): the following tests failed with MLIR bridge.
#if !defined(ENABLE_MLIR_BRIDGE_TEST)
TEST(TFCompileTest, Cond) { TEST(TFCompileTest, Cond) {
CondComp cond; CondComp cond;
EXPECT_EQ(cond.arg0_data(), cond.arg_data(0)); EXPECT_EQ(cond.arg0_data(), cond.arg_data(0));
@ -233,7 +240,6 @@ TEST(TFCompileTest, Gather) {
EXPECT_EQ(gather_const.result0_data(), gather.results()[0]); EXPECT_EQ(gather_const.result0_data(), gather.results()[0]);
} }
} }
#endif
TEST(TFCompileTest, MatMul2) { TEST(TFCompileTest, MatMul2) {
Eigen::ThreadPool tp(2); Eigen::ThreadPool tp(2);
@ -439,6 +445,7 @@ TEST(TFCompileTest, Function) {
EXPECT_EQ(add_fn.result0_data()[0], 3); EXPECT_EQ(add_fn.result0_data()[0], 3);
EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]); EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]);
} }
#endif
TEST(TFCompileTest, Splits) { TEST(TFCompileTest, Splits) {
Eigen::ThreadPool tp(1); Eigen::ThreadPool tp(1);
@ -492,6 +499,20 @@ TEST(TFCompileTest, TopK) {
EXPECT_EQ(expected_indices[1], fn.result1(1)); EXPECT_EQ(expected_indices[1], fn.result1(1));
} }
TEST(TFCompileTest, VariableReadonly) {
Eigen::ThreadPool tp(1);
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
VariableReadonlyComp fn;
float x = 23;
fn.set_var_x_data(&x);
fn.set_thread_pool(&device);
fn.Run();
EXPECT_EQ(fn.result0(), 65);
EXPECT_EQ(fn.var_x(), 23);
}
TEST(TFCompileTest, Variable) { TEST(TFCompileTest, Variable) {
Eigen::ThreadPool tp(1); Eigen::ThreadPool tp(1);
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
@ -665,6 +686,11 @@ TEST(TFCompileTest, HloProfiling) {
/*clock_rate_ghz=*/1.0); /*clock_rate_ghz=*/1.0);
VLOG(1) << "Original HLO profile string:\n" << hlo_profile_as_string; VLOG(1) << "Original HLO profile string:\n" << hlo_profile_as_string;
// Replace Arg_n with argn when the MLIR bridge is used.
#if defined(ENABLE_MLIR_BRIDGE_TEST)
RE2::GlobalReplace(&hlo_profile_as_string, "(Arg_)([0-9].)", "arg\\2");
#endif
// Strip away identifier details from the profile string to avoid this test // Strip away identifier details from the profile string to avoid this test
// being a change detector for xla internals. Identifiers such as '%dot.0.7' // being a change detector for xla internals. Identifiers such as '%dot.0.7'
// just become '%dot'. // just become '%dot'.
@ -690,7 +716,6 @@ TEST(TFCompileTest, HloProfiling) {
IsSupersetOf({header, total_cycles_profile_line, dot_profile_line, IsSupersetOf({header, total_cycles_profile_line, dot_profile_line,
add_profile_line, tuple_profile_line})); add_profile_line, tuple_profile_line}));
} }
#endif
} // namespace } // namespace
} // namespace tfcompile } // namespace tfcompile

View File

@ -407,6 +407,7 @@ def target_llvm_triple():
"//tensorflow:android_arm64": "aarch64-none-android", "//tensorflow:android_arm64": "aarch64-none-android",
"//tensorflow:android_x86": "i686-none-android", "//tensorflow:android_x86": "i686-none-android",
"//tensorflow:ios": "arm64-none-ios", "//tensorflow:ios": "arm64-none-ios",
"//tensorflow:ios_x86_64": "x86_64-apple-ios",
"//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu", "//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
"//tensorflow:macos": "x86_64-none-darwin", "//tensorflow:macos": "x86_64-none-darwin",
"//conditions:default": "x86_64-pc-linux", "//conditions:default": "x86_64-pc-linux",

View File

@ -21,7 +21,6 @@ limitations under the License.
#include "absl/strings/match.h" #include "absl/strings/match.h"
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "llvm-c/Target.h"
#include "tensorflow/compiler/aot/codegen.h" #include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/compile.h" #include "tensorflow/compiler/aot/compile.h"
#include "tensorflow/compiler/aot/flags.h" #include "tensorflow/compiler/aot/flags.h"
@ -56,88 +55,6 @@ const char kUsageHeader[] =
"--cpp_class=\"mynamespace::MyComputation\"\n" "--cpp_class=\"mynamespace::MyComputation\"\n"
"\n"; "\n";
Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
if (absl::EndsWith(fname, ".pbtxt")) {
return ReadTextProto(Env::Default(), fname, proto);
} else {
return ReadBinaryProto(Env::Default(), fname, proto);
}
}
Status Main(const MainFlags& flags) {
// Initialize all LLVM targets so we can cross compile.
LLVMInitializeAArch64Target();
LLVMInitializeAArch64TargetInfo();
LLVMInitializeAArch64TargetMC();
LLVMInitializeAArch64AsmPrinter();
LLVMInitializeARMTarget();
LLVMInitializeARMTargetInfo();
LLVMInitializeARMTargetMC();
LLVMInitializeARMAsmPrinter();
LLVMInitializePowerPCTarget();
LLVMInitializePowerPCTargetInfo();
LLVMInitializePowerPCTargetMC();
LLVMInitializePowerPCAsmPrinter();
LLVMInitializeX86Target();
LLVMInitializeX86TargetInfo();
LLVMInitializeX86TargetMC();
LLVMInitializeX86AsmPrinter();
// Process config.
tf2xla::Config config;
if (flags.config.empty()) {
return errors::InvalidArgument("Must specify --config");
}
TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
TF_RETURN_IF_ERROR(ValidateConfig(config));
if (flags.dump_fetch_nodes) {
std::set<string> nodes;
for (const tf2xla::Fetch& fetch : config.fetch()) {
nodes.insert(fetch.id().node_name());
}
std::cout << absl::StrJoin(nodes, ",");
return Status::OK();
}
// Read and initialize the graph.
if (flags.graph.empty()) {
return errors::InvalidArgument("Must specify --graph");
}
GraphDef graph_def;
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
CompileResult compile_result;
TF_RETURN_IF_ERROR(CompileGraph(graph_def, config, flags, &compile_result));
// Write output files.
Env* env = Env::Default();
const std::vector<char>& obj = compile_result.aot->object_file_data();
TF_RETURN_IF_ERROR(
WriteStringToFile(env, flags.out_function_object,
absl::string_view(obj.data(), obj.size())));
CodegenOpts codegen_opts;
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
codegen_opts.gen_program_shape = flags.gen_program_shape;
codegen_opts.target_triple = flags.target_triple;
if (flags.cpp_class.empty()) {
return errors::InvalidArgument("Must specify --cpp_class");
}
codegen_opts.gen_hlo_profile_printer_data =
xla::GetDebugOptionsFromFlags().xla_hlo_profile();
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
&codegen_opts.namespaces));
MetadataResult metadata_result;
TF_RETURN_IF_ERROR(
GenerateMetadata(codegen_opts, compile_result, &metadata_result));
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object,
metadata_result.object_file_data));
string header;
TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result,
metadata_result, &header));
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
return Status::OK();
}
} // end namespace tfcompile } // end namespace tfcompile
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -2,14 +2,10 @@ load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "tf_cc_
load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm") load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps")
load("//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", "tf_proto_library") load("//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
load("//tensorflow/core/platform:build_config_root.bzl", "tf_cuda_tests_tags")
package( package(
default_visibility = [ default_visibility = [":internal"],
":internal",
# BEGIN-GOOGLE-INTERNAL
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
# END-GOOGLE-INTERNAL
],
licenses = ["notice"], # Apache 2.0 licenses = ["notice"], # Apache 2.0
) )
@ -61,6 +57,7 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":jit_compilation_passes", ":jit_compilation_passes",
":xla_kernel_creator", # buildcleaner: keep
"//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops",
@ -74,6 +71,7 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = if_cuda_or_rocm([ deps = if_cuda_or_rocm([
":jit_compilation_passes", ":jit_compilation_passes",
":xla_kernel_creator", # buildcleaner: keep
"//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
@ -82,19 +80,6 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "xla_mlir_gpu_jit",
visibility = ["//visibility:public"],
deps = if_cuda_or_rocm([
":jit_compilation_passes",
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/xla/service:mlir_gpu_plugin",
]),
alwayslink = 1,
)
cc_library( cc_library(
name = "xla_cpu_device", name = "xla_cpu_device",
srcs = ["xla_cpu_device.cc"], srcs = ["xla_cpu_device.cc"],
@ -120,6 +105,7 @@ cc_library(
srcs = ["xla_gpu_device.cc"], srcs = ["xla_gpu_device.cc"],
visibility = [":friends"], visibility = [":friends"],
deps = [ deps = [
":flags",
":jit_compilation_passes", ":jit_compilation_passes",
":xla_device", ":xla_device",
":xla_kernel_creator", # buildcleaner: keep ":xla_kernel_creator", # buildcleaner: keep
@ -128,6 +114,7 @@ cc_library(
"//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep "//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:gpu_init",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
@ -172,7 +159,9 @@ XLA_DEVICE_DEPS = [
":common", ":common",
":xla_launch_util", ":xla_launch_util",
":xla_tensor", ":xla_tensor",
"@com_google_absl//absl/base",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
"//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/jit/ops:xla_ops",
@ -265,13 +254,26 @@ cc_library(
}), }),
) )
# Internal targets below this point.
cc_library( cc_library(
name = "flags", name = "flags",
srcs = ["flags.cc"], srcs = ["flags.cc"],
hdrs = ["flags.h"], hdrs = ["flags.h"],
visibility = [":friends"], visibility = [":friends"],
deps = [
"//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@com_google_absl//absl/base",
"@com_google_absl//absl/strings",
],
)
# Header-only version of "flags" library, for linking from the shared object
# without ODR violations.
cc_library(
name = "flags_headers_only",
hdrs = ["flags.h"],
visibility = [":friends"],
deps = [ deps = [
"//tensorflow/compiler/xla:parse_flags_from_env", "//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
@ -291,6 +293,8 @@ cc_library(
visibility = [":friends"], visibility = [":friends"],
) )
# Internal targets below this point.
cc_library( cc_library(
name = "xla_launch_util", name = "xla_launch_util",
srcs = ["xla_launch_util.cc"], srcs = ["xla_launch_util.cc"],
@ -412,6 +416,7 @@ cc_library(
"xla_kernel_creator.h", "xla_kernel_creator.h",
], ],
deps = [ deps = [
":flags",
":jit_compilation_passes", ":jit_compilation_passes",
":xla_kernel_creator_util", ":xla_kernel_creator_util",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
@ -500,6 +505,7 @@ cc_library(
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:graph", "//tensorflow/core:graph",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
], ],
) )
@ -639,6 +645,7 @@ cc_library(
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:inlined_vector",
@ -770,7 +777,7 @@ tf_cc_test(
], ],
# TODO(b/141643254) Re-enable msan after fixing use-of-uninitialized-value # TODO(b/141643254) Re-enable msan after fixing use-of-uninitialized-value
# error. # error.
tags = ["nomsan"], tags = ["nomsan"] + tf_cuda_tests_tags(),
deps = [ deps = [
":common", ":common",
":compilation_passes", ":compilation_passes",

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/graph_node_util.h"
#include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/hash/hash.h"
@ -1583,7 +1584,6 @@ DeadnessAnalysis::~DeadnessAnalysis() {}
absl::flat_hash_map<TensorId, string, TensorId::Hasher> absl::flat_hash_map<TensorId, string, TensorId::Hasher>
DeadnessAnalysisImpl::PredicateMapAsString() const { DeadnessAnalysisImpl::PredicateMapAsString() const {
absl::flat_hash_map<TensorId, string, TensorId::Hasher> result; absl::flat_hash_map<TensorId, string, TensorId::Hasher> result;
std::vector<TensorId> tensor_ids;
for (const auto& kv_pair : predicate_map_) { for (const auto& kv_pair : predicate_map_) {
CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second); CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second);
} }

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph_node_util.h"
#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/lib/strings/proto_serialization.h"

View File

@ -374,39 +374,6 @@ xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
return new_def; return new_def;
} }
TF_ATTRIBUTE_NOINLINE Status
ValidateOutsideCompilationCallNode(Node* call_node) {
// DT_INT64 as input/output for outside compilation is not supported yet:
// b/120809951.
for (const Edge* e : call_node->in_edges()) {
if (e->IsControlEdge()) {
continue;
}
DataType dtype = e->src()->output_type(e->src_output());
if (dtype == DT_INT64) {
return errors::Unimplemented(
"int64 input for outside compilation is not supported yet: "
"b/120809951. Please cast output of node ",
e->src()->DebugString(),
" to int32 before feeding it into outside compilation.");
}
}
for (const Edge* e : call_node->out_edges()) {
if (e->IsControlEdge()) {
continue;
}
DataType dtype = e->dst()->input_type(e->dst_input());
if (dtype == DT_INT64) {
return errors::Unimplemented(
"int64 output for outside compilation is not supported yet: "
"b/120809951. Please cast input of node ",
e->dst()->DebugString(),
" to int32 before returning it from outside compilation.");
}
}
return Status::OK();
}
// Replace outside compilation function call node with XlaHostCompute node. // Replace outside compilation function call node with XlaHostCompute node.
TF_ATTRIBUTE_NOINLINE xla::StatusOr<Node*> ReplaceOutsideCompilationCallNode( TF_ATTRIBUTE_NOINLINE xla::StatusOr<Node*> ReplaceOutsideCompilationCallNode(
Graph* g, Node* call_node, const std::map<string, int>& host_compute_core, Graph* g, Node* call_node, const std::map<string, int>& host_compute_core,
@ -2130,6 +2097,53 @@ Status ExtractOutsideCompilationForNodesWithAssociatedFunctions(
return Status::OK(); return Status::OK();
} }
Status CopyOutsideCompilationConstNodes(
Graph* g, const string& outside_compilation_attr_name) {
for (Node* n : g->op_nodes()) {
if (!n->IsConstant() ||
!HasNodeAttr(n->def(), outside_compilation_attr_name)) {
continue;
}
std::vector<const Edge*> out_edges(n->out_edges().begin(),
n->out_edges().end());
bool has_non_oc_output = false;
for (const Edge* e : out_edges) {
if (!e->IsControlEdge() &&
!HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
has_non_oc_output = true;
break;
}
}
if (!has_non_oc_output) {
continue;
}
NodeDef copy_def = n->def();
copy_def.set_name(g->NewName(n->name()));
copy_def.mutable_attr()->erase(outside_compilation_attr_name);
Status s;
Node* copy_node = g->AddNode(copy_def, &s);
TF_RETURN_IF_ERROR(s);
for (const Edge* e : n->in_edges()) {
if (e->IsControlEdge()) {
g->AddControlEdge(e->src(), copy_node);
}
}
for (const Edge* e : out_edges) {
if (!e->IsControlEdge() &&
!HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
Node* dst = e->dst();
int dst_input = e->dst_input();
g->RemoveEdge(e);
g->AddEdge(copy_node, 0, dst, dst_input);
}
}
}
return Status::OK();
}
} // namespace } // namespace
Status RewriteOutsideCompilationSubgraphFn::operator()( Status RewriteOutsideCompilationSubgraphFn::operator()(
@ -2279,6 +2293,10 @@ Status ExtractOutsideCompilationForFunction(
std::vector<string> outside_compilation_host_graphs; std::vector<string> outside_compilation_host_graphs;
std::vector<string> shape_inference_graphs_to_rewrite; std::vector<string> shape_inference_graphs_to_rewrite;
if (*has_outside_compilation) { if (*has_outside_compilation) {
// Copy outside compilation Const nodes with non outside compilation users.
TF_RETURN_IF_ERROR(CopyOutsideCompilationConstNodes(
fbody->graph, outside_compilation_attr_name));
// Find dependencies between outside compilation clusters. // Find dependencies between outside compilation clusters.
TF_ASSIGN_OR_RETURN(auto cluster_deps, TF_ASSIGN_OR_RETURN(auto cluster_deps,
OutsideCompilationClusterDependencies( OutsideCompilationClusterDependencies(
@ -2333,7 +2351,6 @@ Status ExtractOutsideCompilationForFunction(
} }
std::map<string, Node*> host_compute_nodes; std::map<string, Node*> host_compute_nodes;
for (Node* n : outside_compilation_nodes) { for (Node* n : outside_compilation_nodes) {
TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n));
auto host_compute_node_or = ReplaceOutsideCompilationCallNode( auto host_compute_node_or = ReplaceOutsideCompilationCallNode(
graph_out.get(), n, host_compute_core, *cluster_deps); graph_out.get(), n, host_compute_core, *cluster_deps);
TF_RETURN_IF_ERROR(host_compute_node_or.status()); TF_RETURN_IF_ERROR(host_compute_node_or.status());

View File

@ -13,13 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/jit/flags.h"
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include "absl/base/call_once.h"
#include "absl/strings/numbers.h" #include "absl/strings/numbers.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
#include "absl/strings/strip.h" #include "absl/strings/strip.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/xla/parse_flags_from_env.h" #include "tensorflow/compiler/xla/parse_flags_from_env.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow { namespace tensorflow {
@ -32,7 +35,7 @@ XlaOpsCommonFlags* ops_flags;
IntroduceFloatingPointJitterPassFlags* jitter_flags; IntroduceFloatingPointJitterPassFlags* jitter_flags;
std::vector<Flag>* flag_list; std::vector<Flag>* flag_list;
std::once_flag flags_init; absl::once_flag flags_init;
bool SetterForXlaAutoJitFlag(const string& value) { bool SetterForXlaAutoJitFlag(const string& value) {
int32 opt_level; int32 opt_level;
@ -155,6 +158,7 @@ void AllocateAndParseFlags() {
device_flags = new XlaDeviceFlags; device_flags = new XlaDeviceFlags;
device_flags->tf_xla_compile_on_demand = false; device_flags->tf_xla_compile_on_demand = false;
device_flags->tf_xla_enable_xla_devices = true;
ops_flags = new XlaOpsCommonFlags; ops_flags = new XlaOpsCommonFlags;
ops_flags->tf_xla_always_defer_compilation = false; ops_flags->tf_xla_always_defer_compilation = false;
@ -187,6 +191,12 @@ void AllocateAndParseFlags() {
"Switch a device into 'on-demand' mode, where instead of " "Switch a device into 'on-demand' mode, where instead of "
"autoclustering ops are compiled one by one just-in-time."), "autoclustering ops are compiled one by one just-in-time."),
Flag("tf_xla_enable_xla_devices",
&device_flags->tf_xla_enable_xla_devices,
"Generate XLA_* devices, where placing a computation on such a "
"device"
"forces compilation by XLA. Deprecated."),
Flag("tf_xla_always_defer_compilation", Flag("tf_xla_always_defer_compilation",
&ops_flags->tf_xla_always_defer_compilation, ""), &ops_flags->tf_xla_always_defer_compilation, ""),
@ -206,38 +216,45 @@ void AllocateAndParseFlags() {
} // namespace } // namespace
bool SetXlaAutoJitFlagFromFlagString(const string& value) { bool SetXlaAutoJitFlagFromFlagString(const string& value) {
std::call_once(flags_init, &AllocateAndParseFlags); absl::call_once(flags_init, &AllocateAndParseFlags);
return SetterForXlaAutoJitFlag(value); return SetterForXlaAutoJitFlag(value);
} }
BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags() { BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags() {
std::call_once(flags_init, &AllocateAndParseFlags); absl::call_once(flags_init, &AllocateAndParseFlags);
return build_ops_flags; return build_ops_flags;
} }
MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() { MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() {
std::call_once(flags_init, &AllocateAndParseFlags); absl::call_once(flags_init, &AllocateAndParseFlags);
return mark_for_compilation_flags; return mark_for_compilation_flags;
} }
XlaDeviceFlags* GetXlaDeviceFlags() { XlaDeviceFlags* GetXlaDeviceFlags() {
std::call_once(flags_init, &AllocateAndParseFlags); absl::call_once(flags_init, &AllocateAndParseFlags);
return device_flags; return device_flags;
} }
const XlaOpsCommonFlags& GetXlaOpsCommonFlags() { const XlaOpsCommonFlags& GetXlaOpsCommonFlags() {
std::call_once(flags_init, &AllocateAndParseFlags); absl::call_once(flags_init, &AllocateAndParseFlags);
return *ops_flags; return *ops_flags;
} }
const IntroduceFloatingPointJitterPassFlags& const IntroduceFloatingPointJitterPassFlags&
GetIntroduceFloatingPointJitterPassFlags() { GetIntroduceFloatingPointJitterPassFlags() {
std::call_once(flags_init, &AllocateAndParseFlags); absl::call_once(flags_init, &AllocateAndParseFlags);
return *jitter_flags; return *jitter_flags;
} }
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) { void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
std::call_once(flags_init, &AllocateAndParseFlags); absl::call_once(flags_init, &AllocateAndParseFlags);
AppendMarkForCompilationPassFlagsInternal(flag_list); AppendMarkForCompilationPassFlagsInternal(flag_list);
} }
static bool xla_is_enabled = false;
void SetXlaIsEnabled() { xla_is_enabled = true; }
bool IsXlaEnabled() { return xla_is_enabled; }
} // namespace tensorflow } // namespace tensorflow

View File

@ -87,6 +87,9 @@ struct XlaDeviceFlags {
// Enabling this mode by a legacy flag is a temporary mechanism. When this // Enabling this mode by a legacy flag is a temporary mechanism. When this
// feature is battle-tested, we will switch this to be a session option. // feature is battle-tested, we will switch this to be a session option.
bool tf_xla_compile_on_demand; bool tf_xla_compile_on_demand;
// Enables "XLA" devices if this flag is set.
bool tf_xla_enable_xla_devices;
}; };
// Flags common to the _Xla* ops and their kernels. // Flags common to the _Xla* ops and their kernels.
@ -151,6 +154,15 @@ GetIntroduceFloatingPointJitterPassFlags();
// Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet. // Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet.
void AppendMarkForCompilationPassFlags( void AppendMarkForCompilationPassFlags(
std::vector<tensorflow::Flag>* flag_list); std::vector<tensorflow::Flag>* flag_list);
// Makes all future calls to `IsXlaEnabled()` return `true`.
//
// Should only be called when XLA is linked in.
void SetXlaIsEnabled();
// Returns whether XLA is enabled.
bool IsXlaEnabled();
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_ #endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "absl/base/call_once.h"
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h" #include "absl/container/flat_hash_set.h"
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
@ -1616,8 +1617,8 @@ StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileClusterImpl(
if (!should_compile && global_jit_level_ != OptimizerOptions::OFF && if (!should_compile && global_jit_level_ != OptimizerOptions::OFF &&
device_type.type_string() == DEVICE_CPU) { device_type.type_string() == DEVICE_CPU) {
static std::once_flag once; static absl::once_flag once;
std::call_once(once, [] { absl::call_once(once, [] {
LOG(WARNING) LOG(WARNING)
<< "(One-time warning): Not using XLA:CPU for cluster because envvar " << "(One-time warning): Not using XLA:CPU for cluster because envvar "
"TF_XLA_FLAGS=--tf_xla_cpu_global_jit was not set. If you want " "TF_XLA_FLAGS=--tf_xla_cpu_global_jit was not set. If you want "
@ -1776,9 +1777,9 @@ absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable() {
"Lgamma", "Digamma", "Lgamma", "Digamma",
// Binary // Binary
"Add", "AddV2", "Sub", "Mul", "Div", "Atan2", "Complex", "DivNoNan", "Add", "AddV2", "Sub", "Mul", "Div", "Atan2", "Complex", "DivNoNan",
"MulNoNan", "FloorDiv", "Xlogy", "Xdivy", "FloorMod", "BitwiseAnd", "MulNoNan", "FloorDiv", "Xlogy", "Xlog1py", "Xdivy", "FloorMod",
"BitwiseOr", "BitwiseXor", "LeftShift", "RightShift", "LogicalAnd", "BitwiseAnd", "BitwiseOr", "BitwiseXor", "LeftShift", "RightShift",
"LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv", "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
"ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "TruncateDiv", "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "TruncateDiv",
"TruncateMod", "Equal", "NotEqual", "Greater", "GreaterEqual", "TruncateMod", "Equal", "NotEqual", "Greater", "GreaterEqual",
"Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", "SoftsignGrad", "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", "SoftsignGrad",
@ -1872,6 +1873,8 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"Einsum", "Einsum",
"EmptyTensorList", "EmptyTensorList",
"ExtractImagePatches", "ExtractImagePatches",
"Igamma",
"Igammac",
"FFT", "FFT",
"FFT2D", "FFT2D",
"FFT3D", "FFT3D",

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/node_matchers.h" #include "tensorflow/compiler/jit/node_matchers.h"
#include <utility> #include <utility>
#include "absl/algorithm/container.h" #include "absl/algorithm/container.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
@ -24,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/graph_node_util.h"
namespace tensorflow { namespace tensorflow {
namespace testing { namespace testing {

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/graph/graph_node_util.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/public/version.h" #include "tensorflow/core/public/version.h"

View File

@ -17,7 +17,10 @@ limitations under the License.
#include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h"
#include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/util/dump_graph.h" #include "tensorflow/core/util/dump_graph.h"
@ -39,7 +42,7 @@ Status ShapeHandleToTensorShape(shape_inference::InferenceContext* context,
return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape); return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape);
} }
Status PropagateShapes(const Graph& graph, Status PropagateShapes(Graph* graph,
const std::map<int, InferredShape>& arg_shapes, const std::map<int, InferredShape>& arg_shapes,
const std::vector<BackEdgeHelper::BackEdge>& back_edges, const std::vector<BackEdgeHelper::BackEdge>& back_edges,
ShapeRefiner* shape_refiner) { ShapeRefiner* shape_refiner) {
@ -54,7 +57,7 @@ Status PropagateShapes(const Graph& graph,
// shapes. // shapes.
// TODO(phawkins): handle cyclic graphs. // TODO(phawkins): handle cyclic graphs.
std::vector<Node*> order; std::vector<Node*> order;
GetReversePostOrder(graph, &order); GetReversePostOrder(*graph, &order);
for (Node* n : order) { for (Node* n : order) {
// Ignore the status returned by the shape_refiner. We want the best effort // Ignore the status returned by the shape_refiner. We want the best effort
@ -99,6 +102,67 @@ Status PropagateShapes(const Graph& graph,
} }
} }
// Sometimes we have VariableShape nodes in while loop (after Enter nodes).
// They won't be constant-folded because TensorFlow constant folding does
// not handle Enter nodes (and thus does not handle any nodes after Enter
// nodes). We try to replace such VariableShape nodes with Const nodes here.
if (n->type_string() == "VariableShape") {
shape_inference::InferenceContext* context = shape_refiner->GetContext(n);
auto handle_shapes_and_types = context->input_handle_shapes_and_types(0);
if (handle_shapes_and_types && !handle_shapes_and_types->empty()) {
shape_inference::ShapeHandle handle =
handle_shapes_and_types->at(0).shape;
TensorShapeProto shape_proto;
context->ShapeHandleToProto(handle, &shape_proto);
if (!shape_proto.unknown_rank()) {
NodeDef const_def;
const_def.set_op("Const");
Node* var_node;
TF_RETURN_IF_ERROR(n->input_node(0, &var_node));
const_def.set_name(
graph->NewName(absl::StrCat("var_shape_", var_node->name())));
DataType dtype = n->output_type(0);
AddNodeAttr("dtype", dtype, &const_def);
TensorProto value;
value.set_dtype(dtype);
value.mutable_tensor_shape()->add_dim()->set_size(
shape_proto.dim_size());
for (const auto& dim : shape_proto.dim()) {
if (dtype == DT_INT32) {
value.add_int_val(dim.size());
} else {
value.add_int64_val(dim.size());
}
}
AddNodeAttr("value", value, &const_def);
for (auto const& attr : n->attrs()) {
if (*attr.first.begin() == '_') {
AddNodeAttr(attr.first, attr.second, &const_def);
}
}
Status s;
Node* const_node = graph->AddNode(const_def, &s);
TF_RETURN_IF_ERROR(s);
graph->AddControlEdge(var_node, const_node);
std::vector<const Edge*> out_edges(n->out_edges().begin(),
n->out_edges().end());
for (const Edge* e : out_edges) {
if (e->IsControlEdge()) {
graph->AddControlEdge(const_node, e->dst());
graph->RemoveEdge(e);
} else {
Node* dst = e->dst();
int dst_input = e->dst_input();
graph->RemoveEdge(e);
graph->AddEdge(const_node, 0, dst, dst_input);
}
}
}
}
}
// Merge node causes a loop so we remove NextIteration->Merge edge before // Merge node causes a loop so we remove NextIteration->Merge edge before
// performing shape inference. But removing those edges also prevents us // performing shape inference. But removing those edges also prevents us
// from inferring output shape for Merge node (we need shapes for all its // from inferring output shape for Merge node (we need shapes for all its
@ -196,7 +260,7 @@ Status InferShapes(Graph* graph, const std::map<int, InferredShape>& arg_shapes,
// the shape inference is complete. // the shape inference is complete.
BackEdgeHelper back_edge; BackEdgeHelper back_edge;
TF_RETURN_IF_ERROR(back_edge.Remove(graph)); TF_RETURN_IF_ERROR(back_edge.Remove(graph));
TF_RETURN_IF_ERROR(PropagateShapes(*graph, arg_shapes, TF_RETURN_IF_ERROR(PropagateShapes(graph, arg_shapes,
back_edge.RemovedEdges(), &shape_refiner)); back_edge.RemovedEdges(), &shape_refiner));
TF_RETURN_IF_ERROR(back_edge.Replace()); TF_RETURN_IF_ERROR(back_edge.Replace());

View File

@ -163,12 +163,11 @@ Status XlaCompilationCache::BuildExecutable(
build_options.set_device_allocator(options.device_allocator); build_options.set_device_allocator(options.device_allocator);
build_options.set_alias_passthrough_params(options.alias_passthrough_params); build_options.set_alias_passthrough_params(options.alias_passthrough_params);
auto compile_result = TF_ASSIGN_OR_RETURN(
client_->Compile(*result.computation, argument_layouts, build_options); auto executables,
if (!compile_result.ok()) { client_->Compile(*result.computation, argument_layouts, build_options));
return compile_result.status(); TF_RET_CHECK(executables.size() == 1);
} *executable = std::move(executables[0]);
*executable = std::move(compile_result.ValueOrDie());
return Status::OK(); return Status::OK();
} }

View File

@ -36,8 +36,13 @@ class XlaCpuDeviceFactory : public DeviceFactory {
}; };
Status XlaCpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) { Status XlaCpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0")); XlaDeviceFlags* flags = GetXlaDeviceFlags();
if (!flags->tf_xla_enable_xla_devices) {
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
return Status::OK();
}
devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0"));
return Status::OK(); return Status::OK();
} }
@ -45,6 +50,10 @@ Status XlaCpuDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix, const SessionOptions& session_options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) { std::vector<std::unique_ptr<Device>>* devices) {
XlaDeviceFlags* flags = GetXlaDeviceFlags(); XlaDeviceFlags* flags = GetXlaDeviceFlags();
if (!flags->tf_xla_enable_xla_devices) {
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
return Status::OK();
}
bool compile_on_demand = flags->tf_xla_compile_on_demand; bool compile_on_demand = flags->tf_xla_compile_on_demand;
XlaOpRegistry::DeviceRegistration registration; XlaOpRegistry::DeviceRegistration registration;

View File

@ -20,7 +20,9 @@ limitations under the License.
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include "absl/base/call_once.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/match.h"
#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
#include "tensorflow/compiler/jit/xla_device_context.h" #include "tensorflow/compiler/jit/xla_device_context.h"
@ -386,14 +388,33 @@ Status XlaDevice::TryGetDeviceContext(DeviceContext** out_context) {
return Status::OK(); return Status::OK();
} }
// Warn about XLA_CPU/XLA_GPU exactly once.
static void ShowXlaDeviceDeprecationWarning(
absl::string_view compilation_device_name) {
static absl::once_flag once;
if (absl::StrContains(compilation_device_name, "CPU") ||
absl::StrContains(compilation_device_name, "GPU")) {
absl::call_once(once, [] {
LOG(WARNING)
<< "XLA_GPU and XLA_CPU devices are deprecated and will be "
"removed in subsequent releases. Instead, use either "
"@tf.function(experimental_compile=True) for must-compile "
"semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 "
"for auto-clustering best-effort compilation.";
});
}
}
void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":" VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":"
<< op_kernel->type_string(); << op_kernel->type_string();
ShowXlaDeviceDeprecationWarning(jit_device_name_.type_string());
op_kernel->Compute(context); op_kernel->Compute(context);
} }
void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) { AsyncOpKernel::DoneCallback done) {
ShowXlaDeviceDeprecationWarning(jit_device_name_.type_string());
VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":" VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
<< op_kernel->type_string(); << op_kernel->type_string();
op_kernel->ComputeAsync(context, done); op_kernel->ComputeAsync(context, done);

View File

@ -140,7 +140,6 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
// The device tensor should always be fresh. // The device tensor should always be fresh.
TF_RET_CHECK(!xla_tensor->has_shaped_buffer()); TF_RET_CHECK(!xla_tensor->has_shaped_buffer());
xla_tensor->set_host_tensor(*cpu_tensor);
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_, xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
stream_->parent()->device_ordinal())); stream_->parent()->device_ordinal()));

View File

@ -14,17 +14,20 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
// Registers the XLA_GPU device, which is an XlaDevice instantiation that runs // Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
// operators using XLA via the XLA "CUDA" (GPU) backend. // operators using XLA via the XLA "CUDA" or "ROCM" (GPU) backend.
#include <set> #include <set>
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/numbers.h" #include "absl/strings/numbers.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
namespace tensorflow { namespace tensorflow {
@ -61,7 +64,14 @@ class XlaGpuDeviceFactory : public DeviceFactory {
}; };
Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) { Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA"); XlaDeviceFlags* flags = GetXlaDeviceFlags();
if (!flags->tf_xla_enable_xla_devices) {
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
return Status::OK();
}
auto platform =
se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
if (!platform.ok()) { if (!platform.ok()) {
// Treat failures as non-fatal; there might not be a GPU in the machine. // Treat failures as non-fatal; there might not be a GPU in the machine.
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status(); VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
@ -84,6 +94,12 @@ Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
Status XlaGpuDeviceFactory::CreateDevices( Status XlaGpuDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix, const SessionOptions& session_options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) { std::vector<std::unique_ptr<Device>>* devices) {
XlaDeviceFlags* flags = GetXlaDeviceFlags();
if (!flags->tf_xla_enable_xla_devices) {
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
return Status::OK();
}
XlaOpRegistry::DeviceRegistration registration; XlaOpRegistry::DeviceRegistration registration;
registration.compilation_device_name = DEVICE_GPU_XLA_JIT; registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
registration.autoclustering_policy = registration.autoclustering_policy =
@ -103,7 +119,8 @@ Status XlaGpuDeviceFactory::CreateDevices(
RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT); RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT);
(void)registrations; (void)registrations;
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA"); auto platform =
se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
if (!platform.ok()) { if (!platform.ok()) {
// Treat failures as non-fatal; there might not be a GPU in the machine. // Treat failures as non-fatal; there might not be a GPU in the machine.
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status(); VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/jit/xla_kernel_creator.h" #include "tensorflow/compiler/jit/xla_kernel_creator.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_kernel_creator_util.h" #include "tensorflow/compiler/jit/xla_kernel_creator_util.h"
#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function.h"
@ -39,6 +40,10 @@ bool RegisterLaunchOpCreator() {
} }
static bool register_me = RegisterLaunchOpCreator(); static bool register_me = RegisterLaunchOpCreator();
static bool register_xla = [] {
SetXlaIsEnabled();
return true;
}();
} // end namespace } // end namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -222,8 +222,9 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
OpKernelConstruction construction( OpKernelConstruction construction(
DeviceType(dev->device_type()), dev, DeviceType(dev->device_type()), dev,
dev->GetAllocator(AllocatorAttributes()), &node_def, dev->GetAllocator(AllocatorAttributes()), &node_def,
&fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types, &fbody->fdef.signature(), flr, dev->resource_manager(), fbody->arg_types,
fbody->ret_types, output_memory_types, flr->graph_def_version(), &s); input_memory_types, fbody->ret_types, output_memory_types,
flr->graph_def_version(), &s);
*kernel = absl::make_unique<XlaLocalLaunchBase>( *kernel = absl::make_unique<XlaLocalLaunchBase>(
&construction, constant_arg_indices, resource_arg_indices, function); &construction, constant_arg_indices, resource_arg_indices, function);

View File

@ -44,8 +44,11 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/platform:logging", "//tensorflow/core/platform:logging",
"@llvm-project//llvm:support", "@llvm-project//llvm:support",
"@llvm-project//mlir:AffineDialectRegistration",
"@llvm-project//mlir:LoopDialectRegistration",
"@llvm-project//mlir:MlirOptLib", "@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOpsDialectRegistration",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
"@llvm-project//mlir/test:TestTransforms", "@llvm-project//mlir/test:TestTransforms",
], ],
@ -63,6 +66,8 @@ cc_library(
"//tensorflow/compiler/mlir/lite:tensorflow_lite_optimize", "//tensorflow/compiler/mlir/lite:tensorflow_lite_optimize",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize", "//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
"//tensorflow/compiler/mlir/lite/quantization:quantization_passes", "//tensorflow/compiler/mlir/lite/quantization:quantization_passes",
"//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_to_quant",
"//tensorflow/compiler/mlir/lite/quantization/xla:hlo_xla_quantization_passes",
"//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
@ -74,15 +79,16 @@ cc_library(
"//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg", "//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine", "//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu", "//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_linalg",
"//tensorflow/compiler/mlir/xla:xla_dialect_registration", "//tensorflow/compiler/mlir/xla:xla_dialect_registration",
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow", "//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf", "//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/mlir/xla:xla_legalize_to_linalg",
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard", "//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
"//tensorflow/compiler/mlir/xla:xla_lower", "//tensorflow/compiler/mlir/xla:xla_lower",
"@llvm-project//mlir:AffineDialectRegistration", "//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts",
"//tensorflow/compiler/mlir/xla:xla_test_passes",
"@llvm-project//mlir:AffineOps",
"@llvm-project//mlir:QuantOps", "@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:QuantOpsDialectRegistration",
], ],
) )

View File

@ -26,9 +26,11 @@ package_group(
filegroup( filegroup(
name = "tensorflow_lite_ops_td_files", name = "tensorflow_lite_ops_td_files",
srcs = [ srcs = [
"ir/tfl_op_interfaces.td",
"ir/tfl_ops.td", "ir/tfl_ops.td",
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
"@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:include/mlir/Transforms/LoopLikeInterface.td",
], ],
) )
@ -55,6 +57,25 @@ gentbl(
], ],
) )
gentbl(
name = "tensorflow_lite_op_interfaces_inc_gen",
tbl_outs = [
(
"-gen-op-interface-decls",
"ir/tfl_ops_interface.h.inc",
),
(
"-gen-op-interface-defs",
"ir/tfl_ops_interface.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ir/tfl_op_interfaces.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
],
)
gentbl( gentbl(
name = "tensorflow_lite_prepare_tf_inc_gen", name = "tensorflow_lite_prepare_tf_inc_gen",
tbl_outs = [ tbl_outs = [
@ -177,11 +198,12 @@ cc_library(
"ir/tfl_ops.cc", "ir/tfl_ops.cc",
"ir/tfl_ops.cc.inc", "ir/tfl_ops.cc.inc",
"ir/tfl_ops.h.inc", "ir/tfl_ops.h.inc",
"ir/tfl_ops_interface.cc.inc",
"ir/tfl_ops_interface.h.inc",
"utils/attribute_utils.cc", "utils/attribute_utils.cc",
], ],
hdrs = [ hdrs = [
"ir/tfl_ops.h", "ir/tfl_ops.h",
"ir/tfl_traits.h",
"transforms/passes.h", "transforms/passes.h",
"utils/attribute_utils.h", "utils/attribute_utils.h",
"//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h", "//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",
@ -190,8 +212,6 @@ cc_library(
deps = [ deps = [
":tensorflow_lite_ops_inc_gen", ":tensorflow_lite_ops_inc_gen",
":validators", ":validators",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/lite/schema:schema_fbs",
"@llvm-project//llvm:support", "@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:Dialect", "@llvm-project//mlir:Dialect",
@ -200,6 +220,10 @@ cc_library(
"@llvm-project//mlir:QuantOps", "@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
# TODO(jpienaar): Move this out after splitting out LoopLikeOpInterface.
"@llvm-project//mlir:Transforms",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/lite/schema:schema_fbs",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -258,6 +282,7 @@ tf_cc_test(
cc_library( cc_library(
name = "tensorflow_lite_legalize_tf", name = "tensorflow_lite_legalize_tf",
srcs = [ srcs = [
"transforms/dilated_conv.cc",
"transforms/extract_ophint.cc", "transforms/extract_ophint.cc",
"transforms/generated_legalize_tf.inc", "transforms/generated_legalize_tf.inc",
"transforms/generated_lower_static_tensor_list.inc", "transforms/generated_lower_static_tensor_list.inc",
@ -273,6 +298,7 @@ cc_library(
"transforms/unroll_batch_matmul.cc", "transforms/unroll_batch_matmul.cc",
], ],
hdrs = [ hdrs = [
"transforms/dilated_conv.h",
"transforms/passes.h", "transforms/passes.h",
"transforms/unroll_batch_matmul.h", "transforms/unroll_batch_matmul.h",
], ],
@ -284,13 +310,16 @@ cc_library(
":validators", ":validators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
"//tensorflow/compiler/mlir/tensorflow:mangling_util", "//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:tensor_list",
"//tensorflow/core/platform:logging", "//tensorflow/core/platform:logging",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@llvm-project//llvm:support", "@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
@ -316,6 +345,7 @@ cc_library(
deps = [ deps = [
":tensorflow_lite", ":tensorflow_lite",
":validators", ":validators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:support", "@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
@ -330,6 +360,7 @@ cc_library(
cc_library( cc_library(
name = "tensorflow_lite_quantize", name = "tensorflow_lite_quantize",
srcs = [ srcs = [
"transforms/default_quant_params.cc",
"transforms/generated_post_quantize.inc", "transforms/generated_post_quantize.inc",
"transforms/generated_quantize.inc", "transforms/generated_quantize.inc",
"transforms/load_quantization_recipe.cc", "transforms/load_quantization_recipe.cc",
@ -346,6 +377,7 @@ cc_library(
":validators", ":validators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_config", "//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/lite/quantization/lite:tfl_to_std",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@llvm-project//llvm:support", "@llvm-project//llvm:support",
@ -370,6 +402,8 @@ genrule(
name = "op_quant_spec_getters_inc", name = "op_quant_spec_getters_inc",
srcs = [ srcs = [
"ir/tfl_ops.td", "ir/tfl_ops.td",
"ir/tfl_op_interfaces.td",
"@llvm-project//mlir:include/mlir/Transforms/LoopLikeInterface.td",
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
], ],
outs = [ outs = [
@ -436,8 +470,13 @@ cc_library(
deps = [ deps = [
":tensorflow_lite", ":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
"//tensorflow/lite/kernels/internal:kernel_utils",
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@flatbuffers", "@flatbuffers",
"@llvm-project//llvm:support", "@llvm-project//llvm:support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
@ -501,6 +540,7 @@ cc_library(
"//tensorflow/lite:schema_fbs_version", "//tensorflow/lite:schema_fbs_version",
"//tensorflow/lite:string_util", "//tensorflow/lite:string_util",
"//tensorflow/lite/delegates/flex:whitelisted_flex_ops_lib", "//tensorflow/lite/delegates/flex:whitelisted_flex_ops_lib",
"//tensorflow/lite/kernels/internal:kernel_utils",
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/tools/versioning:op_version", "//tensorflow/lite/tools/versioning:op_version",
"@com_google_absl//absl/base", "@com_google_absl//absl/base",
@ -666,12 +706,16 @@ cc_library(
], ],
) )
exports_files( cc_library(
["transforms/passes.h"], name = "empty_passes",
hdrs = ["transforms/passes.h"],
visibility = [ visibility = [
"//configs/devtools/hawkeye/tflite:__subpackages__", "//configs/devtools/hawkeye/tflite:__subpackages__",
"//learning/brain/models/app_benchmarks:__subpackages__", "//learning/brain/models/app_benchmarks:__subpackages__",
"//tensorflow/compiler/mlir/lite:friends", "//tensorflow/compiler/mlir/lite:friends",
"//tensorflow/lite/experimental/mlir:__subpackages__", "//tensorflow/lite/experimental/mlir:__subpackages__",
], ],
deps = [
"@llvm-project//llvm:support",
],
) )

View File

@ -31,10 +31,11 @@ struct PassConfig {
: emit_builtin_tflite_ops(true), : emit_builtin_tflite_ops(true),
lower_tensor_list_ops(false), lower_tensor_list_ops(false),
trim_functions_whitelist({}), trim_functions_whitelist({}),
quant_specs(specs), quant_specs(std::move(specs)),
skip_control_dialect(false), skip_control_dialect(false),
form_clusters(false), form_clusters(false),
inline_functions(false) {} inline_functions(true),
unfold_batch_matmul(true) {}
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be // If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
// added, which produces TF Lite ops. // added, which produces TF Lite ops.
@ -57,6 +58,9 @@ struct PassConfig {
// Inline function calls within the main function in the MLIR module, prior // Inline function calls within the main function in the MLIR module, prior
// to legalization to TFLite. // to legalization to TFLite.
bool inline_functions; bool inline_functions;
// if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set
// of tfl.fully_connected ops.
bool unfold_batch_matmul;
}; };
} // namespace TFL } // namespace TFL

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <algorithm> #include <algorithm>
#include <cctype> #include <cctype>
#include <cstdint>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <string> #include <string>
@ -103,12 +104,26 @@ using llvm::cl::opt;
// Commandline flag to enable the control of flatbuffer import. // Commandline flag to enable the control of flatbuffer import.
bool use_external_constant; bool use_external_constant;
// Commandline flag to enable graph pruning.
bool experimental_prune_unreachable_nodes_unconditionally;
// NOLINTNEXTLINE // NOLINTNEXTLINE
static opt<bool, true> use_external_constant_flag( static opt<bool, true> use_external_constant_flag(
"use-external-constant", "use-external-constant",
llvm::cl::desc("Use external constant during flatbuffer import"), llvm::cl::desc("Use external constant during flatbuffer import"),
llvm::cl::location(use_external_constant), llvm::cl::init(false)); llvm::cl::location(use_external_constant), llvm::cl::init(false));
// TODO(b/147111261): After the importer supports generic custom ops, we should
// change the flag to a more lightwise flag, e.g.
// "import_custom_ops_as_side_effect_free_ops", and let the MLIR DCE to prune
// the operations.
// NOLINTNEXTLINE
static opt<bool, true> experimental_prune_unreachable_nodes_unconditionally_flg(
"experimental-prune-unreachable-nodes-unconditionally",
llvm::cl::desc("Prune nodes that are not ancestors of the output nodes."),
llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally),
llvm::cl::init(false));
namespace { namespace {
bool IsScalar(const TensorT& tensor) { bool IsScalar(const TensorT& tensor) {
// TODO(b/138222071) We can't distinguish scalars and unranked tensors // TODO(b/138222071) We can't distinguish scalars and unranked tensors
@ -217,7 +232,7 @@ mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b,
// min/max stats is just for comments, so ignore it. // min/max stats is just for comments, so ignore it.
if (!tensor.quantization || IsQuantized(tensor)) return nullptr; if (!tensor.quantization || IsQuantized(tensor)) return nullptr;
// If the result isn't float and unquantizable, the min/max is ignored. // If the result isn't float and unquantizable, the min/max is ignored.
if (!res->getType() if (!res.getType()
.cast<mlir::ShapedType>() .cast<mlir::ShapedType>()
.getElementType() .getElementType()
.isa<mlir::FloatType>()) { .isa<mlir::FloatType>()) {
@ -255,10 +270,23 @@ mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b,
} }
StatusOr<std::string> OpNameForOpCode(const tflite::OperatorCodeT opcode) { StatusOr<std::string> OpNameForOpCode(const tflite::OperatorCodeT opcode) {
// TODO(krzysd) Support custom ops // TODO(b/143872630): Support custom ops
if (opcode.builtin_code == tflite::BuiltinOperator_CUSTOM) { if (opcode.builtin_code == tflite::BuiltinOperator_CUSTOM) {
return errors::Unimplemented("unsupported custom operation: ", // Adding some custom op supported on GPU.
opcode.custom_code); const absl::string_view custom_name = opcode.custom_code;
if (custom_name == "MaxPoolingWithArgmax2D") {
return std::string("tfl.max_pooling_with_argmax_2d");
}
if (custom_name == "Convolution2DTransposeBias") {
return std::string("tfl.convolution_2d_transpose_bias");
}
if (custom_name == "MaxUnpooling2D") {
return std::string("tfl.max_unpooling_2d");
}
// Use an unsupported op name instead of throwing an error here in case the
// op is pruned during the import.
return std::string(
llvm::Twine("tfl.UNSUPPORTED_custom_", opcode.custom_code).str());
} }
if (opcode.builtin_code == tflite::BuiltinOperator_IF) { if (opcode.builtin_code == tflite::BuiltinOperator_IF) {
return std::string("tf.If"); return std::string("tf.If");
@ -361,7 +389,6 @@ StatusOr<mlir::ElementsAttr> ConvertIntBuffer(
mlir::RankedTensorType shaped_type, mlir::Type elem_type, mlir::RankedTensorType shaped_type, mlir::Type elem_type,
const std::vector<uint8_t>& buffer) { const std::vector<uint8_t>& buffer) {
unsigned bit_width; unsigned bit_width;
mlir::RankedTensorType buffer_type;
if (auto itype = elem_type.dyn_cast<mlir::IntegerType>()) { if (auto itype = elem_type.dyn_cast<mlir::IntegerType>()) {
bit_width = itype.getWidth(); bit_width = itype.getWidth();
} else if (auto qtype = elem_type.dyn_cast<QuantizedType>()) { } else if (auto qtype = elem_type.dyn_cast<QuantizedType>()) {
@ -495,6 +522,13 @@ bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) {
} }
} }
// Returns true if this is a custom op.
bool IsCustomOp(const std::string& op_name) {
return op_name == "tfl.max_pooling_with_argmax_2d" ||
op_name == "tfl.max_unpooling_2d" ||
op_name == "tfl.convolution_2d_transpose_bias";
}
// TODO(krzysd) Handle function calls // TODO(krzysd) Handle function calls
StatusOr<Operation*> ConvertOp( StatusOr<Operation*> ConvertOp(
const tflite::OperatorT& op, const std::vector<Value>& vals_map, const tflite::OperatorT& op, const std::vector<Value>& vals_map,
@ -557,7 +591,15 @@ StatusOr<Operation*> ConvertOp(
} }
llvm::SmallVector<mlir::NamedAttribute, 2> attrs; llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
mlir::BuiltinOptionsToAttributes(op.builtin_options, builder, attrs); if (IsCustomOp(op_name)) {
auto status = mlir::CustomOptionsToAttributes(op_name, op.custom_options,
builder, loc, &attrs);
if (!status.ok()) {
return emitError(loc, status.ToString()), status;
}
} else {
mlir::BuiltinOptionsToAttributes(op.builtin_options, builder, attrs);
}
op_state.addAttributes(attrs); op_state.addAttributes(attrs);
// Handle the conversion from subgraph index to functions for If and While // Handle the conversion from subgraph index to functions for If and While
@ -619,6 +661,49 @@ mlir::NamedAttribute BuildTFEntryFunctionAttribute(
name, builder->getStringAttr(llvm::join(tensor_names, ","))); name, builder->getStringAttr(llvm::join(tensor_names, ",")));
} }
// Given a list of output indices, traverses the subgraph and returns the set of
// ops that are ancestors of the output tensors.
StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
const tflite::SubGraphT& subgraph, ArrayRef<int32_t> output_indices) {
// Create a map from tensor index to defining op.
absl::flat_hash_map<int32_t, const tflite::OperatorT*> defining_op;
for (const auto& op : subgraph.operators) {
for (int32_t output : op->outputs) {
defining_op[output] = op.get();
}
}
std::vector<const tflite::OperatorT*> queue;
for (int32_t output : output_indices) {
if (auto& op = defining_op[output]) {
queue.push_back(op);
} else {
return errors::InvalidArgument("Output tensor doesn't have defining op");
}
}
// Traverse the graph towards inputs.
absl::flat_hash_set<const tflite::OperatorT*> visited;
while (!queue.empty()) {
const tflite::OperatorT* op = queue.back();
queue.pop_back();
if (!visited.insert(op).second) {
// The node has already been visited.
continue;
}
for (int32_t input : op->inputs) {
// Input tensor may not have a defining op in case it is a subgraph input
// or a constant tensor.
if (auto& op = defining_op[input]) {
queue.push_back(op);
}
}
}
return visited;
}
// Build a FuncOp from a tflite SubGraph // Build a FuncOp from a tflite SubGraph
// The op_names are a mapping from indexes into the TFLite operators array to // The op_names are a mapping from indexes into the TFLite operators array to
// the operator name MLIR expects (tfl.foo_op). The buffers are directly taken // the operator name MLIR expects (tfl.foo_op). The buffers are directly taken
@ -635,7 +720,8 @@ StatusOr<FuncOp> ConvertSubgraph(
const std::vector<std::unique_ptr<tflite::BufferT>>& buffers, const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
Location base_loc, Builder builder, Location base_loc, Builder builder,
const std::vector<std::string>& ordered_output_arrays, bool is_entry_point, const std::vector<std::string>& ordered_output_arrays, bool is_entry_point,
bool use_external_constant) { bool use_external_constant,
bool experimental_prune_unreachable_nodes_unconditionally) {
llvm::SmallVector<mlir::Type, 2> ret_types; llvm::SmallVector<mlir::Type, 2> ret_types;
llvm::SmallVector<mlir::Type, 4> input_types; llvm::SmallVector<mlir::Type, 4> input_types;
@ -731,8 +817,19 @@ StatusOr<FuncOp> ConvertSubgraph(
func.setAttr("tf.entry_function", builder.getDictionaryAttr(attributes)); func.setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
} }
absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
if (experimental_prune_unreachable_nodes_unconditionally) {
TF_ASSIGN_OR_RETURN(pruned_subgraph_ops,
PruneSubgraph(subgraph, func_outputs));
}
// Construct MLIR operators from TFLite operators // Construct MLIR operators from TFLite operators
for (auto& op : subgraph.operators) { for (auto& op : subgraph.operators) {
if (experimental_prune_unreachable_nodes_unconditionally &&
!pruned_subgraph_ops.contains(op)) {
continue;
}
for (auto input_num : op->inputs) { for (auto input_num : op->inputs) {
// The operators in a graph are topologically sorted // The operators in a graph are topologically sorted
// and so if no previous operation has produced a tensor // and so if no previous operation has produced a tensor
@ -822,22 +919,21 @@ StatusOr<FuncOp> ConvertSubgraph(
// represents TFLite, this entry point must be called "main" // represents TFLite, this entry point must be called "main"
// TODO(b/131175224,b/132239787) Support multiple entry points // TODO(b/131175224,b/132239787) Support multiple entry points
std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) { std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
if (subgraph.name.empty()) { if (index == 0) {
if (index == 0) { return "main";
return "main";
} else {
return llvm::formatv("fn_{0}", index).str();
}
} else {
return subgraph.name;
} }
if (subgraph.name.empty()) {
return llvm::formatv("fn_{0}", index).str();
}
return subgraph.name;
} }
} // namespace } // namespace
OwningModuleRef tflite::FlatBufferToMlir( OwningModuleRef tflite::FlatBufferToMlir(
absl::string_view buffer, MLIRContext* context, Location base_loc, absl::string_view buffer, MLIRContext* context, Location base_loc,
const std::vector<std::string>& ordered_output_arrays, const std::vector<std::string>& ordered_output_arrays,
bool use_external_constant) { bool use_external_constant,
bool experimental_prune_unreachable_nodes_unconditionally) {
auto model_ptr = auto model_ptr =
FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length()); FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
if (nullptr == model_ptr) { if (nullptr == model_ptr) {
@ -892,7 +988,8 @@ OwningModuleRef tflite::FlatBufferToMlir(
// TODO(b/131175224,b/132239787) Support multiple entry points // TODO(b/131175224,b/132239787) Support multiple entry points
builder, ordered_output_arrays, builder, ordered_output_arrays,
/*is_entry_point=*/e.index() == 0, /*is_entry_point=*/e.index() == 0,
/*use_external_constant=*/use_external_constant); /*use_external_constant=*/use_external_constant,
experimental_prune_unreachable_nodes_unconditionally);
if (!func_or_error.ok()) { if (!func_or_error.ok()) {
return emitError(base_loc, "could not translate function ") return emitError(base_loc, "could not translate function ")
<< subgraph->name, << subgraph->name,
@ -905,9 +1002,10 @@ OwningModuleRef tflite::FlatBufferToMlir(
return OwningModuleRef(module); return OwningModuleRef(module);
} }
static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr, static OwningModuleRef FlatBufferFileToMlirTrans(
MLIRContext* context, llvm::SourceMgr* source_mgr, MLIRContext* context,
bool use_external_constant) { bool use_external_constant,
bool experimental_prune_unreachable_nodes_unconditionally) {
const llvm::MemoryBuffer* input = const llvm::MemoryBuffer* input =
source_mgr->getMemoryBuffer(source_mgr->getMainFileID()); source_mgr->getMemoryBuffer(source_mgr->getMainFileID());
std::string error; std::string error;
@ -924,12 +1022,14 @@ static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
return tflite::FlatBufferToMlir( return tflite::FlatBufferToMlir(
absl::string_view(input->getBufferStart(), input->getBufferSize()), absl::string_view(input->getBufferStart(), input->getBufferSize()),
context, loc, outputs, use_external_constant); context, loc, outputs, use_external_constant,
experimental_prune_unreachable_nodes_unconditionally);
} }
static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg( static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg(
"tflite-flatbuffer-to-mlir", "tflite-flatbuffer-to-mlir",
[](llvm::SourceMgr& source_mgr, MLIRContext* context) { [](llvm::SourceMgr& source_mgr, MLIRContext* context) {
return FlatBufferFileToMlirTrans(&source_mgr, context, return FlatBufferFileToMlirTrans(
use_external_constant); &source_mgr, context, use_external_constant,
experimental_prune_unreachable_nodes_unconditionally);
}); });

View File

@ -31,11 +31,14 @@ namespace tflite {
// on failure, and more specific errors will be emitted via the context. // on failure, and more specific errors will be emitted via the context.
// If `use_external_constant` is true, it will create `tfl.external_const` // If `use_external_constant` is true, it will create `tfl.external_const`
// instead of `tfl.const`. // instead of `tfl.const`.
// If `experimental_prune_unreachable_nodes_unconditionally` is true, nodes that
// are not ancestors of the output nodes will be pruned.
mlir::OwningModuleRef FlatBufferToMlir( mlir::OwningModuleRef FlatBufferToMlir(
absl::string_view buffer, mlir::MLIRContext* context, absl::string_view buffer, mlir::MLIRContext* context,
mlir::Location base_loc, mlir::Location base_loc,
const std::vector<std::string>& ordered_output_arrays, const std::vector<std::string>& ordered_output_arrays,
bool use_external_constant = false); bool use_external_constant = false,
bool experimental_prune_unreachable_nodes_unconditionally = false);
} // namespace tflite } // namespace tflite
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_ #endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_

View File

@ -17,6 +17,8 @@ limitations under the License.
#include <vector> #include <vector>
#include "absl/strings/str_cat.h"
#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/StringSwitch.h"
#include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project
@ -24,8 +26,36 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // TF:llvm-project #include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h"
namespace {
using ::tensorflow::Status;
using ::tensorflow::errors::InvalidArgument;
using ::xla::StatusOr;
StatusOr<mlir::StringAttr> GetPaddingAttr(TfLitePadding pad_params,
mlir::Builder builder,
mlir::Location loc) {
auto padding = tflite::Padding::Padding_VALID;
if (pad_params == TfLitePadding::kTfLitePaddingSame) {
padding = tflite::Padding_SAME;
} else if (pad_params == TfLitePadding::kTfLitePaddingValid) {
padding = tflite::Padding_VALID;
} else {
return InvalidArgument(
absl::StrCat("Invalid padding type", std::to_string(pad_params)));
}
const char* option_name = tflite::EnumNamePadding(padding);
return builder.getStringAttr(option_name);
}
} // namespace
// TODO(jpienaar): This is a placeholder. This should be done in more efficient // TODO(jpienaar): This is a placeholder. This should be done in more efficient
// way when part of the translation of module. // way when part of the translation of module.
static tflite::ActivationFunctionType ConvertTFL_AFAttrForOptionWriter( static tflite::ActivationFunctionType ConvertTFL_AFAttrForOptionWriter(
@ -212,5 +242,44 @@ static mlir::Attribute BuildTFL_PaddingAttr(tflite::Padding value,
return builder.getStringAttr(option_name); return builder.getStringAttr(option_name);
} }
Status mlir::CustomOptionsToAttributes(
const std::string& op_name, const std::vector<uint8_t>& custom_options,
mlir::Builder builder, mlir::Location loc,
llvm::SmallVectorImpl<mlir::NamedAttribute>* attributes) {
if (op_name == "tfl.max_pooling_with_argmax_2d" ||
op_name == "tfl.max_unpooling_2d") {
auto* pool_params =
reinterpret_cast<const TfLitePoolParams*>(custom_options.data());
TF_ASSIGN_OR_RETURN(auto padding_attribute,
GetPaddingAttr(pool_params->padding, builder, loc));
attributes->emplace_back(
builder.getNamedAttr("padding", padding_attribute));
attributes->emplace_back(builder.getNamedAttr(
"stride_h", builder.getI32IntegerAttr(pool_params->stride_height)));
attributes->emplace_back(builder.getNamedAttr(
"stride_w", builder.getI32IntegerAttr(pool_params->stride_width)));
attributes->emplace_back(builder.getNamedAttr(
"filter_h", builder.getI32IntegerAttr(pool_params->filter_height)));
attributes->emplace_back(builder.getNamedAttr(
"filter_w", builder.getI32IntegerAttr(pool_params->filter_width)));
return Status::OK();
} else if (op_name == "tfl.convolution_2d_transpose_bias") {
auto* conv_params = reinterpret_cast<const TfLiteTransposeConvParams*>(
custom_options.data());
TF_ASSIGN_OR_RETURN(auto padding_attribute,
GetPaddingAttr(conv_params->padding, builder, loc));
attributes->emplace_back(
builder.getNamedAttr("padding", padding_attribute));
attributes->emplace_back(builder.getNamedAttr(
"stride_h", builder.getI32IntegerAttr(conv_params->stride_height)));
attributes->emplace_back(builder.getNamedAttr(
"stride_w", builder.getI32IntegerAttr(conv_params->stride_width)));
return Status::OK();
}
return InvalidArgument(absl::StrCat("invalid custom op type: ", op_name));
}
// Pull in FlatBuffer writers for TFLite generated using TableGen // Pull in FlatBuffer writers for TFLite generated using TableGen
#include "tensorflow/compiler/mlir/lite/operator_converters.inc" #include "tensorflow/compiler/mlir/lite/operator_converters.inc"

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project #include "mlir/IR/Operation.h" // TF:llvm-project
#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h"
namespace mlir { namespace mlir {
@ -45,7 +46,7 @@ llvm::Optional<flatbuffers::Offset<tflite::Operator>> CreateFlatBufferOperator(
const std::vector<int32_t> &operands, const std::vector<int32_t> &results, const std::vector<int32_t> &operands, const std::vector<int32_t> &results,
flatbuffers::FlatBufferBuilder *fbb); flatbuffers::FlatBufferBuilder *fbb);
// Populate the array of mlir::NamedAttributes corresponding to the given // Populates the array of mlir::NamedAttributes corresponding to the given
// tflite::FlatbufferOptionsUnion. // tflite::FlatbufferOptionsUnion.
// We use an out parameter per LLVM convention // We use an out parameter per LLVM convention
void BuiltinOptionsToAttributes( void BuiltinOptionsToAttributes(
@ -53,6 +54,15 @@ void BuiltinOptionsToAttributes(
// NOLINTNEXTLINE // NOLINTNEXTLINE
llvm::SmallVectorImpl<mlir::NamedAttribute> &attributes); llvm::SmallVectorImpl<mlir::NamedAttribute> &attributes);
// Populates the array of mlir::NamedAttributes corresponding to the given
// custom_options.
// We use an out parameter per LLVM convention
tensorflow::Status CustomOptionsToAttributes(
const std::string &op_name, const std::vector<uint8_t> &custom_options,
mlir::Builder builder,
// NOLINTNEXTLINE
Location loc, llvm::SmallVectorImpl<mlir::NamedAttribute> *attributes);
} // namespace mlir } // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_OPERATOR_H_ #endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_OPERATOR_H_

View File

@ -71,6 +71,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h" #include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/string_util.h" #include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/tools/versioning/op_version.h" #include "tensorflow/lite/tools/versioning/op_version.h"
@ -89,6 +90,7 @@ using mlir::MLIRContext;
using mlir::ModuleOp; using mlir::ModuleOp;
using mlir::NoneType; using mlir::NoneType;
using mlir::Operation; using mlir::Operation;
using mlir::Region;
using mlir::StringAttr; using mlir::StringAttr;
using mlir::TensorType; using mlir::TensorType;
using mlir::TranslateFromMLIRRegistration; using mlir::TranslateFromMLIRRegistration;
@ -218,6 +220,13 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
auto qtype = type.cast<mlir::quant::UniformQuantizedPerAxisType>(); auto qtype = type.cast<mlir::quant::UniformQuantizedPerAxisType>();
return GetTFLiteType(qtype.getStorageType(), qtype.isSigned()); return GetTFLiteType(qtype.getStorageType(), qtype.isSigned());
} }
case mlir::TF::TensorFlowTypes::RESOURCE: {
// Treat tf.resource values as integer values in flatbuffer.
// TODO(b/146131919): Maybe need to have a detailed design for supporting
// other resource types beyonds hash table resources and resource
// variables.
return tflite::TensorType_INT32;
}
default: default:
// TFLite export fills FLOAT32 for unknown data types. Returning an error // TFLite export fills FLOAT32 for unknown data types. Returning an error
// for now for safety and this could be revisited when required. // for now for safety and this could be revisited when required.
@ -233,17 +242,17 @@ static bool IsConst(Operation* op) {
template <typename T> template <typename T>
static bool HasValidTFLiteType(Value value, T& error_handler) { static bool HasValidTFLiteType(Value value, T& error_handler) {
// None type is allowed to represent unspecified operands. // None type is allowed to represent unspecified operands.
if (value->getType().isa<NoneType>()) return true; if (value.getType().isa<NoneType>()) return true;
auto type = value->getType().dyn_cast<TensorType>(); auto type = value.getType().dyn_cast<TensorType>();
if (!type) { if (!type) {
if (auto op = value->getDefiningOp()) { if (auto op = value.getDefiningOp()) {
error_handler.emitError() error_handler.emitError()
<< '\'' << op << "' should produce value of tensor type instead of " << '\'' << op << "' should produce value of tensor type instead of "
<< value->getType(); << value.getType();
return false; return false;
} }
error_handler.emitError("expected tensor type, got ") << value->getType(); error_handler.emitError("expected tensor type, got ") << value.getType();
return false; return false;
} }
@ -282,7 +291,7 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
for (auto arg : bb.getArguments()) { for (auto arg : bb.getArguments()) {
if (!HasValidTFLiteType(arg, fn)) if (!HasValidTFLiteType(arg, fn))
return fn.emitError("invalid TFLite type: ") << arg->getType(), false; return fn.emitError("invalid TFLite type: ") << arg.getType(), false;
} }
// Verify that all operations except the terminator have exactly one // Verify that all operations except the terminator have exactly one
@ -292,7 +301,7 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
for (auto result : inst.getResults()) { for (auto result : inst.getResults()) {
if (!HasValidTFLiteType(result, inst)) if (!HasValidTFLiteType(result, inst))
return fn.emitError("invalid TFLite type: ") << result->getType(), return fn.emitError("invalid TFLite type: ") << result.getType(),
false; false;
} }
} }
@ -301,7 +310,7 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
return true; return true;
} }
static std::unique_ptr<::tensorflow::NodeDef> getTensorFlowNodeDef( static std::unique_ptr<::tensorflow::NodeDef> GetTensorFlowNodeDef(
::mlir::Operation* inst) { ::mlir::Operation* inst) {
// We pass empty string for the original node_def name since Flex runtime // We pass empty string for the original node_def name since Flex runtime
// does not care about this being set correctly on node_def. There is no // does not care about this being set correctly on node_def. There is no
@ -317,6 +326,48 @@ static std::unique_ptr<::tensorflow::NodeDef> getTensorFlowNodeDef(
return std::move(status_or_node_def.ValueOrDie()); return std::move(status_or_node_def.ValueOrDie());
} }
// Converts a mlir padding StringRef to TfLitePadding.
// Returns llvm::None if conversion fails.
static Optional<TfLitePadding> GetTflitePadding(Operation* inst,
llvm::StringRef padding) {
const tflite::Padding padding_attr =
std::move(llvm::StringSwitch<tflite::Padding>(padding)
.Case("SAME", tflite::Padding_SAME)
.Case("VALID", tflite::Padding_VALID));
if (padding_attr == tflite::Padding_SAME) {
return kTfLitePaddingSame;
}
if (padding_attr == tflite::Padding_VALID) {
return kTfLitePaddingValid;
}
return inst->emitOpError() << "Invalid padding attribute: " << padding,
llvm::None;
}
// Extracts TfLitePoolParams from a TFL custom op.
// Template parameter, TFLOp, should be a TFL custom op containing attributes
// generated from TfLitePoolParams.
// Returns llvm::None if conversion fails.
template <typename TFLOp>
static Optional<TfLitePoolParams> GetTflitePoolParams(Operation* inst,
TFLOp op) {
TfLitePoolParams pool_params;
pool_params.stride_height = op.stride_h().getSExtValue();
pool_params.stride_width = op.stride_w().getSExtValue();
pool_params.filter_height = op.filter_h().getSExtValue();
pool_params.filter_width = op.filter_w().getSExtValue();
const auto padding = GetTflitePadding(inst, op.padding());
if (padding) {
pool_params.padding = *padding;
pool_params.activation = kTfLiteActNone;
pool_params.computed.padding = TfLitePaddingValues{0, 0, 0, 0};
return pool_params;
}
return llvm::None;
}
namespace { namespace {
// Translates an MLIR module in TFLite dialect to TFLite FlatBuffer. // Translates an MLIR module in TFLite dialect to TFLite FlatBuffer.
@ -375,9 +426,36 @@ class Translator {
mlir::TF::WhileOp op, const std::vector<int32_t>& operands, mlir::TF::WhileOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results); const std::vector<int32_t>& results);
// Build while operator where cond & body are regions.
BufferOffset<tflite::Operator> BuildWhileOperator(
mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
// Builds custom operators.
// Templated on a) data type of custom_option to be stored into flatbuffer,
// and b) TFL custom op type.
template <typename CustomOptionType, typename TFLOp>
BufferOffset<tflite::Operator> BuildCustomOperator(
const CustomOptionType& custom_option, const std::string& opcode_name,
TFLOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
BufferOffset<tflite::Operator> BuildNumericVerifyOperator( BufferOffset<tflite::Operator> BuildNumericVerifyOperator(
mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands, mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results); const std::vector<int32_t>& results);
Optional<BufferOffset<tflite::Operator>>
BuildConvolution2DTransposeBiasOperator(
Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op,
const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
Optional<BufferOffset<tflite::Operator>> BuildMaxPoolingWithArgMax2DOperator(
Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op,
const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
Optional<BufferOffset<tflite::Operator>> BuildMaxUnpooling2DOperator(
Operation* inst, mlir::TFL::MaxUnpooling2DOp op,
const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
Optional<CustomOptionsOffset> CreateFlexOpCustomOptions( Optional<CustomOptionsOffset> CreateFlexOpCustomOptions(
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
@ -400,7 +478,10 @@ class Translator {
Operation* inst, const std::vector<int32_t>& operands, Operation* inst, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results); const std::vector<int32_t>& results);
Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(FuncOp fn); // Build a subgraph with a given name out of the region either corresponding
// to a function's body or while op.
Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(
const std::string& name, Region* region);
// Builds Metadata with the given `name` and buffer `content`. // Builds Metadata with the given `name` and buffer `content`.
BufferOffset<tflite::Metadata> BuildMetadata(StringRef name, BufferOffset<tflite::Metadata> BuildMetadata(StringRef name,
@ -422,6 +503,12 @@ class Translator {
// Returns a unique name for `val`. // Returns a unique name for `val`.
std::string UniqueName(mlir::Value val); std::string UniqueName(mlir::Value val);
// Returns the names of the subgraphs corresponding the regions of the op. The
// names are supposed to be unique as the op name is unique and the suffix is
// not a valid name.
std::string GetWhileBodyName(mlir::TFL::WhileOp while_op);
std::string GetWhileCondName(mlir::TFL::WhileOp while_op);
ModuleOp module_; ModuleOp module_;
tensorflow::OpOrArgNameMapper& name_mapper_; tensorflow::OpOrArgNameMapper& name_mapper_;
@ -451,7 +538,7 @@ class Translator {
}; };
std::string Translator::UniqueName(mlir::Value val) { std::string Translator::UniqueName(mlir::Value val) {
return name_mapper_.GetUniqueName(val); return std::string(name_mapper_.GetUniqueName(val));
} }
Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer( Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
@ -504,7 +591,7 @@ Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor( Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
Value value, const std::string& name, unsigned buffer_idx) { Value value, const std::string& name, unsigned buffer_idx) {
auto type = value->getType().cast<TensorType>(); auto type = value.getType().cast<TensorType>();
// TFLite requires tensor shape only for the inputs and constants. // TFLite requires tensor shape only for the inputs and constants.
// However, we output all known shapes for better round-tripping // However, we output all known shapes for better round-tripping
@ -516,19 +603,20 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range)) if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range))
return mlir::emitError( return mlir::emitError(
value->getLoc(), value.getLoc(),
"result shape dimensions out of 32 bit int type range"); "result shape dimensions out of 32 bit int type range");
return mlir::success(); return mlir::success();
}; };
std::vector<int32_t> shape; std::vector<int32_t> shape;
std::vector<int32_t> shape_signature;
if (type.hasStaticShape()) { if (type.hasStaticShape()) {
llvm::ArrayRef<int64_t> shape_ref = type.getShape(); llvm::ArrayRef<int64_t> shape_ref = type.getShape();
if (mlir::failed(check_shape(shape_ref))) return llvm::None; if (mlir::failed(check_shape(shape_ref))) return llvm::None;
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end()); shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
} else if (auto* inst = value->getDefiningOp()) { } else if (auto* inst = value.getDefiningOp()) {
if (IsConst(inst)) { if (IsConst(inst)) {
// Const op can have a result of dynamic shaped type (e.g. due to constant // Const op can have a result of dynamic shaped type (e.g. due to constant
// folding), but we can still derive the shape of a constant tensor for // folding), but we can still derive the shape of a constant tensor for
@ -540,7 +628,17 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end()); shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
} }
} else if (type.hasRank()) {
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
shape.reserve(shape_ref.size());
for (auto& dim : shape_ref) {
shape.push_back(dim == -1 ? 1 : dim);
}
shape_signature = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
} }
Type element_type = type.getElementType(); Type element_type = type.getElementType();
tflite::TensorType tflite_element_type = tflite::TensorType tflite_element_type =
GetTFLiteType(type.getElementType()).ValueOrDie(); GetTFLiteType(type.getElementType()).ValueOrDie();
@ -571,16 +669,25 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
// marked as a stateful. If so, set the tensor's is_variable as true // marked as a stateful. If so, set the tensor's is_variable as true
// This is v1 ref variable semantics in the TFLite runtime. // This is v1 ref variable semantics in the TFLite runtime.
bool is_variable = false; bool is_variable = false;
for (auto& use : value->getUses()) { for (auto& use : value.getUses()) {
is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber()); is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber());
if (is_variable) { if (is_variable) {
break; break;
} }
} }
return tflite::CreateTensor(
builder_, builder_.CreateVector(shape), tflite_element_type, if (shape_signature.empty()) {
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params, return tflite::CreateTensor(
/*is_variable=*/is_variable); builder_, builder_.CreateVector(shape), tflite_element_type,
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
/*is_variable=*/is_variable);
} else {
return tflite::CreateTensor(
builder_, builder_.CreateVector(shape), tflite_element_type,
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
/*is_variable=*/is_variable, /*sparsity=*/0,
/*shape_signature=*/builder_.CreateVector(shape_signature));
}
} }
BufferOffset<tflite::Operator> Translator::BuildIfOperator( BufferOffset<tflite::Operator> Translator::BuildIfOperator(
@ -615,19 +722,96 @@ BufferOffset<tflite::Operator> Translator::BuildWhileOperator(
builtin_options); builtin_options);
} }
std::string Translator::GetWhileBodyName(mlir::TFL::WhileOp while_op) {
return (name_mapper_.GetUniqueName(while_op.getOperation()) + "$body").str();
}
std::string Translator::GetWhileCondName(mlir::TFL::WhileOp while_op) {
return (name_mapper_.GetUniqueName(while_op.getOperation()) + "$cond").str();
}
BufferOffset<tflite::Operator> Translator::BuildWhileOperator(
mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results) {
auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE);
int body_subgraph_index = subgraph_index_map_.at(GetWhileBodyName(op));
int cond_subgraph_index = subgraph_index_map_.at(GetWhileCondName(op));
auto builtin_options = tflite::CreateWhileOptions(
builder_, cond_subgraph_index, body_subgraph_index)
.Union();
auto inputs = builder_.CreateVector(operands);
auto outputs = builder_.CreateVector(results);
return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
tflite::BuiltinOptions_WhileOptions,
builtin_options);
}
template <typename CustomOptionType, typename TFLOp>
BufferOffset<tflite::Operator> Translator::BuildCustomOperator(
const CustomOptionType& custom_option, const std::string& opcode_name,
TFLOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results) {
std::vector<uint8_t> custom_option_vector(sizeof(CustomOptionType));
memcpy(custom_option_vector.data(), &custom_option, sizeof(CustomOptionType));
auto opcode_index =
GetOpcodeIndex(opcode_name, tflite::BuiltinOperator_CUSTOM);
return tflite::CreateOperator(
builder_, opcode_index, builder_.CreateVector(operands),
builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
/*builtin_options=*/0,
builder_.CreateVector<uint8_t>(custom_option_vector),
tflite::CustomOptionsFormat_FLEXBUFFERS);
}
BufferOffset<tflite::Operator> Translator::BuildNumericVerifyOperator( BufferOffset<tflite::Operator> Translator::BuildNumericVerifyOperator(
mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands, mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results) { const std::vector<int32_t>& results) {
float tolerance = op.tolerance().convertToFloat(); float tolerance = op.tolerance().convertToFloat();
std::vector<uint8_t> custom_options(sizeof(float)); return BuildCustomOperator(tolerance, "NumericVerify", op, operands, results);
memcpy(custom_options.data(), &tolerance, sizeof(float)); }
auto opcode_index =
GetOpcodeIndex("NumericVerify", tflite::BuiltinOperator_CUSTOM); Optional<BufferOffset<tflite::Operator>>
return tflite::CreateOperator( Translator::BuildConvolution2DTransposeBiasOperator(
builder_, opcode_index, builder_.CreateVector(operands), Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op,
builder_.CreateVector(results), tflite::BuiltinOptions_NONE, const std::vector<int32_t>& operands, const std::vector<int32_t>& results) {
/*builtin_options=*/0, builder_.CreateVector<uint8_t>(custom_options), TfLiteTransposeConvParams conv_params;
tflite::CustomOptionsFormat_FLEXBUFFERS); conv_params.stride_height = op.stride_h().getSExtValue();
conv_params.stride_width = op.stride_w().getSExtValue();
const auto padding = GetTflitePadding(inst, op.padding());
if (padding) {
conv_params.padding = *padding;
return BuildCustomOperator(conv_params, "Convolution2DTransposeBias", op,
operands, results);
}
return llvm::None;
}
Optional<BufferOffset<tflite::Operator>>
Translator::BuildMaxPoolingWithArgMax2DOperator(
Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op,
const std::vector<int32_t>& operands, const std::vector<int32_t>& results) {
const auto pool_params = GetTflitePoolParams(inst, op);
if (pool_params) {
return BuildCustomOperator(*pool_params, "MaxPoolingWithArgmax2D", op,
operands, results);
}
return llvm::None;
}
Optional<BufferOffset<tflite::Operator>>
Translator::BuildMaxUnpooling2DOperator(Operation* inst,
mlir::TFL::MaxUnpooling2DOp op,
const std::vector<int32_t>& operands,
const std::vector<int32_t>& results) {
const auto pool_params = GetTflitePoolParams(inst, op);
if (pool_params) {
return BuildCustomOperator(*pool_params, "MaxUnpooling2D", op, operands,
results);
}
return llvm::None;
} }
Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions( Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
@ -769,6 +953,24 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
if (auto verify_op = dyn_cast<mlir::TFL::NumericVerifyOp>(inst)) { if (auto verify_op = dyn_cast<mlir::TFL::NumericVerifyOp>(inst)) {
return BuildNumericVerifyOperator(verify_op, operands, results); return BuildNumericVerifyOperator(verify_op, operands, results);
} }
if (auto conv_transpose_bias_op =
dyn_cast<mlir::TFL::Convolution2DTransposeBiasOp>(inst)) {
return BuildConvolution2DTransposeBiasOperator(
inst, conv_transpose_bias_op, operands, results);
}
if (auto max_pooling_with_arg_max_op =
dyn_cast<mlir::TFL::MaxPoolingWithArgMax2DOp>(inst)) {
return BuildMaxPoolingWithArgMax2DOperator(
inst, max_pooling_with_arg_max_op, operands, results);
}
if (auto max_unpooling_op = dyn_cast<mlir::TFL::MaxUnpooling2DOp>(inst)) {
return BuildMaxUnpooling2DOperator(inst, max_unpooling_op, operands,
results);
}
if (auto whileOp = dyn_cast<mlir::TFL::WhileOp>(inst)) {
return BuildWhileOperator(whileOp, operands, results);
}
inst->emitOpError("is not a supported TFLite op"); inst->emitOpError("is not a supported TFLite op");
return llvm::None; return llvm::None;
} }
@ -805,7 +1007,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
// we emit op as flex. // we emit op as flex.
// if custom is enabled // if custom is enabled
// we emit the op as custom. // we emit the op as custom.
auto node_def = getTensorFlowNodeDef(inst); auto node_def = GetTensorFlowNodeDef(inst);
if (!node_def) { if (!node_def) {
return llvm::None; return llvm::None;
} }
@ -904,18 +1106,16 @@ void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) {
bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) { bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) {
std::vector<int> operand_indices; std::vector<int> operand_indices;
// TODO(b/138254427): When the bug is addressed, we'll be able to inspect
// for the presence of a specific OpTrait using mlir::Operation, without
// having to cast it to specific ops like below.
// Until then, when a new RNN/LSTM op is added to TFLite and has stateful
// tensors as operands, they will need to be added here as well.
if (!mlir::TFL::IsStatefulOp(op, &operand_indices)) return false; if (!mlir::TFL::IsStatefulOp(op, &operand_indices)) return false;
return absl::c_find(operand_indices, operand_index) != operand_indices.end(); return absl::c_find(operand_indices, operand_index) != operand_indices.end();
} }
Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) { Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
const std::string& name, Region* region) {
bool has_input_attr = false; bool has_input_attr = false;
InitializeNamesFromAttribute(fn, &has_input_attr); if (auto fn = dyn_cast<FuncOp>(region->getParentOp())) {
InitializeNamesFromAttribute(fn, &has_input_attr);
}
std::vector<BufferOffset<tflite::Tensor>> tensors; std::vector<BufferOffset<tflite::Tensor>> tensors;
llvm::DenseMap<Value, int> tensor_index_map; llvm::DenseMap<Value, int> tensor_index_map;
@ -923,7 +1123,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
// on failure. // on failure.
auto build_tensor_and_buffer = [&](Value value, const std::string& name) { auto build_tensor_and_buffer = [&](Value value, const std::string& name) {
// NoneType represents optional and may be skipped here. // NoneType represents optional and may be skipped here.
if (value->getType().isa<NoneType>()) { if (value.getType().isa<NoneType>()) {
return true; return true;
} }
@ -936,7 +1136,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
// make the Buffer empty apart from setting the buffer_idx=0 in the Tensor. // make the Buffer empty apart from setting the buffer_idx=0 in the Tensor.
// This does not seem to affect runtime behavior for RNN/LSTM, but would be // This does not seem to affect runtime behavior for RNN/LSTM, but would be
// good for reducing memory footprint. // good for reducing memory footprint.
if (auto* inst = value->getDefiningOp()) { if (auto* inst = value.getDefiningOp()) {
auto buffer_or = BuildBuffer(inst); auto buffer_or = BuildBuffer(inst);
if (!buffer_or) return false; if (!buffer_or) return false;
buffers_.push_back(*buffer_or); buffers_.push_back(*buffer_or);
@ -947,7 +1147,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
}; };
std::vector<BufferOffset<tflite::Operator>> operators; std::vector<BufferOffset<tflite::Operator>> operators;
auto& bb = fn.getBlocks().front(); auto& bb = region->front();
// Main function's arguments are first passed to `input` op so they don't // Main function's arguments are first passed to `input` op so they don't
// have associated tensor and buffer. Build FlatBuffer tensor and buffer for // have associated tensor and buffer. Build FlatBuffer tensor and buffer for
@ -955,7 +1155,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
for (unsigned i = 0, e = bb.getNumArguments(); i < e; ++i) { for (unsigned i = 0, e = bb.getNumArguments(); i < e; ++i) {
mlir::BlockArgument arg = bb.getArgument(i); mlir::BlockArgument arg = bb.getArgument(i);
std::string name; std::string name;
if (has_input_attr) name = name_mapper_.GetUniqueName(arg); if (has_input_attr) name = std::string(name_mapper_.GetUniqueName(arg));
if (name.empty()) name = absl::StrCat("arg", i); if (name.empty()) name = absl::StrCat("arg", i);
if (!build_tensor_and_buffer(arg, name)) return llvm::None; if (!build_tensor_and_buffer(arg, name)) return llvm::None;
} }
@ -976,7 +1176,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
std::vector<int32_t> operands; std::vector<int32_t> operands;
operands.reserve(inst.getNumOperands()); operands.reserve(inst.getNumOperands());
for (auto operand : inst.getOperands()) { for (auto operand : inst.getOperands()) {
if (operand->getType().isa<NoneType>()) if (operand.getType().isa<NoneType>())
operands.push_back(kTfLiteOptionalTensor); operands.push_back(kTfLiteOptionalTensor);
else else
operands.push_back(tensor_index_map.lookup(operand)); operands.push_back(tensor_index_map.lookup(operand));
@ -1007,7 +1207,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
return tflite::CreateSubGraph( return tflite::CreateSubGraph(
builder_, builder_.CreateVector(tensors), builder_.CreateVector(inputs), builder_, builder_.CreateVector(tensors), builder_.CreateVector(inputs),
builder_.CreateVector(outputs), builder_.CreateVector(operators), builder_.CreateVector(outputs), builder_.CreateVector(operators),
/*name=*/builder_.CreateString(fn.getName().str())); /*name=*/builder_.CreateString(name));
} }
BufferOffset<tflite::Metadata> Translator::BuildMetadata(StringRef name, BufferOffset<tflite::Metadata> Translator::BuildMetadata(StringRef name,
@ -1050,35 +1250,45 @@ Optional<std::string> Translator::Translate(
} }
Optional<std::string> Translator::TranslateInternal() { Optional<std::string> Translator::TranslateInternal() {
// Create a list of functions in the module with main function being the // A list of named regions in the module with main function being the first in
// first function in the list. This is required as the first subgraph in the // the list. The main function is required as the first subgraph in the model
// model is entry point for the model. // is entry point for the model.
std::vector<FuncOp> functions; std::vector<std::pair<std::string, Region*>> named_regions;
functions.reserve(std::distance(module_.begin(), module_.end())); named_regions.reserve(std::distance(module_.begin(), module_.end()));
int subgraph_idx = 0; int subgraph_idx = 0;
FuncOp main_fn = module_.lookupSymbol<FuncOp>("main"); FuncOp main_fn = module_.lookupSymbol<FuncOp>("main");
subgraph_index_map_[main_fn.getName().str()] = subgraph_idx++; subgraph_index_map_[main_fn.getName().str()] = subgraph_idx++;
functions.push_back(main_fn); named_regions.emplace_back("main", &main_fn.getBody());
for (auto fn : module_.getOps<FuncOp>()) { // Walk over the module collection ops with functions and while ops.
if (fn == main_fn) continue; module_.walk([&](Operation* op) {
if (auto fn = dyn_cast<FuncOp>(op)) {
if (fn != main_fn) {
subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
named_regions.emplace_back(fn.getName().str(), &fn.getBody());
}
} else if (auto wo = dyn_cast<mlir::TFL::WhileOp>(op)) {
std::string name = GetWhileCondName(wo);
subgraph_index_map_[name] = subgraph_idx++;
named_regions.emplace_back(GetWhileCondName(wo), &wo.cond());
name = GetWhileBodyName(wo);
subgraph_index_map_[name] = subgraph_idx++;
named_regions.emplace_back(name, &wo.body());
}
});
subgraph_index_map_[fn.getName().str()] = subgraph_idx++; // Build subgraph for each of the named regions.
functions.push_back(fn);
}
// Build subgraph for each of the functions.
std::vector<BufferOffset<tflite::SubGraph>> subgraphs; std::vector<BufferOffset<tflite::SubGraph>> subgraphs;
subgraphs.reserve(functions.size()); subgraphs.reserve(named_regions.size());
int first_failed_func = -1; int first_failed_func = -1;
for (int i = 0; i < functions.size(); ++i) { for (auto it : llvm::enumerate(named_regions)) {
auto subgraph_or = BuildSubGraph(functions[i]); auto subgraph_or = BuildSubGraph(it.value().first, it.value().second);
if (!subgraph_or) { if (!subgraph_or) {
if (first_failed_func == -1) if (first_failed_func == -1)
// Record the index of the first function that cannot be converted. // Record the index of the first region that cannot be converted.
// Keep looping through all subgraphs in the module to make sure that // Keep looping through all subgraphs in the module to make sure that
// we collect the list of missing ops from the entire module. // we collect the list of missing ops from the entire module.
first_failed_func = i; first_failed_func = it.index();
} else { } else {
subgraphs.push_back(*subgraph_or); subgraphs.push_back(*subgraph_or);
} }
@ -1099,9 +1309,10 @@ Optional<std::string> Translator::TranslateInternal() {
"-emit-custom-ops flag): " + "-emit-custom-ops flag): " +
failed_custom_ops_list; failed_custom_ops_list;
return functions[first_failed_func].emitError("failed while converting: '") auto& failed_region = named_regions[first_failed_func];
<< functions[first_failed_func].getName() << "\'\n" return failed_region.second->getParentOp()->emitError()
<< err, << "failed while converting: '" << failed_region.first
<< "': " << err,
llvm::None; llvm::None;
} }

View File

@ -0,0 +1,58 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the operation interface definition file for TensorFlow Lite.
#ifndef TFL_OP_INTERFACES
#define TFL_OP_INTERFACES
include "mlir/IR/OpBase.td"
//===----------------------------------------------------------------------===//
// TFL op interface for stateful operands.
def TFL_StatefulOp : OpInterface<"StatefulOpInterface"> {
let description = [{
Interface for ops that are stateful and need to identify stateful operands.
Stateful operands correspond to TF's variables semantics. An op that has 1
or more stateful operands is a stateful op.
}];
let methods = [
InterfaceMethod<
[{Returns the indices of stateful operands.}],
"std::vector<int>", "GetStatefulOperands", (ins)
>,
];
}
//===----------------------------------------------------------------------===//
// TFL op interface for output channel index.
def TFL_ChannelDimIndexInterface : OpInterface<"ChannelDimIndexInterface"> {
let description = [{
Interface for defining the index of out channel index.
}];
let methods = [
InterfaceMethod<
[{Returns the dimension index of the output channels.}],
"int", "GetChannelDimIndex", (ins)
>,
];
}
#endif // TFL_OP_INTERFACES

View File

@ -304,11 +304,11 @@ Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
void buildComparisonBinOp(Builder *builder, OperationState &result, Value lhs, void buildComparisonBinOp(Builder *builder, OperationState &result, Value lhs,
Value rhs) { Value rhs) {
auto result_type = auto result_type =
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType()); OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
if (!result_type) if (!result_type)
emitError(result.location) emitError(result.location)
<< "non-broadcastable operands: " << lhs->getType() << " and " << "non-broadcastable operands: " << lhs.getType() << " and "
<< rhs->getType(); << rhs.getType();
result.addOperands({lhs, rhs}); result.addOperands({lhs, rhs});
// Comparison binary ops always return i1 tensor. // Comparison binary ops always return i1 tensor.
if (auto shaped_type = result_type.dyn_cast<RankedTensorType>()) { if (auto shaped_type = result_type.dyn_cast<RankedTensorType>()) {
@ -324,12 +324,12 @@ void buildFusedBroadcastableBinOp(Builder *builder, OperationState &result,
Value lhs, Value rhs, Value lhs, Value rhs,
StringAttr fused_activation_function) { StringAttr fused_activation_function) {
auto result_type = auto result_type =
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType()); OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
if (!result_type) if (!result_type)
emitError(result.location) emitError(result.location)
<< "non-broadcastable operands: " << lhs->getType() << " and " << "non-broadcastable operands: " << lhs.getType() << " and "
<< rhs->getType(); << rhs.getType();
result.addOperands({lhs, rhs}); result.addOperands({lhs, rhs});
result.addAttribute("fused_activation_function", fused_activation_function); result.addAttribute("fused_activation_function", fused_activation_function);
@ -358,7 +358,7 @@ OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
namespace { namespace {
int64_t GetConcatenationOpAxis(ConcatenationOp op) { int64_t GetConcatenationOpAxis(ConcatenationOp op) {
auto output_type = op.output()->getType().cast<RankedTensorType>(); auto output_type = op.output().getType().cast<RankedTensorType>();
int64_t axis = op.axis().getSExtValue(); int64_t axis = op.axis().getSExtValue();
if (axis < 0) axis += output_type.getRank(); if (axis < 0) axis += output_type.getRank();
return axis; return axis;
@ -452,7 +452,7 @@ LogicalResult VerifyConcatenationOpTypes(Operation *op,
} }
LogicalResult Verify(ConcatenationOp op) { LogicalResult Verify(ConcatenationOp op) {
auto output_type = op.output()->getType().dyn_cast<RankedTensorType>(); auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
// If the output type is unranked, there is nothing else to be verified. // If the output type is unranked, there is nothing else to be verified.
if (!output_type) return success(); if (!output_type) return success();
@ -463,7 +463,7 @@ LogicalResult Verify(ConcatenationOp op) {
SmallVector<TensorType, 4> operand_types; SmallVector<TensorType, 4> operand_types;
for (Value operand : op.values()) for (Value operand : op.values())
operand_types.push_back(operand->getType().cast<TensorType>()); operand_types.push_back(operand.getType().cast<TensorType>());
return VerifyConcatenationOpTypes(op.getOperation(), output_type, return VerifyConcatenationOpTypes(op.getOperation(), output_type,
operand_types, axis); operand_types, axis);
@ -520,7 +520,7 @@ DenseElementsAttr ConstFoldConcatenateOpDense(ArrayRef<Attribute> operands,
OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) { OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
if (fused_activation_function() == "NONE") { if (fused_activation_function() == "NONE") {
if (auto output_type = output()->getType().dyn_cast<RankedTensorType>()) { if (auto output_type = output().getType().dyn_cast<RankedTensorType>()) {
const int64_t axis = GetConcatenationOpAxis(*this); const int64_t axis = GetConcatenationOpAxis(*this);
if (IsConcatenationOpConstFoldable(*this, operands, output_type, axis)) if (IsConcatenationOpConstFoldable(*this, operands, output_type, axis))
return ConstFoldConcatenateOpDense(operands, output_type, axis); return ConstFoldConcatenateOpDense(operands, output_type, axis);
@ -530,7 +530,7 @@ OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
// Remove all empty values. // Remove all empty values.
SmallVector<Value, 4> non_empty_values; SmallVector<Value, 4> non_empty_values;
for (Value value : this->values()) { for (Value value : this->values()) {
const auto shaped_type = value->getType().cast<ShapedType>(); const auto shaped_type = value.getType().cast<ShapedType>();
if (shaped_type.hasStaticShape() && shaped_type.getNumElements() == 0) { if (shaped_type.hasStaticShape() && shaped_type.getNumElements() == 0) {
continue; continue;
} }
@ -559,8 +559,8 @@ OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
LogicalResult Verify(FullyConnectedOp op) { LogicalResult Verify(FullyConnectedOp op) {
ShapedType input_type = op.input()->getType().cast<ShapedType>(); ShapedType input_type = op.input().getType().cast<ShapedType>();
ShapedType filter_type = op.filter()->getType().cast<ShapedType>(); ShapedType filter_type = op.filter().getType().cast<ShapedType>();
if (filter_type.hasRank() && filter_type.getRank() != 2) { if (filter_type.hasRank() && filter_type.getRank() != 2) {
return op.emitOpError("expect 2d filter, got ") << filter_type; return op.emitOpError("expect 2d filter, got ") << filter_type;
} }
@ -582,7 +582,7 @@ LogicalResult Verify(FullyConnectedOp op) {
// format. // format.
if (op.weights_format() == "DEFAULT") { if (op.weights_format() == "DEFAULT") {
ShapedType output_type = ShapedType output_type =
(*op.output().begin())->getType().cast<ShapedType>(); (*op.output().begin()).getType().cast<ShapedType>();
if (!output_type.hasStaticShape()) { if (!output_type.hasStaticShape()) {
return mlir::success(); return mlir::success();
} }
@ -610,8 +610,8 @@ LogicalResult Verify(FullyConnectedOp op) {
static void BuildGatherOp(Builder *builder, OperationState &result, static void BuildGatherOp(Builder *builder, OperationState &result,
Value params, Value indices, IntegerAttr axis) { Value params, Value indices, IntegerAttr axis) {
auto params_type = params->getType().cast<TensorType>(); auto params_type = params.getType().cast<TensorType>();
auto indices_type = indices->getType().cast<TensorType>(); auto indices_type = indices.getType().cast<TensorType>();
// If params/indices is unranked, then output is unranked. // If params/indices is unranked, then output is unranked.
if (!params_type.hasRank() || !indices_type.hasRank()) if (!params_type.hasRank() || !indices_type.hasRank())
@ -705,7 +705,7 @@ static LogicalResult Verify(PackOp op) {
return op.emitOpError("input count should match 'values_count' attribute"); return op.emitOpError("input count should match 'values_count' attribute");
Value operand0 = op.getOperand(0); Value operand0 = op.getOperand(0);
auto input_type = operand0->getType().cast<ShapedType>(); auto input_type = operand0.getType().cast<ShapedType>();
// Check axis bounds. // Check axis bounds.
if (input_type.hasRank()) { if (input_type.hasRank()) {
@ -718,7 +718,7 @@ static LogicalResult Verify(PackOp op) {
// Make sure all inputs have the same shape and element type. // Make sure all inputs have the same shape and element type.
// TODO(rahulsp): Simplify once b/135032064 is fixed. // TODO(rahulsp): Simplify once b/135032064 is fixed.
for (Value operand : op.getOperands()) { for (Value operand : op.getOperands()) {
auto other_type = operand->getType().cast<ShapedType>(); auto other_type = operand.getType().cast<ShapedType>();
if (input_type != other_type) if (input_type != other_type)
return op.emitOpError("operands should be of the same type. got ") return op.emitOpError("operands should be of the same type. got ")
<< input_type << ", " << other_type; << input_type << ", " << other_type;
@ -732,9 +732,9 @@ static LogicalResult Verify(PackOp op) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult Verify(PReluOp op) { static LogicalResult Verify(PReluOp op) {
auto input_type = op.input()->getType().cast<ShapedType>(); auto input_type = op.input().getType().cast<ShapedType>();
auto alpha_type = op.alpha()->getType().cast<ShapedType>(); auto alpha_type = op.alpha().getType().cast<ShapedType>();
auto output_type = op.output()->getType().cast<ShapedType>(); auto output_type = op.output().getType().cast<ShapedType>();
if (input_type.hasStaticShape() && alpha_type.hasStaticShape()) { if (input_type.hasStaticShape() && alpha_type.hasStaticShape()) {
if (input_type.getRank() != alpha_type.getRank() + 1) { if (input_type.getRank() != alpha_type.getRank() + 1) {
@ -783,13 +783,13 @@ struct RemoveAdjacentReshape : public RewritePattern {
PatternMatchResult match(Operation *op) const override { PatternMatchResult match(Operation *op) const override {
auto thisOp = cast<ReshapeOp>(op); auto thisOp = cast<ReshapeOp>(op);
auto prevOp = thisOp.getOperand(0)->getDefiningOp(); auto prevOp = thisOp.getOperand(0).getDefiningOp();
return isa_and_nonnull<ReshapeOp>(prevOp) ? matchSuccess() : matchFailure(); return isa_and_nonnull<ReshapeOp>(prevOp) ? matchSuccess() : matchFailure();
} }
void rewrite(Operation *op, PatternRewriter &rewriter) const override { void rewrite(Operation *op, PatternRewriter &rewriter) const override {
auto thisOp = cast<ReshapeOp>(op); auto thisOp = cast<ReshapeOp>(op);
auto prevOp = cast<ReshapeOp>(thisOp.getOperand(0)->getDefiningOp()); auto prevOp = cast<ReshapeOp>(thisOp.getOperand(0).getDefiningOp());
// Replace // Replace
// %1 = "tfl.reshape"(%0, %shape0) // %1 = "tfl.reshape"(%0, %shape0)
@ -797,8 +797,7 @@ struct RemoveAdjacentReshape : public RewritePattern {
// With // With
// %2 = "tfl.reshape"(%0, %shape1) // %2 = "tfl.reshape"(%0, %shape1)
rewriter.replaceOpWithNewOp<ReshapeOp>( rewriter.replaceOpWithNewOp<ReshapeOp>(
{prevOp.getResult()}, op, thisOp.getType(), prevOp.getOperand(0), op, thisOp.getType(), prevOp.getOperand(0), thisOp.getOperand(1));
thisOp.getOperand(1));
} }
}; };
@ -807,7 +806,7 @@ struct RemoveAdjacentReshape : public RewritePattern {
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) { OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
// Remove identity reshape with both static result and input shape. // Remove identity reshape with both static result and input shape.
auto result_type = getType().cast<ShapedType>(); auto result_type = getType().cast<ShapedType>();
auto input_type = getOperand(0)->getType().cast<ShapedType>(); auto input_type = getOperand(0).getType().cast<ShapedType>();
if (result_type.hasStaticShape() && result_type == input_type) { if (result_type.hasStaticShape() && result_type == input_type) {
return getOperand(0); return getOperand(0);
} }
@ -865,7 +864,7 @@ struct RemoveRedundantUnpackPack : public RewritePattern {
PatternMatchResult matchAndRewrite(Operation *op, PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
TFL::PackOp pack_op = cast<TFL::PackOp>(op); TFL::PackOp pack_op = cast<TFL::PackOp>(op);
Operation *first_input = pack_op.getOperand(0)->getDefiningOp(); Operation *first_input = pack_op.getOperand(0).getDefiningOp();
if (!first_input) return matchFailure(); if (!first_input) return matchFailure();
auto input_unpack_op = dyn_cast_or_null<TFL::UnpackOp>(first_input); auto input_unpack_op = dyn_cast_or_null<TFL::UnpackOp>(first_input);
if (!input_unpack_op) return matchFailure(); if (!input_unpack_op) return matchFailure();
@ -905,9 +904,9 @@ void PackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult Verify(SliceOp op) { static LogicalResult Verify(SliceOp op) {
auto input_type = op.input()->getType().cast<ShapedType>(); auto input_type = op.input().getType().cast<ShapedType>();
auto begin_type = op.begin()->getType().cast<ShapedType>(); auto begin_type = op.begin().getType().cast<ShapedType>();
auto size_type = op.size()->getType().cast<ShapedType>(); auto size_type = op.size().getType().cast<ShapedType>();
if (input_type.hasStaticShape() && begin_type.hasStaticShape() && if (input_type.hasStaticShape() && begin_type.hasStaticShape() &&
size_type.hasStaticShape()) { size_type.hasStaticShape()) {
if (input_type.getRank() != begin_type.getNumElements()) { if (input_type.getRank() != begin_type.getNumElements()) {
@ -995,7 +994,7 @@ static void BuildTopKOp(Builder *builder, OperationState &result, Value input,
// TODO(jpienaar): This should use a helper function. // TODO(jpienaar): This should use a helper function.
const_k = cst.getValue<IntegerAttr>({}).getValue().getSExtValue(); const_k = cst.getValue<IntegerAttr>({}).getValue().getSExtValue();
auto val_type = input->getType().cast<TensorType>(); auto val_type = input.getType().cast<TensorType>();
// If value is unranked, then so is results. // If value is unranked, then so is results.
if (!val_type.hasRank()) if (!val_type.hasRank())
return TFL::TopKV2Op::build( return TFL::TopKV2Op::build(
@ -1035,7 +1034,7 @@ struct DropFakeQuant : public RewritePattern {
// If all the users of this op have valid "minmax" attributes, it is matched // If all the users of this op have valid "minmax" attributes, it is matched
// and can be removed. // and can be removed.
auto fakeQuantOp = cast<FakeQuantOp>(op); auto fakeQuantOp = cast<FakeQuantOp>(op);
for (auto *operand : fakeQuantOp.getResult()->getUsers()) for (auto *operand : fakeQuantOp.getResult().getUsers())
if (!HasValidMinMaxAttribute(operand)) return matchFailure(); if (!HasValidMinMaxAttribute(operand)) return matchFailure();
return matchSuccess(); return matchSuccess();
@ -1102,7 +1101,7 @@ static LogicalResult VerifySplitOpOutputTypes(
for (int64_t i = 0; i < num_splits; ++i) { for (int64_t i = 0; i < num_splits; ++i) {
auto expected_output_type = get_expected_output_type(i); auto expected_output_type = get_expected_output_type(i);
Value output = op->getResult(i); Value output = op->getResult(i);
auto output_type = output->getType().dyn_cast<RankedTensorType>(); auto output_type = output.getType().dyn_cast<RankedTensorType>();
if (!output_type || output_type != expected_output_type) if (!output_type || output_type != expected_output_type)
return op->emitOpError() return op->emitOpError()
<< "output #" << i << " should be " << expected_output_type; << "output #" << i << " should be " << expected_output_type;
@ -1121,7 +1120,7 @@ static LogicalResult Verify(SplitOp op) {
if (!split_dim_opt) return success(); if (!split_dim_opt) return success();
// If 'input' is not a ranked tensor, there are no other checks. // If 'input' is not a ranked tensor, there are no other checks.
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>(); auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
if (!input_type) return success(); if (!input_type) return success();
int64_t split_dim = split_dim_opt.getValue(); int64_t split_dim = split_dim_opt.getValue();
@ -1157,7 +1156,7 @@ static LogicalResult Verify(SplitVOp op) {
if (!split_dim_opt) return success(); if (!split_dim_opt) return success();
// If 'input' is not a ranked tensor, there are no other checks. // If 'input' is not a ranked tensor, there are no other checks.
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>(); auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
if (!input_type) return success(); if (!input_type) return success();
int64_t split_dim = split_dim_opt.getValue(); int64_t split_dim = split_dim_opt.getValue();
@ -1177,8 +1176,7 @@ static LogicalResult Verify(SplitVOp op) {
return success(); return success();
if (size_splits_attr.getNumElements() != num_splits) { if (size_splits_attr.getNumElements() != num_splits) {
auto size_splits_type = auto size_splits_type = op.size_splits().getType().cast<RankedTensorType>();
op.size_splits()->getType().cast<RankedTensorType>();
RankedTensorType expected_size_splits_type = RankedTensorType expected_size_splits_type =
RankedTensorType::get({num_splits}, size_splits_type.getElementType()); RankedTensorType::get({num_splits}, size_splits_type.getElementType());
return op.emitOpError("'size_splits' should be ") return op.emitOpError("'size_splits' should be ")
@ -1303,6 +1301,19 @@ OpFoldResult AbsOp::fold(ArrayRef<Attribute> operands) {
return ConstFoldUnaryOp(result_type, operands[0], compute); return ConstFoldUnaryOp(result_type, operands[0], compute);
} }
//===----------------------------------------------------------------------===//
// NegOp
//===----------------------------------------------------------------------===//
OpFoldResult NegOp::fold(ArrayRef<Attribute> operands) {
Type result_type = getType();
// Only constant fold for tensor of f32 is implemented.
if (!IsF32ShapedType(result_type)) return nullptr;
auto compute = [](APFloat value) -> APFloat { return llvm::neg(value); };
return ConstFoldUnaryOp(result_type, operands[0], compute);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// SinOp // SinOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1414,7 +1425,7 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
} }
// Also fold if `input` has a known rank. // Also fold if `input` has a known rank.
auto input_type = input()->getType().cast<ShapedType>(); auto input_type = input().getType().cast<ShapedType>();
// Do not fold if rank is zero because the TFLite converter doesn't // Do not fold if rank is zero because the TFLite converter doesn't
// distinguish between unranked input and scalar input due to b/138865275. // distinguish between unranked input and scalar input due to b/138865275.
// TODO(b/138865275): Remove `input_type.getRank() != 0` in the following // TODO(b/138865275): Remove `input_type.getRank() != 0` in the following
@ -1445,18 +1456,18 @@ OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
static void BuildSelectV2Op(Builder *builder, OperationState &result, static void BuildSelectV2Op(Builder *builder, OperationState &result,
Value cond, Value x, Value y) { Value cond, Value x, Value y) {
auto operand_type = auto operand_type =
OpTrait::util::getBroadcastedType(x->getType(), y->getType()); OpTrait::util::getBroadcastedType(x.getType(), y.getType());
if (!operand_type) if (!operand_type)
emitError(result.location) << "non-broadcastable operands: " << x->getType() emitError(result.location) << "non-broadcastable operands: " << x.getType()
<< " and " << y->getType(); << " and " << y.getType();
bool has_static_cond_shape = false; bool has_static_cond_shape = false;
bool has_static_operand_shape = false; bool has_static_operand_shape = false;
ArrayRef<int64_t> cond_shape; ArrayRef<int64_t> cond_shape;
ArrayRef<int64_t> operand_shape; ArrayRef<int64_t> operand_shape;
if (auto shaped_type = cond->getType().dyn_cast<ShapedType>()) { if (auto shaped_type = cond.getType().dyn_cast<ShapedType>()) {
if (shaped_type.hasStaticShape()) { if (shaped_type.hasStaticShape()) {
has_static_cond_shape = true; has_static_cond_shape = true;
cond_shape = shaped_type.getShape(); cond_shape = shaped_type.getShape();
@ -1474,12 +1485,12 @@ static void BuildSelectV2Op(Builder *builder, OperationState &result,
!OpTrait::util::getBroadcastedShape(cond_shape, operand_shape, !OpTrait::util::getBroadcastedShape(cond_shape, operand_shape,
broadcastedShape)) { broadcastedShape)) {
emitError(result.location) << "non-broadcastable operands: " << operand_type emitError(result.location) << "non-broadcastable operands: " << operand_type
<< " and " << cond->getType(); << " and " << cond.getType();
} }
result.addOperands({cond, x, y}); result.addOperands({cond, x, y});
auto elementType = x->getType().dyn_cast<ShapedType>().getElementType(); auto elementType = x.getType().dyn_cast<ShapedType>().getElementType();
if (has_static_cond_shape && has_static_operand_shape) { if (has_static_cond_shape && has_static_operand_shape) {
result.types.push_back( result.types.push_back(
RankedTensorType::get(broadcastedShape, elementType)); RankedTensorType::get(broadcastedShape, elementType));
@ -1571,9 +1582,8 @@ OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult Verify(TransposeConvOp op) { static LogicalResult Verify(TransposeConvOp op) {
ShapedType output_type = op.output()->getType().cast<ShapedType>(); ShapedType output_type = op.output().getType().cast<ShapedType>();
ShapedType output_shape_type = ShapedType output_shape_type = op.output_shape().getType().cast<ShapedType>();
op.output_shape()->getType().cast<ShapedType>();
if (output_type.hasRank() && output_shape_type.hasStaticShape()) { if (output_type.hasRank() && output_shape_type.hasStaticShape()) {
if (output_type.getRank() != output_shape_type.getDimSize(0)) { if (output_type.getRank() != output_shape_type.getDimSize(0)) {
return op.emitOpError(llvm::formatv( return op.emitOpError(llvm::formatv(
@ -1679,9 +1689,9 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
} }
static LogicalResult Verify(TransposeOp op) { static LogicalResult Verify(TransposeOp op) {
auto input_type = op.x()->getType().cast<ShapedType>(); auto input_type = op.x().getType().cast<ShapedType>();
auto perm_type = op.perm()->getType().cast<ShapedType>(); auto perm_type = op.perm().getType().cast<ShapedType>();
auto output_type = op.y()->getType().cast<ShapedType>(); auto output_type = op.y().getType().cast<ShapedType>();
if (input_type.hasStaticShape() && perm_type.hasStaticShape()) { if (input_type.hasStaticShape() && perm_type.hasStaticShape()) {
if (perm_type.getNumElements() != input_type.getRank()) { if (perm_type.getNumElements() != input_type.getRank()) {
return op.emitOpError( return op.emitOpError(
@ -1726,10 +1736,25 @@ static LogicalResult Verify(TransposeOp op) {
return success(); return success();
} }
Region &WhileOp::getLoopBody() { return body(); }
bool WhileOp::isDefinedOutsideOfLoop(Value value) {
// TODO(jpienaar): This is to overly conservative and disables anything other
// than constant hoisting initially.
return false;
}
LogicalResult WhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *>) {
// TODO(jpienaar): Fail any hoisting until post test case and refining
// isDefinedOutsideOfLoop.
return failure();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TableGen'd op method definitions // TableGen'd op method definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.cc.inc"
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"

View File

@ -27,7 +27,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // TF:llvm-project #include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project #include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project #include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h" #include "mlir/Transforms/LoopLikeInterface.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h"
@ -44,6 +44,7 @@ class TensorFlowLiteDialect : public Dialect {
Location loc) override; Location loc) override;
}; };
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.h.inc"
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc"

View File

@ -19,6 +19,8 @@ limitations under the License.
#define TFL_OPS #define TFL_OPS
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
include "mlir/Transforms/LoopLikeInterface.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td"
include "tensorflow/compiler/mlir/lite/quantization/quantization.td" include "tensorflow/compiler/mlir/lite/quantization/quantization.td"
def TFL_Dialect : Dialect { def TFL_Dialect : Dialect {
@ -135,7 +137,7 @@ def TFL_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TFL_Int32Or64]>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
class TFL_OperandIsUnrankedPred<int n> : class TFL_OperandIsUnrankedPred<int n> :
CPred<"$_op.getOperand(" # n # ")->getType().isa<UnrankedTensorType>()">; CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">;
// TODO: Some of these could be generalized and/or moved to more general // TODO: Some of these could be generalized and/or moved to more general
// location. // location.
@ -144,38 +146,38 @@ class TFL_OperandHasRank<int n, int m> :
PredOpTrait<"operand " # n # " is " # m # "-D", PredOpTrait<"operand " # n # " is " # m # "-D",
Or<[TFL_OperandIsUnrankedPred<n>, Or<[TFL_OperandIsUnrankedPred<n>,
CPred<"$_op.getOperand(" # n # CPred<"$_op.getOperand(" # n #
")->getType().cast<ShapedType>().getRank() == " # m>]>>; ").getType().cast<ShapedType>().getRank() == " # m>]>>;
// Returns true if the n-th operand is ranked and has rank dim. // Returns true if the n-th operand is ranked and has rank dim.
class TFL_OperandHasKnownRank<int n, int dim> : And<[ class TFL_OperandHasKnownRank<int n, int dim> : And<[
CPred<"$_op.getOperand(" # n # ")->getType().isa<RankedTensorType>()">, CPred<"$_op.getOperand(" # n # ").getType().isa<RankedTensorType>()">,
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>().getRank() == " CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>().getRank() == "
# dim>]>; # dim>]>;
// True if operand n is ranked and has a rank > dim. // True if operand n is ranked and has a rank > dim.
class TFL_OperandIsRankedAndHasDimPred<int n, int dim> : And<[ class TFL_OperandIsRankedAndHasDimPred<int n, int dim> : And<[
CPred<"$_op.getOperand(" # n # ")->getType().isa<RankedTensorType>()">, CPred<"$_op.getOperand(" # n # ").getType().isa<RankedTensorType>()">,
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>().getRank() > " CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>().getRank() > "
# dim>]>; # dim>]>;
class TFL_OperandDimEquals<int n, int dim, int size> : And<[ class TFL_OperandDimEquals<int n, int dim, int size> : And<[
TFL_OperandIsRankedAndHasDimPred<n, dim>, TFL_OperandIsRankedAndHasDimPred<n, dim>,
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>()" CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>()"
".getShape()[" # dim # " ] == " # size>]>; ".getShape()[" # dim # " ] == " # size>]>;
// Returns true if the n-th operand has unknown rank or at least rank m. // Returns true if the n-th operand has unknown rank or at least rank m.
class TFL_OperandHasAtleastRank<int n, int m> : class TFL_OperandHasAtleastRank<int n, int m> :
PredOpTrait<"operand " # n # " is " # m # "-D", PredOpTrait<"operand " # n # " is " # m # "-D",
Or<[CPred<"$_op.getOperand(" # n # ")->getType().isa<UnrankedTensorType>()">, Or<[CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">,
CPred<"$_op.getOperand(" # n # CPred<"$_op.getOperand(" # n #
")->getType().cast<ShapedType>().getRank() >= " # m>]>>; ").getType().cast<ShapedType>().getRank() >= " # m>]>>;
class TFL_OperandRankEquals1DimOfOperand<int x, int y> : class TFL_OperandRankEquals1DimOfOperand<int x, int y> :
PredOpTrait<"operand " # x # "'s rank equals operand " # y # "'s size", PredOpTrait<"operand " # x # "'s rank equals operand " # y # "'s size",
CPred<"$_op.getOperand(" # x # CPred<"$_op.getOperand(" # x #
")->getType().cast<ShapedType>().getRank() == " ").getType().cast<ShapedType>().getRank() == "
"$_op.getOperand(" # y # "$_op.getOperand(" # y #
")->getType().cast<ShapedType>().getShape()[0]">>; ").getType().cast<ShapedType>().getShape()[0]">>;
class TFL_Operand0DOr1ElementTensor<int x> : class TFL_Operand0DOr1ElementTensor<int x> :
PredOpTrait<"operand #" # x # " is an 0-d tensor or 1-d tensor w/ 1 element", PredOpTrait<"operand #" # x # " is an 0-d tensor or 1-d tensor w/ 1 element",
@ -195,7 +197,7 @@ class TFL_OperandHasRankLessThan<int n, int m> :
PredOpTrait<"operand " # n # " is maximum " # m # "-D", PredOpTrait<"operand " # n # " is maximum " # m # "-D",
Or<[TFL_OperandIsUnrankedPred<n>, Or<[TFL_OperandIsUnrankedPred<n>,
CPred<"$_op.getOperand(" # n # CPred<"$_op.getOperand(" # n #
")->getType().cast<ShapedType>().getRank() <= " # m>]>>; ").getType().cast<ShapedType>().getRank() <= " # m>]>>;
// This is a quantization-aware version of TCresVTEtIsSameAsOp // This is a quantization-aware version of TCresVTEtIsSameAsOp
class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[ class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[
@ -227,7 +229,7 @@ def TFL_BroadcastableBinaryBuilder : OpBuilder<
"Builder *builder, OperationState &result, Value lhs, Value rhs", "Builder *builder, OperationState &result, Value lhs, Value rhs",
[{ [{
auto resultType = auto resultType =
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType()); OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
if (!resultType) if (!resultType)
mlir::emitError(result.location, "non-broadcastable operands"); mlir::emitError(result.location, "non-broadcastable operands");
result.addOperands({lhs, rhs}); result.addOperands({lhs, rhs});
@ -248,16 +250,6 @@ def TFL_ComparisonBinaryBuilder : OpBuilder<
buildComparisonBinOp(builder, result, lhs, rhs); buildComparisonBinOp(builder, result, lhs, rhs);
}]>; }]>;
//===----------------------------------------------------------------------===//
// TFL native op trait for stateful operands and channel indices.
class StatefulOperands<list<int> operands>
: ParamNativeOpTrait<"TFL::StatefulOperands", StrJoinInt<operands>.result>;
class ChannelDimIndex<int index>
: ParamNativeOpTrait<"TFL::ChannelDimIndex", !cast<string>(index)>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TFL op base class. // TFL op base class.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -285,7 +277,7 @@ class TFL_Op<string mnemonic, list<OpTrait> traits = []> :
class TFL_ConvOp<string mnemonic, string opSummary, int index> : class TFL_ConvOp<string mnemonic, string opSummary, int index> :
TFL_Op<mnemonic, [NoSideEffect, AccumulatorUniformScale<2, 0, 1>, TFL_Op<mnemonic, [NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
ChannelDimIndex<index>, AffineOpCoefficient<index, 1>]> { TFL_ChannelDimIndexInterface, AffineOpCoefficient<index, 1>]> {
let summary = opSummary # " operator"; let summary = opSummary # " operator";
let description = [{ let description = [{
@ -335,7 +327,7 @@ an output element, this operation computes \\(y = |x|\\).
let hasFolder = 1; let hasFolder = 1;
} }
def TFL_AddOp : TFL_Op<"add", [Broadcastable, NoSideEffect, Commutative]> { def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape, NoSideEffect, Commutative]> {
let summary = "Addition operator"; let summary = "Addition operator";
let description = [{ let description = [{
@ -427,6 +419,33 @@ def TFL_TransposeConvOp:
let verifier = [{ return Verify(*this); }]; let verifier = [{ return Verify(*this); }];
} }
def TFL_Convolution2DTransposeBiasOp :
Op<TFL_Dialect, "convolution_2d_transpose_bias", [NoSideEffect]> {
let summary = " Transpose convolution with bias operator";
let description = [{
Performs transpose convolution operation on inputs,
with the option of adding a bias.
Note this is a custom op that is not supported in the standard runtime.
Inputs:
`inputs[0]`: required: the input activation tensor
`inputs[1]`: required: the filter weight tensor
`inputs[2]`: optional: the bias tensor
}];
let arguments = (
ins AnyTensor:$input,
AnyTensor:$filter,
TFL_TensorOfOrNone<[AnyType]>:$bias,
TFL_PaddingAttr:$padding,
I32Attr:$stride_h,
I32Attr:$stride_w
);
let results = (outs AnyTensor:$output);
}
def TFL_AveragePool2DOp: def TFL_AveragePool2DOp:
TFL_Op<"average_pool_2d", [NoSideEffect, SameOperandsAndResultsScale]> { TFL_Op<"average_pool_2d", [NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Average_pool_2d operator"; let summary = "Average_pool_2d operator";
@ -459,8 +478,7 @@ def TFL_ArgMaxOp : TFL_Op<"arg_max", [NoSideEffect]> {
}]; }];
let arguments = ( let arguments = (
// TODO: Add support for uint8. ins TensorOf<[F32, I32, I8, TFL_Uint8, QI8, QUI8]>:$input,
ins TensorOf<[F32, I32, I8]>:$input,
TFL_I32OrI64Tensor:$dim TFL_I32OrI64Tensor:$dim
); );
@ -471,7 +489,7 @@ def TFL_ArgMaxOp : TFL_Op<"arg_max", [NoSideEffect]> {
let hasOptions = 1; let hasOptions = 1;
DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{ DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{
return getResult()->getType().cast<TensorType>().getElementType(). return getResult().getType().cast<TensorType>().getElementType().
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 : cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
tflite::TensorType_INT32; tflite::TensorType_INT32;
}]>; }]>;
@ -488,8 +506,7 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> {
}]; }];
let arguments = ( let arguments = (
// TODO(pkanwar): Add support for uint8. ins TensorOf<[F32, I32, I8, TFL_Uint8, QI8, QUI8]>:$input,
ins TensorOf<[F32, I32, I8]>:$input,
TFL_I32OrI64Tensor:$dim TFL_I32OrI64Tensor:$dim
); );
@ -500,7 +517,7 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> {
let hasOptions = 1; let hasOptions = 1;
DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{ DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{
return getResult()->getType().cast<TensorType>().getElementType(). return getResult().getType().cast<TensorType>().getElementType().
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 : cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
tflite::TensorType_INT32; tflite::TensorType_INT32;
}]>; }]>;
@ -590,7 +607,12 @@ def TFL_ExternalConstOp : Op<TFL_Dialect, "external_const", [NoSideEffect]> {
let results = (outs AnyTensor:$output); let results = (outs AnyTensor:$output);
} }
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0>; def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0> {
let extraClassDeclaration = [{
// StatefulOpInterface:
int GetChannelDimIndex() { return 0; }
}];
}
def TFL_CosOp: TFL_Op<"cos", [ def TFL_CosOp: TFL_Op<"cos", [
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> { NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
@ -610,6 +632,11 @@ def TFL_CosOp: TFL_Op<"cos", [
def TFL_DepthwiseConv2DOp : def TFL_DepthwiseConv2DOp :
TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> { TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> {
let arguments = !con(TFL_Conv2DOp.arguments, (ins I32Attr:$depth_multiplier)); let arguments = !con(TFL_Conv2DOp.arguments, (ins I32Attr:$depth_multiplier));
let extraClassDeclaration = [{
// StatefulOpInterface:
int GetChannelDimIndex() { return 3; }
}];
} }
def TFL_FCWO_Default : StrEnumAttrCase<"DEFAULT">; def TFL_FCWO_Default : StrEnumAttrCase<"DEFAULT">;
@ -623,7 +650,8 @@ def TFL_FullyConnectedOptionsWeightFormatAttr :
// TODO(jpienaar): Update post discussion on semantics of FC OP. // TODO(jpienaar): Update post discussion on semantics of FC OP.
def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
NoSideEffect, AccumulatorUniformScale<2, 0, 1>, ChannelDimIndex<0>, NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
TFL_ChannelDimIndexInterface,
AffineOpCoefficient<-1, 1>]> { AffineOpCoefficient<-1, 1>]> {
let summary = "Fully connected op"; let summary = "Fully connected op";
@ -645,6 +673,11 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
let verifier = [{ return Verify(*this); }]; let verifier = [{ return Verify(*this); }];
let hasOptions = 1; let hasOptions = 1;
let extraClassDeclaration = [{
// ChannelDimIndexInterface:
int GetChannelDimIndex() { return 0; }
}];
} }
def TFL_GatherOp : TFL_Op<"gather", [ def TFL_GatherOp : TFL_Op<"gather", [
@ -652,7 +685,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
SameOperandsAndResultsScale, SameOperandsAndResultsScale,
TFL_OperandHasAtleastRank<0, 1>, TFL_OperandHasAtleastRank<0, 1>,
PredOpTrait<"params and output must have same element type", PredOpTrait<"params and output must have same element type",
TCresVTEtIsSameAsOp<0, 0>> TFL_TCresVTEtIsSameAsOp<0, 0>>
]> { ]> {
let summary = "Gather operator"; let summary = "Gather operator";
@ -661,7 +694,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
}]; }];
let arguments = (ins let arguments = (ins
TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8]>:$params, TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8, QI16]>:$params,
TensorOf<[I32, I64]>:$indices, TensorOf<[I32, I64]>:$indices,
I32Attr:$axis I32Attr:$axis
); );
@ -674,7 +707,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
]; ];
let results = (outs let results = (outs
TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8]>:$output TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8, QI16]>:$output
); );
let hasOptions = 1; let hasOptions = 1;
@ -697,9 +730,9 @@ def TFL_GatherNdOp : TFL_Op<"gather_nd", [NoSideEffect]> {
); );
} }
// Same type check of lhs and rhs is handled by the Broadcastable trait. // Same type check of lhs and rhs is handled by the ResultsBroadcastableShape trait.
def TFL_LessEqualOp : TFL_Op<"less_equal", [ def TFL_LessEqualOp : TFL_Op<"less_equal", [
Broadcastable, NoSideEffect, NoQuantizableResult]> { ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
let summary = "Less_equal operator"; let summary = "Less_equal operator";
let description = [{ let description = [{
@ -755,7 +788,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag
} }
def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [ def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [
Broadcastable, NoSideEffect, NoQuantizableResult]> { ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
let summary = "Greater_equal operator"; let summary = "Greater_equal operator";
let description = [{ let description = [{
@ -916,7 +949,7 @@ larger than 0.
} }
def TFL_NotEqualOp : TFL_Op<"not_equal", [ def TFL_NotEqualOp : TFL_Op<"not_equal", [
Broadcastable, Commutative, NoSideEffect, NoQuantizableResult]> { ResultsBroadcastableShape, Commutative, NoSideEffect, NoQuantizableResult]> {
let summary = "Not_equal operator"; let summary = "Not_equal operator";
let description = [{ let description = [{
@ -943,7 +976,7 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }]; let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
} }
def TFL_DivOp : TFL_Op<"div", [Broadcastable, NoSideEffect]> { def TFL_DivOp : TFL_Op<"div", [ResultsBroadcastableShape, NoSideEffect]> {
let summary = "Division operator"; let summary = "Division operator";
let description = [{ let description = [{
@ -1002,7 +1035,7 @@ def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup",
let results = (outs TensorOf<[F32, I8, TFL_Uint8]>:$output); let results = (outs TensorOf<[F32, I8, TFL_Uint8]>:$output);
} }
def TFL_EqualOp: TFL_Op<"equal", [Commutative, Broadcastable, def TFL_EqualOp: TFL_Op<"equal", [Commutative, ResultsBroadcastableShape,
NoQuantizableResult, NoQuantizableResult,
PredOpTrait<"Operands have same value type", TCopVTEtIsSameAs<0, 1>>]> { PredOpTrait<"Operands have same value type", TCopVTEtIsSameAs<0, 1>>]> {
let summary = "Equal operator"; let summary = "Equal operator";
@ -1036,7 +1069,8 @@ def TFL_ExpOp: TFL_Op<"exp", [NoSideEffect, SameOperandsAndResultType]> {
let hasOptions = 0b1; let hasOptions = 0b1;
} }
def TFL_ExpandDimsOp: TFL_Op<"expand_dims", [NoSideEffect]> { def TFL_ExpandDimsOp: TFL_Op<"expand_dims", [
NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Inserts a dimension of 1 into a tensor's shape."; let summary = "Inserts a dimension of 1 into a tensor's shape.";
let description = [{ let description = [{
@ -1146,7 +1180,7 @@ def TFL_FloorOp: TFL_Op<"floor", [NoSideEffect, SameOperandsAndResultType]> {
} }
def TFL_FloorDivOp : TFL_Op<"floor_div", [ def TFL_FloorDivOp : TFL_Op<"floor_div", [
Broadcastable, NoSideEffect, BinaryOpSameElementTypeConstraint]> { ResultsBroadcastableShape, NoSideEffect, BinaryOpSameElementTypeConstraint]> {
let summary = "Floor div operator"; let summary = "Floor div operator";
let description = [{ let description = [{
@ -1165,7 +1199,7 @@ def TFL_FloorDivOp : TFL_Op<"floor_div", [
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }]; let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
} }
def TFL_FloorModOp : TFL_Op<"floor_mod", [Broadcastable, NoSideEffect]> { def TFL_FloorModOp : TFL_Op<"floor_mod", [ResultsBroadcastableShape, NoSideEffect]> {
let summary = "Division reminder"; let summary = "Division reminder";
let description = [{ let description = [{
@ -1181,7 +1215,8 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [Broadcastable, NoSideEffect]> {
let builders = [TFL_BroadcastableBinaryBuilder]; let builders = [TFL_BroadcastableBinaryBuilder];
} }
def TFL_GreaterOp : TFL_Op<"greater", [NoSideEffect, NoQuantizableResult]> { def TFL_GreaterOp : TFL_Op<"greater", [
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
let summary = "Greater operator"; let summary = "Greater operator";
let description = [{ let description = [{
@ -1194,6 +1229,8 @@ def TFL_GreaterOp : TFL_Op<"greater", [NoSideEffect, NoQuantizableResult]> {
let results = (outs AnyTensor:$output); let results = (outs AnyTensor:$output);
let builders = [TFL_ComparisonBinaryBuilder];
let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }]; let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }];
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }]; let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
@ -1260,7 +1297,8 @@ def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [NoSideEffect, SameOperandsAndResultTy
let hasOptions = 0b1; let hasOptions = 0b1;
} }
def TFL_LessOp : TFL_Op<"less", [NoSideEffect, NoQuantizableResult]> { def TFL_LessOp : TFL_Op<"less", [
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
let summary = "Less operator"; let summary = "Less operator";
let description = [{ let description = [{
@ -1427,8 +1465,65 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
let customOption = "Pool2DOptions"; let customOption = "Pool2DOptions";
} }
def TFL_MaxPoolingWithArgMax2DOp :
Op<TFL_Dialect, "max_pooling_with_argmax_2d", [NoSideEffect]> {
let summary = "Max Pool 2D with argmax op";
let description = [{
Performs max pooling on the input and outputs both max values and indices.
Each index is a flatten index in a sub-array of "filter_w" x "filter_h" size
Note this is a custom op that is not supported in the standard runtime.
Inputs:
`inputs[0]`: required: the input activation tensor
}];
let arguments = (
ins AnyTensor:$input,
TFL_PaddingAttr:$padding,
I32Attr:$stride_w,
I32Attr:$stride_h,
I32Attr:$filter_w,
I32Attr:$filter_h
);
let results = (outs
AnyTensor:$value,
AnyTensor:$indices
);
}
def TFL_MaxUnpooling2DOp :
Op<TFL_Dialect, "max_unpooling_2d", [NoSideEffect]> {
let summary = "Max Unpool 2D";
let description = [{
Performs max unpool operation.
To some extent this is the reverse operation of max pooling:
the elements in the input activation tensor is stored into the position
specified by the input indices.
Note this is a custom op that is not supported in the standard runtime.
Inputs:
`inputs[0]`: required: the input activation tensor
`inputs[1]`: required: the input indices
}];
let arguments = (
ins AnyTensor:$input,
AnyTensor:$indices,
TFL_PaddingAttr:$padding,
I32Attr:$stride_w,
I32Attr:$stride_h,
I32Attr:$filter_w,
I32Attr:$filter_h
);
let results = (outs AnyTensor:$outputs);
}
def TFL_MaximumOp : TFL_Op<"maximum", [ def TFL_MaximumOp : TFL_Op<"maximum", [
Broadcastable, NoSideEffect, Commutative, SameOperandsAndResultsScale, ResultsBroadcastableShape, NoSideEffect, Commutative, SameOperandsAndResultsScale,
TFL_OperandHasRankLessThan<0, 4>, TFL_OperandHasRankLessThan<1, 4>]> { TFL_OperandHasRankLessThan<0, 4>, TFL_OperandHasRankLessThan<1, 4>]> {
let summary = "Max operator"; let summary = "Max operator";
let description = [{ let description = [{
@ -1567,7 +1662,8 @@ def TFL_SumOp: TFL_Op<"sum", [NoSideEffect]> {
let customOption = "ReducerOptions"; let customOption = "ReducerOptions";
} }
def TFL_ReduceMinOp: TFL_Op<"reduce_min", [NoSideEffect]> { def TFL_ReduceMinOp: TFL_Op<"reduce_min", [
NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Min-reduction operator"; let summary = "Min-reduction operator";
let description = [{ let description = [{
@ -1586,7 +1682,8 @@ def TFL_ReduceMinOp: TFL_Op<"reduce_min", [NoSideEffect]> {
let customOption = "ReducerOptions"; let customOption = "ReducerOptions";
} }
def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [NoSideEffect]> { def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [
NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Max-reduction operator"; let summary = "Max-reduction operator";
let description = [{ let description = [{
@ -1625,7 +1722,7 @@ def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [NoSideEffect]> {
} }
def TFL_MinimumOp : TFL_Op<"minimum", [ def TFL_MinimumOp : TFL_Op<"minimum", [
Broadcastable, NoSideEffect, Commutative, SameOperandsAndResultsScale, ResultsBroadcastableShape, NoSideEffect, Commutative, SameOperandsAndResultsScale,
TFL_OperandHasRankLessThan<0, 4>, TFL_OperandHasRankLessThan<1, 4>]> { TFL_OperandHasRankLessThan<0, 4>, TFL_OperandHasRankLessThan<1, 4>]> {
let summary = "Min operator"; let summary = "Min operator";
let description = [{ let description = [{
@ -1646,7 +1743,7 @@ def TFL_MinimumOp : TFL_Op<"minimum", [
let hasOptions = 0; let hasOptions = 0;
} }
def TFL_MulOp : TFL_Op<"mul", [Broadcastable, NoSideEffect, Commutative]> { def TFL_MulOp : TFL_Op<"mul", [ResultsBroadcastableShape, NoSideEffect, Commutative]> {
let summary = "Multiplication operator"; let summary = "Multiplication operator";
let description = [{ let description = [{
@ -1683,6 +1780,8 @@ def TFL_NegOp: TFL_Op<"neg", [NoSideEffect, SameOperandsAndResultType]> {
let results = (outs AnyTensor:$y); let results = (outs AnyTensor:$y);
let hasOptions = 0b1; let hasOptions = 0b1;
let hasFolder = 1;
} }
def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> { def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> {
@ -1716,14 +1815,14 @@ def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> {
}]; }];
let arguments = (ins let arguments = (ins
Variadic<TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8]>>:$values, Variadic<TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8, QI16]>>:$values,
I32Attr:$values_count, I32Attr:$values_count,
I32Attr:$axis I32Attr:$axis
); );
let results = (outs let results = (outs
TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8]>:$output TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8, QI16]>:$output
); );
let verifier = [{ return Verify(*this); }]; let verifier = [{ return Verify(*this); }];
@ -1821,7 +1920,7 @@ def TFL_PadV2Op : TFL_Op<"padv2", [
let hasOptions = 1; let hasOptions = 1;
} }
def TFL_PowOp : TFL_Op<"pow", [Broadcastable, NoSideEffect, NoQuantizableResult]> { def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
let summary = "Power operator"; let summary = "Power operator";
let description = [{ let description = [{
@ -1996,7 +2095,7 @@ def TFL_ShapeOp: TFL_Op<"shape", [NoSideEffect]> {
let results = (outs AnyTensor:$output); let results = (outs AnyTensor:$output);
DerivedTypeAttr out_type = DerivedTypeAttr<[{ DerivedTypeAttr out_type = DerivedTypeAttr<[{
return getResult()->getType().cast<TensorType>().getElementType(); return getResult().getType().cast<TensorType>().getElementType();
}]>; }]>;
let hasOptions = 1; let hasOptions = 1;
@ -2039,7 +2138,7 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2",
Args: Args:
tensor: A Tensor. Must be one of the following types: tensor: A Tensor. Must be one of the following types:
int16, int32, int64, float32 Up to 8-D. uint8, int16, int32, int64, float32, bool Up to 8-D.
axis: A Tensor. Must be one of the following types: int32, int64. axis: A Tensor. Must be one of the following types: int32, int64.
with only 1 element which is the axis index. with only 1 element which is the axis index.
@ -2048,12 +2147,12 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2",
let arguments = ( let arguments = (
ins ins
TensorOf<[F32, I16, I32, I64]>:$input, TensorOf<[F32, I16, I32, I64, TFL_Uint8, I1]>:$input,
TensorOf<[I32, I64]>:$axis TensorOf<[I32, I64]>:$axis
); );
let results = (outs let results = (outs
TensorOf<[F32, I16, I32, I64, I8]>:$output TensorOf<[F32, I16, I32, I64, TFL_Uint8, I1]>:$output
); );
} }
@ -2083,7 +2182,7 @@ def TFL_SelectOp : TFL_Op<"select", [NoSideEffect,
let builders = [OpBuilder<"Builder *builder, OperationState &result, " let builders = [OpBuilder<"Builder *builder, OperationState &result, "
"Value condition, Value x, Value y", "Value condition, Value x, Value y",
[{ [{
auto resultType = x->getType(); auto resultType = x.getType();
result.addOperands({condition, x, y}); result.addOperands({condition, x, y});
result.types.push_back(resultType); result.types.push_back(resultType);
}]>]; }]>];
@ -2190,7 +2289,7 @@ def TFL_SquareOp: TFL_Op<"square", [
let hasFolder = 1; let hasFolder = 1;
} }
def TFL_SubOp : TFL_Op<"sub", [Broadcastable, NoSideEffect]> { def TFL_SubOp : TFL_Op<"sub", [ResultsBroadcastableShape, NoSideEffect]> {
let summary = "Subtraction operator"; let summary = "Subtraction operator";
let description = [{ let description = [{
@ -2218,7 +2317,7 @@ def TFL_SubOp : TFL_Op<"sub", [Broadcastable, NoSideEffect]> {
// TODO(jpienaar): Expand the kernel implementation to support all types besides // TODO(jpienaar): Expand the kernel implementation to support all types besides
// I32 and F32. // I32 and F32.
def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [ def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [
Broadcastable, NoSideEffect, NoQuantizableResult]> { ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
let summary = "Squared difference operator"; let summary = "Squared difference operator";
let description = [{ let description = [{
@ -2257,9 +2356,9 @@ def TFL_TanhOp: TFL_Op<"tanh", [
let results = (outs TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$y); let results = (outs TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$y);
} }
def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, SameOperandsAndResultsScale,
PredOpTrait<"resultant element type needs to match first operand type", PredOpTrait<"resultant element type needs to match first operand type",
TCresVTEtIsSameAsOp<0,0>>]> { TFL_TCresVTEtIsSameAsOp<0,0>>]> {
let summary = "Tile operator."; let summary = "Tile operator.";
let description = [{ let description = [{
Constructs a tensor by tiling a given tensor. Constructs a tensor by tiling a given tensor.
@ -2272,10 +2371,11 @@ def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
}]; }];
let arguments = (ins let arguments = (ins
TensorOf<[F32, I1, I32, I64, TFL_Uint8]>:$input, TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8]>:$input,
TFL_I32OrI64Tensor:$multiples); TFL_I32OrI64Tensor:$multiples);
let results = (outs TensorOf<[F32, I1, I32, I64, TFL_Uint8]>:$output); let results = (outs
TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8]>:$output);
let hasOptions = 0; let hasOptions = 0;
} }
@ -2285,7 +2385,7 @@ def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
// TODO(jpienaar): Check that k is less or equal the internal dimension // TODO(jpienaar): Check that k is less or equal the internal dimension
def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>, def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>,
PredOpTrait<"result and input element type match", PredOpTrait<"result and input element type match",
TCresVTEtIsSameAsOp<0,0>>]> { TCresVTEtIsSameAsOp<0,0>>, SameOperandsAndResultsScale]> {
let summary = "TopK operator"; let summary = "TopK operator";
let description = [{ let description = [{
@ -2295,11 +2395,11 @@ def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>,
}]; }];
let arguments = (ins let arguments = (ins
TensorOf<[F32, I8, I32, I64, TFL_Uint8]>:$input, TensorOf<[F32, I8, I32, I64, TFL_Uint8, QI8, QUI8]>:$input,
I32Tensor:$k); I32Tensor:$k);
let results = (outs let results = (outs
AnyTensor:$values, TensorOf<[F32, I8, I32, I64, TFL_Uint8, QI8, QUI8]>:$values,
I32Tensor:$indices); I32Tensor:$indices);
let builders = [OpBuilder<"Builder *builder, OperationState &result, " let builders = [OpBuilder<"Builder *builder, OperationState &result, "
@ -2338,7 +2438,7 @@ def TFL_TransposeOp : TFL_Op<"transpose",
let hasFolder = 1; let hasFolder = 1;
} }
def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect]> { def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Unpacks a tensor along a dimension into multiple tensors"; let summary = "Unpacks a tensor along a dimension into multiple tensors";
let description = [{ let description = [{
@ -2554,7 +2654,9 @@ def TFL_ResizeBilinearOp: TFL_Op<"resize_bilinear", [
// TODO(ycling): Support quantized types. // TODO(ycling): Support quantized types.
TensorOf<[F32, I32, QI8, QUI8]>:$input, TensorOf<[F32, I32, QI8, QUI8]>:$input,
TensorOf<[I32]>:$size, TensorOf<[I32]>:$size,
BoolAttr:$align_corners); BoolAttr:$align_corners,
DefaultValuedAttr<BoolAttr, "false">:$half_pixel_centers
);
let results = (outs let results = (outs
TensorOf<[F32, QI8, QUI8]>:$output TensorOf<[F32, QI8, QUI8]>:$output
@ -2663,12 +2765,11 @@ def TFL_CastOp : TFL_Op<"cast", [
Casts input from input type to output type. Casts input from input type to output type.
}]; }];
// TODO(b/135538711): Add complex types here.
let arguments = (ins let arguments = (ins
TensorOf<[F32, I1, I32, I64, TFL_Quint8, TFL_Uint8]>:$input TensorOf<[F32, I1, I32, I64, TFL_Quint8, TFL_Uint8, Complex<F<32>>]>:$input
); );
let results = (outs TensorOf<[F32, I1, I32, I64]>:$output); let results = (outs TensorOf<[F32, I1, I32, I64, Complex<F<32>>]>:$output);
// TFLite's cast op does not utilize CastOptions, instead derives types // TFLite's cast op does not utilize CastOptions, instead derives types
// from the TfLiteTensors. // from the TfLiteTensors.
@ -2733,7 +2834,7 @@ in the unique output `y`. In other words:
); );
DerivedTFLiteTypeAttr idx_out_type = DerivedTFLiteTypeAttr<[{ DerivedTFLiteTypeAttr idx_out_type = DerivedTFLiteTypeAttr<[{
return getResult(1)->getType().cast<TensorType>().getElementType(). return getResult(1).getType().cast<TensorType>().getElementType().
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 : cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
tflite::TensorType_INT32; tflite::TensorType_INT32;
}]>; }]>;
@ -2768,7 +2869,9 @@ def TFL_FakeQuantOp : TFL_Op<"fake_quant", [NoSideEffect]> {
let arguments = ( let arguments = (
ins AnyTensor:$input, ins AnyTensor:$input,
// The expected [min, max] range of values. // The expected [min, max] range of values.
MinMaxAttr:$minmax, F32Attr:$min,
F32Attr:$max,
// The bitwidth of the quantization; between 2 and 16, inclusive. // The bitwidth of the quantization; between 2 and 16, inclusive.
I32Attr:$num_bits, I32Attr:$num_bits,
// Quantization range starts from 0 or 1; starts from 1 if true. // Quantization range starts from 0 or 1; starts from 1 if true.
@ -2777,6 +2880,8 @@ def TFL_FakeQuantOp : TFL_Op<"fake_quant", [NoSideEffect]> {
let results = (outs AnyTensor:$output); let results = (outs AnyTensor:$output);
let hasCanonicalizer = 0b1; let hasCanonicalizer = 0b1;
let hasOptions = 1;
} }
def TFL_QConstOp : Op<TFL_Dialect, "pseudo_qconst", [ def TFL_QConstOp : Op<TFL_Dialect, "pseudo_qconst", [
@ -2823,6 +2928,20 @@ def TFL_QuantizeOp: TFL_Op<"quantize", [
let results = (outs AnyTensor:$output); let results = (outs AnyTensor:$output);
} }
def TFL_DensifyOp: TFL_Op<"densify", [NoSideEffect,
SameOperandsAndResultType,
NoQuantizableResult]> {
let summary = "Densify operator";
let description = [{
Converts sparse tensor to dense format.
}];
let arguments = (ins AnyTensor:$input);
let results = (outs AnyTensor:$output);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// LSTM Ops // LSTM Ops
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -2912,7 +3031,7 @@ def TFL_LSTMOp :
LstmOptionalPeepholeWeightConstraint, LstmOptionalPeepholeWeightConstraint,
LstmProjectionWeightBiasConstraint, LstmProjectionWeightBiasConstraint,
LstmResultConstraint, LstmResultConstraint,
StatefulOperands<[18, 19]>]> { TFL_StatefulOp]> {
let summary = "The full lstm operator"; let summary = "The full lstm operator";
let description = [{ let description = [{
@ -2996,6 +3115,11 @@ Ba et al. “Layer Normalization”
let hasOptions = 1; let hasOptions = 1;
let verifier = [{ return Verify(*this); }]; let verifier = [{ return Verify(*this); }];
let extraClassDeclaration = [{
// StatefulOpInterface:
std::vector<int> GetStatefulOperands() { return {18, 19}; }
}];
} }
// UnidirectionalSequenceLstm op. // UnidirectionalSequenceLstm op.
@ -3007,7 +3131,7 @@ def TFL_UnidirectionalSequenceLSTMOp :
LstmOptionalPeepholeWeightConstraint, LstmOptionalPeepholeWeightConstraint,
LstmProjectionWeightBiasConstraint, LstmProjectionWeightBiasConstraint,
LstmResultConstraint, LstmResultConstraint,
StatefulOperands<[18, 19]>]> { TFL_StatefulOp]> {
let summary = "Unidirectional sequence lstm operator"; let summary = "Unidirectional sequence lstm operator";
let description = [{ let description = [{
@ -3076,6 +3200,11 @@ def TFL_UnidirectionalSequenceLSTMOp :
let hasOptions = 1; let hasOptions = 1;
let verifier = [{ return Verify(*this); }]; let verifier = [{ return Verify(*this); }];
let extraClassDeclaration = [{
// StatefulOpInterface:
std::vector<int> GetStatefulOperands() { return {18, 19}; }
}];
} }
def RnnResultConstraint : PredOpTrait< def RnnResultConstraint : PredOpTrait<
@ -3085,7 +3214,7 @@ def RnnResultConstraint : PredOpTrait<
// UnidirectionalSequenceRNN op. // UnidirectionalSequenceRNN op.
def TFL_UnidirectionalSequenceRNNOp : def TFL_UnidirectionalSequenceRNNOp :
TFL_Op<"unidirectional_sequence_rnn", TFL_Op<"unidirectional_sequence_rnn",
[RnnResultConstraint, StatefulOperands<[4]>]> { [RnnResultConstraint, TFL_StatefulOp]> {
let summary = "Unidirectional sequence rnn operator"; let summary = "Unidirectional sequence rnn operator";
@ -3129,6 +3258,11 @@ def TFL_UnidirectionalSequenceRNNOp :
let customOption = "SequenceRNNOptions"; let customOption = "SequenceRNNOptions";
let verifier = [{ return Verify(*this); }]; let verifier = [{ return Verify(*this); }];
let extraClassDeclaration = [{
// StatefulOpInterface:
std::vector<int> GetStatefulOperands() { return {4}; }
}];
} }
def TFL_WhereOp : TFL_Op<"where", [NoSideEffect]> { def TFL_WhereOp : TFL_Op<"where", [NoSideEffect]> {
@ -3180,7 +3314,7 @@ def SVDFResultConstraint: PredOpTrait<
// SVDF op. // SVDF op.
def TFL_SVDFOp : def TFL_SVDFOp :
TFL_Op<"svdf", TFL_Op<"svdf",
[SVDFResultConstraint, StatefulOperands<[4]>]> { [SVDFResultConstraint, TFL_StatefulOp]> {
let summary = "Single value decomposition filter operator"; let summary = "Single value decomposition filter operator";
@ -3216,6 +3350,67 @@ def TFL_SVDFOp :
let hasOptions = 1; let hasOptions = 1;
let verifier = [{ return Verify(*this); }]; let verifier = [{ return Verify(*this); }];
let extraClassDeclaration = [{
// StatefulOpInterface:
std::vector<int> GetStatefulOperands() { return {4}; }
}];
}
def TFL_SegmentSumOp: TFL_Op<"segment_sum", [NoSideEffect]> {
let summary = "SegmentSum operator";
let description = [{
Computes the sum along segments of a tensor.
}];
let arguments = (ins
TensorOf<[F32, I32]>:$data,
I32Tensor:$segment_ids
);
let results = (outs TensorOf<[F32, I32]>:$output);
}
def TFL_YieldOp : Op<TFL_Dialect, "yield", [Terminator]> {
let summary = "Yield operation";
let description = [{
The "yield" operation represents a return operation within the conditional
and body of structured control flow (e.g., while). The operation takes
variable number of operands and produces no results. The operand number and
types must match the signature of the region that contains the operation.
}];
let arguments = (ins Variadic<AnyType>:$operands);
}
def TFL_WhileOp : Op<TFL_Dialect, "while", [
DeclareOpInterfaceMethods<LoopLikeOpInterface>,
SingleBlockImplicitTerminator<"YieldOp">,
// Make isolated from above to force values through operands to simplify
// exporting to subgraphs.
IsolatedFromAbove]> {
let summary = [{While loop}];
let description = [{
output = input; while (cond(output)) { output = body(output) }
input: A list of input tensors whose types are T.
output: A list of output tensors whose types are T.
cond: A region takes 'input' and returns a boolean scalar tensor.
body: A region that takes a list of tensors and returns another
list of tensors. Both lists have the same types.
}];
let arguments = (ins
Variadic<AnyTensor>:$input,
// Used to map StatelessWhile and While op defined in TensorFlow to a common
// op.
DefaultValuedAttr<BoolAttr, "false">:$is_stateless
);
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
let results = (outs Variadic<AnyTensor>:$output);
} }
#endif // TFL_OPS #endif // TFL_OPS

View File

@ -1,67 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the op traits used in the MLIR TensorFlow Lite dialect.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LLVM.h" // TF:llvm-project
namespace mlir {
namespace OpTrait {
namespace TFL {
// The trait to specify that the specified operands of the TFL op are stateful.
// This is used as a trait like this:
//
// class LSTMOp
// : public Op<LSTMOp, OpTrait::TFL::StatefulOperands<18, 19>::Impl> {
//
template <int... Operands>
class StatefulOperands {
public:
template <typename ConcreteType>
class Impl
: public TraitBase<ConcreteType, StatefulOperands<Operands...>::Impl> {
public:
static std::vector<int> GetStatefulOperands() {
return std::vector<int>({Operands...});
}
};
};
// The trait to specify the channel dimension index of the input (first operand)
// of an affine TFL op (Conv2D, DepthwiseConv2D, FullyConnected).
//
// class Conv2DOp
// : public Op<Conv2DOp, OpTrait::TFL::ChannelDimIndex<0>::Impl> {
//
template <int Index>
class ChannelDimIndex {
public:
template <typename ConcreteType>
class Impl : public TraitBase<ConcreteType, ChannelDimIndex<Index>::Impl> {
public:
static int GetChannelDimIndex() { return Index; }
};
};
} // namespace TFL
} // namespace OpTrait
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_

View File

@ -122,7 +122,7 @@ static void EmitOptionBuilders(const RecordKeeper &record_keeper,
os << formatv( os << formatv(
" auto {0} = Convert{1}ForOptionWriter(op.{0}(), fbb);\n", " auto {0} = Convert{1}ForOptionWriter(op.{0}(), fbb);\n",
val.getName(), record->getClasses()[0]->getName()); val.getName(), record->getClasses()[0]->getName());
options.push_back(val.getName()); options.push_back(std::string(val.getName()));
} }
} }
} }

View File

@ -32,6 +32,6 @@ cc_library(
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
"@llvm-project//mlir:ViewOpGraph", "@llvm-project//mlir:Transforms",
], ],
) )

View File

@ -107,9 +107,6 @@ void WarningUnusedFlags(const toco::ModelFlags& model_flags,
if (toco_flags.output_format()) { if (toco_flags.output_format()) {
LOG(WARNING) << "Ignored output_format."; LOG(WARNING) << "Ignored output_format.";
} }
if (toco_flags.default_ranges_min() || toco_flags.default_ranges_max()) {
LOG(WARNING) << "Ignored default_ranges_stats.";
}
if (toco_flags.drop_control_dependency()) { if (toco_flags.drop_control_dependency()) {
LOG(WARNING) << "Ignored drop_control_dependency."; LOG(WARNING) << "Ignored drop_control_dependency.";
} }
@ -242,6 +239,13 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
tensorflow::ParseOutputArrayInfo(output_arrays, &specs.outputs)); tensorflow::ParseOutputArrayInfo(output_arrays, &specs.outputs));
// Other flags. // Other flags.
if (toco_flags.has_default_ranges_min()) {
quant_specs.default_ranges.first = toco_flags.default_ranges_min();
}
if (toco_flags.has_default_ranges_max()) {
quant_specs.default_ranges.second = toco_flags.default_ranges_max();
}
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops(); bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops(); bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
bool emit_custom_ops = toco_flags.allow_custom_ops(); bool emit_custom_ops = toco_flags.allow_custom_ops();

View File

@ -71,18 +71,17 @@ cc_library(
"quantization_utils.cc", "quantization_utils.cc",
], ],
hdrs = [ hdrs = [
"quantization_traits.h",
"quantization_utils.h", "quantization_utils.h",
], ],
deps = [ deps = [
"//tensorflow/core:lib_proto_parsing",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@llvm-project//llvm:support", "@llvm-project//llvm:support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps", "@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
# TODO(fengliuai): remove this dependence.
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/core:lib_proto_parsing",
], ],
) )

View File

@ -78,8 +78,8 @@ class ImportQuantStatsPass : public FunctionPass<ImportQuantStatsPass> {
bool IsQuantizableResult(Operation *op, int index) { bool IsQuantizableResult(Operation *op, int index) {
if (index < 0 || index >= op->getNumResults()) return false; if (index < 0 || index >= op->getNumResults()) return false;
Value res = op->getResult(index); Value res = op->getResult(index);
return res->getType().isa<ShapedType>() && return res.getType().isa<ShapedType>() &&
res->getType().cast<ShapedType>().getElementType().isa<FloatType>(); res.getType().cast<ShapedType>().getElementType().isa<FloatType>();
} }
// A method to retrieve the name for the given op. // A method to retrieve the name for the given op.
@ -123,7 +123,7 @@ void ImportQuantStatsPass::InsertStatsOpAtResult(OpBuilder b, Value res,
IntegerAttr axis) { IntegerAttr axis) {
auto stats_op = b.create<quant::StatisticsOp>(b.getUnknownLoc(), res, auto stats_op = b.create<quant::StatisticsOp>(b.getUnknownLoc(), res,
layer_stats, axis_stats, axis); layer_stats, axis_stats, axis);
res->replaceAllUsesWith(stats_op); res.replaceAllUsesWith(stats_op);
stats_op.getOperation()->replaceUsesOfWith(stats_op, res); stats_op.getOperation()->replaceUsesOfWith(stats_op, res);
} }
@ -206,10 +206,17 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
std::unique_ptr<OpPassBase<FuncOp>> std::unique_ptr<OpPassBase<FuncOp>>
CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) { CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) {
auto get_name_func = [](Operation *op) { auto get_name_func = [](Operation *op) {
if (auto name = op->getAttrOfType<StringAttr>("name")) Location loc = op->getLoc();
return name.getValue(); if (auto name = loc.dyn_cast<NameLoc>()) {
else return name.getName().strref();
return llvm::StringRef(""); } else if (auto fused_name = loc.dyn_cast<FusedLoc>()) {
for (auto sub_loc : fused_name.getLocations()) {
if (auto named_sub_loc = sub_loc.dyn_cast<NameLoc>()) {
return named_sub_loc.getName().strref();
}
}
}
return llvm::StringRef("");
}; };
return CreateImportQuantStatsPass(get_name_func, stats_str); return CreateImportQuantStatsPass(get_name_func, stats_str);

View File

@ -12,6 +12,7 @@ package_group(
includes = ["//third_party/mlir:subpackages"], includes = ["//third_party/mlir:subpackages"],
packages = [ packages = [
"//learning/brain/experimental/mlir/...", "//learning/brain/experimental/mlir/...",
"//tensorflow/compiler/mlir/lite/...",
"//tensorflow/lite/...", "//tensorflow/lite/...",
], ],
) )
@ -23,7 +24,6 @@ cc_library(
], ],
hdrs = [ hdrs = [
"quantize_model.h", "quantize_model.h",
"//tensorflow/compiler/mlir/lite:transforms/passes.h",
], ],
deps = [ deps = [
"//tensorflow/compiler/mlir/lite:common", "//tensorflow/compiler/mlir/lite:common",
@ -42,6 +42,24 @@ cc_library(
], ],
) )
cc_library(
name = "tfl_to_std",
srcs = [
"tfl_to_std.cc",
],
hdrs = [
"tfl_to_std.h",
],
deps = [
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
],
)
# Binary to apply quantization on the annotated files. # Binary to apply quantization on the annotated files.
tf_cc_binary( tf_cc_binary(
name = "tfl_quantizer", name = "tfl_quantizer",

View File

@ -73,19 +73,19 @@ TfLiteStatus QuantizeModel(
// Apply quantization passes // Apply quantization passes
PassManager pm(module->getContext()); PassManager pm(module->getContext());
TFL::QuantizationSpecs pass_config; TFL::QuantizationSpecs quant_specs;
pass_config.inference_type = tensorflow::DT_QINT8; quant_specs.inference_type = tensorflow::DT_QINT8;
pass_config.post_training_quantization = true; quant_specs.post_training_quantization = true;
bool emit_adaptor = false; bool emit_adaptor = false;
auto input_tf_type = tflite::TflTypeToTfType(input_type); auto input_tf_type = tflite::TflTypeToTfType(input_type);
if (input_tf_type == tensorflow::DT_FLOAT) { if (input_tf_type == tensorflow::DT_FLOAT) {
emit_adaptor = true; emit_adaptor = true;
} else if (input_tf_type == tensorflow::DT_UINT8) { } else if (input_tf_type == tensorflow::DT_UINT8) {
pass_config.inference_type = tensorflow::DT_QUINT8; quant_specs.inference_type = tensorflow::DT_QUINT8;
} }
pm.addPass(TFL::CreatePrepareQuantizePass(pass_config)); pm.addPass(TFL::CreatePrepareQuantizePass(quant_specs));
pm.addPass(TFL::CreateQuantizePass()); pm.addPass(TFL::CreateQuantizePass());
pm.addPass(TFL::CreatePostQuantizePass(emit_adaptor)); pm.addPass(TFL::CreatePostQuantizePass(emit_adaptor));

Some files were not shown because too many files have changed in this diff Show More