Merge branch 'master' into identity_in_constant_value

This commit is contained in:
Harry Slatyer 2020-02-03 11:04:24 +11:00
commit 8ca5bed715
2860 changed files with 129738 additions and 45504 deletions

View File

@ -69,6 +69,7 @@
# rbe_linux_py3: Linux Python 3 RBE config
#
# rbe_win_py37: Windows Python 3.7 RBE config
# rbe_win_py38: Windows Python 3.8 RBE config
#
# tensorflow_testing_rbe_linux: RBE options to use RBE with tensorflow-testing project on linux
# tensorflow_testing_rbe_win: RBE options to use RBE with tensorflow-testing project on windows
@ -279,7 +280,6 @@ build:windows --host_linkopt=/OPT:REF
build:windows --linkopt=/OPT:ICF
build:windows --host_linkopt=/OPT:ICF
build:windows --experimental_strict_action_env=true
build:windows --incompatible_windows_native_test_wrapper
# Verbose failure logs when something goes wrong
build:windows --verbose_failures
@ -344,6 +344,7 @@ build:rbe_linux --config=avx_linux
build:rbe_linux --config=short_logs
# TODO(gunan): Check why we need this specified in rbe, but not in other builds.
build:rbe_linux --linkopt=-lrt
build:rbe_linux --linkopt=-lm
build:rbe_cpu_linux --config=rbe_linux
build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain"
@ -392,6 +393,7 @@ build:rbe_win --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe
# TODO(gunan): Remove once we use MSVC 2019 with latest patches.
build:rbe_win --define=override_eigen_strong_inline=true
build:rbe_win --jobs=500
build:rbe_win_py37 --config=rbe
build:rbe_win_py37 --repo_env=PYTHON_BIN_PATH=C:\\Python37\\python.exe
@ -399,6 +401,12 @@ build:rbe_win_py37 --repo_env=PYTHON_LIB_PATH=C:\\Python37\\lib\\site-packages
build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py37
build:rbe_win_py37 --python_path=C:\\Python37\\python.exe
build:rbe_win_py38 --config=rbe
build:rbe_win_py38 --repo_env=PYTHON_BIN_PATH=C:\\Python38\\python.exe
build:rbe_win_py38 --repo_env=PYTHON_LIB_PATH=C:\\Python38\\lib\\site-packages
build:rbe_win_py38 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py38
build:rbe_win_py38 --python_path=C:\\Python38\\python.exe
# These you may need to change for your own GCP project.
build:tensorflow_testing_rbe --project_id=tensorflow-testing
common:tensorflow_testing_rbe_linux --remote_instance_name=projects/tensorflow-testing/instances/default_instance

View File

@ -1 +1 @@
1.1.0
1.2.1

View File

@ -29,20 +29,6 @@ to
[announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce).
See all the [mailing lists](https://www.tensorflow.org/community/forums).
## Feature Prioritization Survey
The TensorFlow team is working on building/improving features, and understands
that it is very important to prioritize these efforts based on what TF users
need.
The goal of this short, < 5 minute
[survey](https://google.qualtrics.com/jfe/form/SV_d5nqhCEbkDkQ7ad), is to help
the TensorFlow team better understand what features to prioritize based on your
feedback. Participation is of course optional.
Take the survey
[HERE](https://google.qualtrics.com/jfe/form/SV_d5nqhCEbkDkQ7ad).
## Install
See the [TensorFlow install guide](https://www.tensorflow.org/install) for the
@ -164,4 +150,3 @@ Learn more about the
## License
[Apache License 2.0](LICENSE)

File diff suppressed because one or more lines are too long

View File

@ -245,4 +245,4 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc=
### Known Vulnerabilities
For a list of known vulnerabilities and security advisories for TensorFlow,
[click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/index.md).
[click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/README.md).

View File

@ -1,11 +1,13 @@
workspace(name = "org_tensorflow")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("//third_party:repo.bzl", "tf_http_archive")
http_archive(
tf_http_archive(
name = "io_bazel_rules_closure",
sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9",
strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149",
patch_file = "@org_tensorflow//third_party:rules_closure.patch",
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz",
"https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13
@ -48,38 +50,6 @@ load("//third_party/toolchains/preconfig/generate:workspace.bzl",
remote_config_workspace()
# Apple and Swift rules.
http_archive(
name = "build_bazel_rules_apple",
sha256 = "a045a436b642c70fb0c10ca84ff0fd2dcbd59cc89100d597a61e8374afafb366",
urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.18.0/rules_apple.0.18.0.tar.gz"],
) # https://github.com/bazelbuild/rules_apple/releases
http_archive(
name = "build_bazel_rules_swift",
sha256 = "18cd4df4e410b0439a4935f9ca035bd979993d42372ba79e7f2d4fafe9596ef0",
urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.12.1/rules_swift.0.12.1.tar.gz"],
) # https://github.com/bazelbuild/rules_swift/releases
http_archive(
name = "build_bazel_apple_support",
sha256 = "122ebf7fe7d1c8e938af6aeaee0efe788a3a2449ece5a8d6a428cb18d6f88033",
urls = ["https://github.com/bazelbuild/apple_support/releases/download/0.7.1/apple_support.0.7.1.tar.gz"],
) # https://github.com/bazelbuild/apple_support/releases
http_archive(
name = "bazel_skylib",
sha256 = "1dde365491125a3db70731e25658dfdd3bc5dbdfd11b840b3e987ecf043c7ca0",
urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.9.0/bazel-skylib.0.9.0.tar.gz"],
) # https://github.com/bazelbuild/bazel-skylib/releases
http_archive(
name = "com_github_apple_swift_swift_protobuf",
type = "zip",
strip_prefix = "swift-protobuf-1.6.0/",
urls = ["https://github.com/apple/swift-protobuf/archive/1.6.0.zip"],
) # https://github.com/apple/swift-protobuf/releases
http_file(
name = "xctestrunner",
executable = 1,
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.9/ios_test_runner.par"],
) # https://github.com/google/xctestrunner/releases
# Use `swift_rules_dependencies` to fetch the toolchains. With the
# `git_repository` rules above, the following call will skip redefining them.
load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies")

View File

@ -49,8 +49,8 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
_TF_WORKSPACE_ROOT = ''
_TF_BAZELRC = ''
_TF_CURRENT_BAZEL_VERSION = None
_TF_MIN_BAZEL_VERSION = '1.0.0'
_TF_MAX_BAZEL_VERSION = '1.1.0'
_TF_MIN_BAZEL_VERSION = '1.2.1'
_TF_MAX_BAZEL_VERSION = '1.2.1'
NCCL_LIB_PATHS = [
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
@ -1221,7 +1221,7 @@ def is_reduced_optimize_huge_functions_available(environ_cp):
only, as of 2019-11-19). TensorFlow needs this flag to massively reduce
compile times, but until 16.4 is officially released, we can't depend on it.
See also https://groups.google.com/a/tensorflow.org/g/build/c/SsW98Eo7l3o
See also https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
Because it's very annoying to check this manually (to check the MSVC installed
versions, you need to use the registry, and it's not clear if Bazel will be

View File

@ -2,6 +2,7 @@
# TensorFlow is a computational framework, primarily for use in machine
# learning applications.
load("@bazel_skylib//lib:selects.bzl", "selects")
load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary")
load(
"//tensorflow/core/platform:build_config.bzl",
@ -478,6 +479,7 @@ bzl_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core/platform:build_config_root_bzl",
"//tensorflow/core/platform:rules_cc_bzl",
"//tensorflow/core/platform/default:cuda_build_defs_bzl",
"//third_party/mkl:build_defs_bzl",
"//third_party/mkl_dnn:build_defs_bzl",

View File

@ -23,10 +23,6 @@ from __future__ import print_function
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
from tensorflow.python.util.lazy_loader import LazyLoader
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
del LazyLoader
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
from tensorflow.python.platform import app # pylint: disable=g-import-not-at-top
app.flags = flags

View File

@ -54,9 +54,10 @@ filegroup(
)
filegroup(
name = "pywrap_eager_hdrs",
name = "pywrap_required_hdrs",
srcs = [
"c_api_internal.h",
"python_api.h",
"tf_status_helper.h",
"tf_status_internal.h",
"tf_tensor_internal.h",
@ -98,6 +99,17 @@ tf_cuda_library(
],
)
filegroup(
name = "pywrap_tf_session_hdrs",
srcs = [
"python_api.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)
cc_library(
name = "tf_attrtype",
hdrs = ["tf_attrtype.h"],
@ -302,6 +314,7 @@ tf_cuda_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/platform",
"@com_google_absl//absl/strings",
@ -639,7 +652,7 @@ tf_cuda_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/kernels:ops_testutil",
"//third_party/eigen3",
"@com_google_absl//absl/container:inlined_vector",
],
)

View File

@ -458,7 +458,7 @@ static void TF_Run_Helper(
EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
continue;
}
c_outputs[i] = TF_TensorFromTensor(src, status);
c_outputs[i] = TF_TensorFromTensor(src, &status->status);
if (!status->status.ok()) return;
}
}
@ -1493,7 +1493,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
Tensor t;
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
if (!status->status.ok()) return;
*value = TF_TensorFromTensor(t, status);
*value = TF_TensorFromTensor(t, &status->status);
}
void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
@ -1504,7 +1504,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
if (!status->status.ok()) return;
const auto len = std::min(max_values, static_cast<int>(ts.size()));
for (int i = 0; i < len; ++i) {
values[i] = TF_TensorFromTensor(ts[i], status);
values[i] = TF_TensorFromTensor(ts[i], &status->status);
}
}
@ -2398,7 +2398,7 @@ unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
graph->graph.versions().producer(), &evaluated, &result_tensor);
if (evaluated) {
DCHECK(status->status.ok());
*result = TF_TensorFromTensor(result_tensor, status);
*result = TF_TensorFromTensor(result_tensor, &status->status);
if (!status->status.ok()) evaluated = false;
}
return evaluated;

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/shape_inference.h"
@ -549,7 +550,7 @@ TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(TFE_Op* op,
TF_Status* status) {
TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification;
n->thread.reset(op->operation.EagerContext()->TFEnv()->StartThread(
n->thread.reset(op->operation.EagerContext().TFEnv()->StartThread(
tensorflow::ThreadOptions(), "ExecuteOpThread",
[op, retvals, num_retvals, n]() {
TFE_Execute(op, retvals, num_retvals, n->status.get());
@ -634,7 +635,7 @@ TF_Tensor* TF_CheckpointReaderGetTensor(TF_CheckpointReader* reader,
std::unique_ptr<tensorflow::Tensor> tensor;
reader->GetTensor(name, &tensor, status);
if (!status->status.ok()) return nullptr;
return tensorflow::TF_TensorFromTensor(*tensor, status);
return tensorflow::TF_TensorFromTensor(*tensor, &status->status);
}
void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader,
@ -767,8 +768,9 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
} while (0);
// New server created for new server_def. Unused if updating server_def.
tensorflow::EagerContext* context = ctx->context;
tensorflow::GrpcServer* grpc_server =
dynamic_cast<tensorflow::GrpcServer*>(ctx->context->GetServer());
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
if (grpc_server == nullptr) {
std::unique_ptr<tensorflow::ServerInterface> new_server;
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
@ -779,12 +781,12 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
}
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
LOG_AND_RETURN_IF_ERROR(ctx->context->StoreCollectiveOpsServer(
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
std::move(new_server), grpc_server->worker_env()->device_mgr,
grpc_server->worker_env()->collective_executor_mgr));
} else {
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
LOG_AND_RETURN_IF_ERROR(ctx->context->StoreCollectiveOpsServer(
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
/*new_server=*/nullptr, grpc_server->worker_env()->device_mgr,
grpc_server->worker_env()->collective_executor_mgr));
}

View File

@ -1260,11 +1260,10 @@ TEST_F(CApiFunctionTest, GraphToFunctionDefWithPlaceholderAttr) {
NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node3", "v2",
&node3);
TF_Output inputs[] = {};
TF_Output outputs[] = {{node1, 0}, {node2, 0}, {node3, 0}};
func_ = TF_GraphToFunction(
func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
/*opers=*/nullptr, 0, inputs, 3, outputs,
/*opers=*/nullptr, 0, nullptr, 3, outputs,
/*output_names=*/nullptr,
/*opts=*/nullptr, /*description=*/nullptr, s.get());
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
@ -1300,10 +1299,9 @@ TEST_F(CApiFunctionTest, GraphToFunctionDefWithArgAttr) {
&node);
TF_Output inputs[] = {{node, 0}};
TF_Output outputs[] = {};
func_ = TF_GraphToFunction(
func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
/*opers=*/nullptr, 1, inputs, 0, outputs,
/*opers=*/nullptr, 1, inputs, 0, nullptr,
/*output_names=*/nullptr,
/*opts=*/nullptr, /*description=*/nullptr, s.get());
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
@ -1603,11 +1601,10 @@ void DefineStatefulFunction(const char* name, TF_Function** func) {
TF_Operation* random =
RandomUniform(shape, TF_FLOAT, func_graph.get(), s.get());
TF_Output inputs[] = {};
TF_Output outputs[] = {{random, 0}};
*func = TF_GraphToFunction(func_graph.get(), name,
/*append_hash_to_fn_name=*/false, -1,
/*opers=*/nullptr, 0, inputs, 1, outputs,
/*opers=*/nullptr, 0, nullptr, 1, outputs,
/*output_names=*/nullptr,
/*opts=*/nullptr, "", s.get());
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());

View File

@ -188,7 +188,7 @@ namespace tensorflow {
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
TF_Buffer* out);

View File

@ -51,7 +51,7 @@ limitations under the License.
#include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow {
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
namespace {
@ -227,7 +227,7 @@ TEST(CAPI, LibraryLoadFunctions) {
void TestEncodeDecode(int line, const std::vector<string>& data) {
const tensorflow::int64 n = data.size();
TF_Status* status = TF_NewStatus();
Status status;
for (const std::vector<tensorflow::int64>& dims :
std::vector<std::vector<tensorflow::int64>>{
{n}, {1, n}, {n, 1}, {n / 2, 2}}) {
@ -236,8 +236,8 @@ void TestEncodeDecode(int line, const std::vector<string>& data) {
for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) {
src.flat<tstring>()(i) = data[i];
}
TF_Tensor* dst = TF_TensorFromTensor(src, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_Tensor* dst = TF_TensorFromTensor(src, &status);
ASSERT_TRUE(status.ok()) << status.error_message();
// Convert back to a C++ Tensor and ensure we get expected output.
Tensor output;
@ -249,7 +249,6 @@ void TestEncodeDecode(int line, const std::vector<string>& data) {
TF_DeleteTensor(dst);
}
TF_DeleteStatus(status);
}
TEST(CAPI, TensorEncodeDecodeStrings) {
@ -1394,8 +1393,9 @@ TEST(CAPI, SavedModel) {
TF_Operation* input_op =
TF_GraphOperationByName(graph, input_op_name.c_str());
ASSERT_TRUE(input_op != nullptr);
csession.SetInputs({{input_op, TF_TensorFromTensor(input, s)}});
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
Status status;
csession.SetInputs({{input_op, TF_TensorFromTensor(input, &status)}});
ASSERT_TRUE(status.ok()) << status.error_message();
const tensorflow::string output_op_name(
tensorflow::ParseTensorName(output_name).first);
@ -2522,12 +2522,11 @@ TEST(CAPI, TestTensorIsNotAligned) {
// Take an unaligned slice.
Tensor y = x.Slice(1, 13);
TF_Status* status = TF_NewStatus();
TF_Tensor* a = TF_TensorFromTensor(y, status);
Status status;
TF_Tensor* a = TF_TensorFromTensor(y, &status);
if (EIGEN_MAX_ALIGN_BYTES > 0) {
EXPECT_FALSE(TF_TensorIsAligned(a));
}
TF_DeleteStatus(status);
TF_DeleteTensor(a);
}

View File

@ -17,7 +17,7 @@ limitations under the License.
#include <memory.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/time.h>
#include <time.h>
#include <unistd.h>
#include "tensorflow/c/c_api.h"
@ -58,12 +58,8 @@ int main(int argc, char** argv) {
}
char file_name[100];
struct timeval t;
if (gettimeofday(&t, NULL)) {
perror("gettimeofday failed");
return 1;
}
snprintf(file_name, sizeof(file_name), "test-%d-%ld.txt", getpid(), t.tv_sec);
time_t t = time(NULL);
snprintf(file_name, sizeof(file_name), "test-%d-%ld.txt", getpid(), t);
size_t length = 2 + strlen(path) + strlen(file_name);
char* full_path = malloc(length);

View File

@ -26,8 +26,8 @@ tf_cuda_library(
"c_api.cc",
"c_api_debug.cc",
"c_api_experimental.h",
"c_api_internal.cc",
"c_api_internal.h",
"tensor_handle_interface.h",
],
hdrs = ["c_api.h"],
copts = tf_copts() + tfe_xla_copts(),
@ -89,10 +89,11 @@ tf_cuda_library(
)
filegroup(
name = "pywrap_eager_hdrs",
name = "pywrap_required_hdrs",
srcs = [
"c_api_experimental.h",
"c_api_internal.h",
"tensor_handle_interface.h",
],
visibility = [
"//tensorflow/core:__pkg__",
@ -102,7 +103,10 @@ filegroup(
tf_cuda_library(
name = "c_api_internal",
srcs = ["c_api_experimental.h"],
srcs = [
"c_api_experimental.h",
"tensor_handle_interface.h",
],
hdrs = ["c_api_internal.h"],
visibility = [
"//learning/deepmind/courier:__subpackages__",
@ -125,18 +129,6 @@ tf_cuda_library(
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/distributed_runtime:remote_device",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/eager:remote_tensor_handle",
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
"//tensorflow/core/profiler/lib:profiler_lib",
"//tensorflow/core/profiler/lib:profiler_session",
],
)

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
@ -43,6 +44,7 @@ limitations under the License.
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/platform.h" // NOLINT
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/protobuf/device_filters.pb.h"
#include "tensorflow/core/util/device_name_utils.h"
#ifdef TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@ -81,6 +83,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
@ -93,10 +96,8 @@ using tensorflow::string;
namespace {
const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) {
if (op->inference_ctx) {
return op->inference_ctx->op_def;
}
const tensorflow::OpDef* op_def;
const tensorflow::OpDef* op_def = op->operation.OpDef();
if (op_def) return op_def;
status->status =
tensorflow::OpDefForOp(op->operation.Name().c_str(), &op_def);
return op_def;
@ -265,9 +266,9 @@ tensorflow::Status GetReplacedFromExistingWorkers(
}
tensorflow::Status CreateRemoteContexts(
const std::vector<string>& remote_workers, tensorflow::uint64 context_id,
tensorflow::uint64 context_view_id, int keep_alive_secs,
const tensorflow::ServerDef& server_def,
TFE_Context* ctx, const std::vector<string>& remote_workers,
tensorflow::uint64 context_id, tensorflow::uint64 context_view_id,
int keep_alive_secs, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
const bool lazy_copy_remote_function_inputs,
const tensorflow::eager::CreateContextRequest& base_request) {
@ -296,7 +297,7 @@ tensorflow::Status CreateRemoteContexts(
continue;
}
tensorflow::eager::CreateContextRequest request(base_request);
tensorflow::eager::CreateContextRequest request;
tensorflow::eager::CreateContextResponse* response =
new tensorflow::eager::CreateContextResponse();
request.set_context_id(context_id);
@ -304,6 +305,21 @@ tensorflow::Status CreateRemoteContexts(
*request.mutable_server_def() = server_def;
request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task);
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
server_def.default_session_config());
std::vector<bool> filtered_device_mask;
ctx->context->FilterDevicesForRemoteWorkers(
remote_worker, base_request.cluster_device_attributes(),
&filtered_device_mask);
DCHECK_EQ(filtered_device_mask.size(),
base_request.cluster_device_attributes_size());
for (int i = 0; i < filtered_device_mask.size(); i++) {
if (filtered_device_mask[i]) {
const auto& da = base_request.cluster_device_attributes(i);
*request.add_cluster_device_attributes() = da;
}
}
request.set_async(async);
request.set_keep_alive_secs(keep_alive_secs);
request.set_lazy_copy_remote_function_inputs(
@ -325,13 +341,34 @@ tensorflow::Status CreateRemoteContexts(
}
tensorflow::Status UpdateRemoteContexts(
const std::vector<string>& remote_workers, tensorflow::uint64 context_id,
TFE_Context* ctx, const std::vector<string>& remote_workers,
const std::vector<string>& added_workers,
const std::vector<string>& removed_workers, tensorflow::uint64 context_id,
tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers,
const tensorflow::eager::CreateContextRequest& base_request) {
int num_remote_workers = remote_workers.size();
tensorflow::BlockingCounter counter(num_remote_workers);
std::vector<tensorflow::Status> statuses(num_remote_workers);
int cluster_device_count = base_request.cluster_device_attributes_size();
std::unordered_set<string> added_or_removed(added_workers.begin(),
added_workers.end());
std::copy(removed_workers.begin(), removed_workers.end(),
std::inserter(added_or_removed, added_or_removed.end()));
// Whether each device is in the updated (added or removed) workers
std::vector<bool> device_added_or_removed(cluster_device_count);
for (int i = 0; i < base_request.cluster_device_attributes_size(); i++) {
const auto& da = base_request.cluster_device_attributes().at(i);
tensorflow::DeviceNameUtils::ParsedName pn;
tensorflow::DeviceNameUtils::ParseFullName(da.name(), &pn);
string task_name;
tensorflow::DeviceNameUtils::GetTaskName(pn, &task_name);
if (added_or_removed.find(task_name) != added_or_removed.end()) {
device_added_or_removed[i] = true;
}
}
for (int i = 0; i < num_remote_workers; i++) {
const string& remote_worker = remote_workers[i];
tensorflow::DeviceNameUtils::ParsedName parsed_name;
@ -354,17 +391,42 @@ tensorflow::Status UpdateRemoteContexts(
continue;
}
std::vector<bool> filtered_device_mask;
ctx->context->FilterDevicesForRemoteWorkers(
remote_worker, base_request.cluster_device_attributes(),
&filtered_device_mask);
DCHECK_EQ(filtered_device_mask.size(), cluster_device_count);
// If any of the devices that match the device filters are in the set of
// added or removed workers, we must send a complete UpdateContextRequest.
// Otherwise, only send a simple request to increment context view ID.
std::vector<bool> added_or_removed_filtered_devices(cluster_device_count);
std::transform(device_added_or_removed.begin(),
device_added_or_removed.end(), filtered_device_mask.begin(),
added_or_removed_filtered_devices.begin(),
std::logical_and<bool>());
const bool full_update_request =
std::accumulate(added_or_removed_filtered_devices.begin(),
added_or_removed_filtered_devices.end(), false,
std::logical_or<bool>());
tensorflow::eager::UpdateContextRequest request;
auto* response = new tensorflow::eager::UpdateContextResponse();
*request.mutable_server_def() = server_def;
request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task);
for (const auto& da : base_request.cluster_device_attributes()) {
*request.add_cluster_device_attributes() = da;
}
request.set_context_id(context_id);
request.set_context_view_id(context_view_id);
if (full_update_request) {
*request.mutable_server_def() = server_def;
request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task);
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
server_def.default_session_config());
for (int i = 0; i < cluster_device_count; i++) {
if (filtered_device_mask[i]) {
const auto& da = base_request.cluster_device_attributes(i);
*request.add_cluster_device_attributes() = da;
}
}
}
eager_client->UpdateContextAsync(
&request, response,
@ -409,6 +471,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
// New server created for new server_def. Unused if updating server_def.
std::unique_ptr<tensorflow::ServerInterface> new_server;
tensorflow::EagerContext* context = ctx->context;
tensorflow::GrpcServer* grpc_server;
if (reset_context) {
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
@ -416,26 +479,25 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
LOG_AND_RETURN_IF_ERROR(
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
} else {
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(
ctx->context->GetServer(), worker_name, &curr_remote_workers));
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
&curr_remote_workers));
// No need to check the cast here, since `ListRemoteWorkers` already checks
// if the server is a GRPC server or not.
grpc_server =
dynamic_cast<tensorflow::GrpcServer*>(ctx->context->GetServer());
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
LOG_AND_RETURN_IF_ERROR(
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
}
tensorflow::uint64 context_id = ctx->context->GetContextId();
tensorflow::uint64 context_view_id = ctx->context->GetContextViewId();
tensorflow::uint64 context_id = context->GetContextId();
tensorflow::uint64 context_view_id = context->GetContextViewId();
if (reset_context) {
context_id = tensorflow::EagerContext::NewContextId();
context_view_id = 0;
// Make master eager context accessible by local eager service, which might
// receive send tensor requests from remote workers.
LOG_AND_RETURN_IF_ERROR(grpc_server->AddMasterEagerContextToEagerService(
context_id, ctx->context));
LOG_AND_RETURN_IF_ERROR(
grpc_server->AddMasterEagerContextToEagerService(context_id, context));
}
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
@ -464,11 +526,11 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
&new_remote_device_mgr));
remote_device_mgr = new_remote_device_mgr.get();
} else {
ctx->context->ClearCaches();
context->ClearCachesAndDefaultExecutor();
// TODO(b/143914772): Potential memory leak if rendezvous has pending
// tensors for removed / replaced workers.
remote_device_mgr = ctx->context->GetOwnedRemoteDeviceMgr();
remote_device_mgr = context->GetOwnedRemoteDeviceMgr();
if (remote_device_mgr == nullptr) {
LOG_AND_RETURN_IF_ERROR(tensorflow::errors::InvalidArgument(
"Updating context with an invalid set of remote devices."));
@ -479,8 +541,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
&added_workers, &removed_workers,
&existing_workers);
LOG_AND_RETURN_IF_ERROR(GetReplacedFromExistingWorkers(
&existing_workers, context_id, ctx->context->GetContextViewId(),
server_def, remote_eager_workers.get(), &replaced_workers));
&existing_workers, context_id, context->GetContextViewId(), server_def,
remote_eager_workers.get(), &replaced_workers));
if (VLOG_IS_ON(1)) {
VLOG(1) << "Updating cluster with following changes";
for (const string& w : added_workers) VLOG(1) << " Added worker " << w;
@ -516,7 +578,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
grpc_server->worker_env()->device_mgr->ListDeviceAttributes(
&local_device_attributes);
// This request make sure that we can create Rendevzous properly between
// This request make sure that we can create Rendezvous properly between
// Local and Remote context.
tensorflow::eager::CreateContextRequest base_request;
for (const auto& da : cluster_device_attributes) {
@ -525,18 +587,14 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
for (const auto& da : local_device_attributes) {
*base_request.add_cluster_device_attributes() = da;
}
base_request.mutable_server_def()
->mutable_default_session_config()
->MergeFrom(server_def.default_session_config());
// Initialize remote eager workers.
// TODO(b/138847548) Create remote eager contexts in async mode by default.
if (reset_context) {
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
remote_workers, context_id, context_view_id, keep_alive_secs,
server_def, remote_eager_workers.get(),
ctx->context->Executor().Async(),
ctx->context->LazyCopyFunctionRemoteInputs(), base_request));
ctx, remote_workers, context_id, context_view_id, keep_alive_secs,
server_def, remote_eager_workers.get(), context->Executor().Async(),
context->LazyCopyFunctionRemoteInputs(), base_request));
} else {
// The master's context_view_id will be incremented by one
// the UpdateRemoteMaster call later. We want all new workers and
@ -544,10 +602,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
// we must set their context_view_id to the existing master's
// context_view_id + 1.
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
added_workers, context_id, context_view_id + 1, keep_alive_secs,
server_def, remote_eager_workers.get(),
ctx->context->Executor().Async(),
ctx->context->LazyCopyFunctionRemoteInputs(), base_request));
ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
server_def, remote_eager_workers.get(), context->Executor().Async(),
context->LazyCopyFunctionRemoteInputs(), base_request));
if (!existing_workers.empty()) {
if (VLOG_IS_ON(1)) {
for (const string& w : existing_workers) {
@ -555,8 +612,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
}
}
LOG_AND_RETURN_IF_ERROR(UpdateRemoteContexts(
existing_workers, context_id, context_view_id + 1, server_def,
remote_eager_workers.get(), base_request));
ctx, existing_workers, added_workers, removed_workers, context_id,
context_view_id + 1, server_def, remote_eager_workers.get(),
base_request));
}
}
@ -578,12 +636,12 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
tensorflow::eager::CreateClusterFLR(context_id, ctx->context,
tensorflow::eager::CreateClusterFLR(context_id, context,
worker_session.get());
auto remote_mgr = absl::make_unique<tensorflow::eager::RemoteMgr>(
/*is_master=*/true, ctx->context);
/*is_master=*/true, context);
LOG_AND_RETURN_IF_ERROR(ctx->context->InitializeRemoteMaster(
LOG_AND_RETURN_IF_ERROR(context->InitializeRemoteMaster(
std::move(new_server), grpc_server->worker_env(), worker_session,
std::move(remote_eager_workers), std::move(new_remote_device_mgr),
remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr,
@ -601,9 +659,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
session_name, &worker_session));
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
tensorflow::eager::CreateClusterFLR(context_id, ctx->context,
tensorflow::eager::CreateClusterFLR(context_id, context,
worker_session.get());
LOG_AND_RETURN_IF_ERROR(ctx->context->UpdateRemoteMaster(
LOG_AND_RETURN_IF_ERROR(context->UpdateRemoteMaster(
grpc_server->worker_env(), std::move(remote_eager_workers),
added_workers, removed_workers, context_id, r, device_mgr,
keep_alive_secs, cluster_flr));
@ -614,77 +672,6 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
}
#endif // !IS_MOBILE_PLATFORM
tensorflow::Status OpInferSingleInputAttrs(TFE_Op* op,
TFE_TensorHandle* input) {
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++);
if (!input_def.number_attr().empty() || !input_def.type_list_attr().empty()) {
// Some clients that are still setting their input attributes manually are
// adding input list to their op by calling `TFE_OpAddInput` for each of
// its elements instead of calling `TFE_OpAddInputList`. When this happens,
// we cannot detect the end of such list, thus lose track of the input
// arguments in the op definition. To guarantee backward compatibility with
// those clients, disable automatic inference in this case.
op->inference_ctx.reset(nullptr);
return tensorflow::Status::OK();
}
const std::string& type_attr = input_def.type_attr();
if (!type_attr.empty() && ictx->attrs.find(type_attr) == ictx->attrs.end()) {
op->operation.MutableAttrs()->Set(type_attr, input->handle->dtype);
ictx->attrs.insert(type_attr);
}
return tensorflow::Status::OK();
}
void OpInferSingleTypeInputListAttrs(TFE_Op* op,
const tensorflow::OpDef::ArgDef& input_def,
TFE_TensorHandle** inputs,
int num_inputs) {
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
if (ictx->attrs.find(input_def.number_attr()) == ictx->attrs.end()) {
op->operation.MutableAttrs()->Set(input_def.number_attr(), num_inputs);
ictx->attrs.insert(input_def.number_attr());
}
if (ictx->attrs.find(input_def.type_attr()) == ictx->attrs.end()) {
op->operation.MutableAttrs()->Set(input_def.type_attr(),
inputs[0]->handle->dtype);
ictx->attrs.insert(input_def.type_attr());
}
}
void OpInferMixedTypeInputListAttrs(TFE_Op* op,
const tensorflow::OpDef::ArgDef& input_def,
TFE_TensorHandle** inputs, int num_inputs) {
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
if (ictx->attrs.find(input_def.type_list_attr()) == ictx->attrs.end()) {
std::unique_ptr<tensorflow::DataType[]> dtypes(
new tensorflow::DataType[num_inputs]);
for (int i = 0; i < num_inputs; ++i) {
dtypes[i] = inputs[i]->handle->dtype;
}
op->operation.MutableAttrs()->Set(
input_def.type_list_attr(),
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(dtypes.get(),
num_inputs));
ictx->attrs.insert(input_def.type_list_attr());
}
}
tensorflow::Status OpInferInputListAttrs(TFE_Op* op, TFE_TensorHandle** inputs,
int num_inputs) {
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++);
if (!input_def.type_list_attr().empty()) {
OpInferMixedTypeInputListAttrs(op, input_def, inputs, num_inputs);
} else if (!input_def.type_attr().empty() &&
!input_def.number_attr().empty()) {
OpInferSingleTypeInputListAttrs(op, input_def, inputs, num_inputs);
} else {
return tensorflow::errors::InvalidArgument("Invalid input list definition");
}
return tensorflow::Status::OK();
}
} // namespace
extern "C" {
@ -720,12 +707,14 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get());
return new TFE_Context(opts->session_options.options,
opts->device_placement_policy, opts->mirroring_policy,
opts->async, opts->lazy_remote_inputs_copy,
device_mgr.release(),
/*device_mgr_owned*/ true, r,
tensorflow::GetDefaultCustomKernelCreator());
return new TFE_Context{new tensorflow::EagerContext(
opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
/*device_mgr_owned*/ true, r,
tensorflow::GetDefaultCustomKernelCreator())};
}
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
@ -736,25 +725,33 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr);
return new TFE_Context(opts->session_options.options,
opts->device_placement_policy, opts->mirroring_policy,
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
/*device_mgr_owned*/ false, r,
tensorflow::GetDefaultCustomKernelCreator());
return new TFE_Context{new tensorflow::EagerContext(
opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
/*device_mgr_owned*/ false, r,
tensorflow::GetDefaultCustomKernelCreator())};
}
void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; }
void TFE_DeleteContext(TFE_Context* ctx) {
// context->RefCountIsOne() should be true here.
// TODO(iga): Remove EagerContext refcounting.
ctx->context->Unref();
delete ctx;
}
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
TF_DeviceList* list = new TF_DeviceList;
ctx->context->local_device_mgr()->ListDeviceAttributes(&list->response);
if (ctx->context->remote_device_mgr()) {
ctx->context->remote_device_mgr()->ListDeviceAttributes(&list->response);
}
return list;
TF_DeviceList* l = new TF_DeviceList;
ctx->context->ListDevices(&l->response);
return l;
}
void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context->ClearCaches(); }
void TFE_ContextClearCaches(TFE_Context* ctx) {
ctx->context->ClearCachesAndThreadExecutors();
}
// Set server_def on the context, possibly updating it.
TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
@ -772,6 +769,22 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
"Invalid tensorflow.ServerDef protocol buffer");
return;
}
if (server_def.has_cluster_device_filters()) {
const auto& cdf = server_def.cluster_device_filters();
for (const auto& jdf : cdf.jobs()) {
const string& remote_prefix = "/job:" + jdf.name() + "/task:";
for (const auto& tdf : jdf.tasks()) {
const int32_t task_index = tdf.first;
std::vector<string> device_filters(tdf.second.device_filters_size());
for (int i = 0; i < tdf.second.device_filters_size(); i++) {
device_filters[i] = tdf.second.device_filters(i);
}
const string remote_worker = remote_prefix + std::to_string(task_index);
status->status =
ctx->context->SetRemoteDeviceFilters(remote_worker, device_filters);
}
}
}
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
ctx, /*reset_context=*/true);
#endif // !IS_MOBILE_PLATFORM
@ -796,6 +809,11 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
status->status = tensorflow::errors::InvalidArgument(
"Trying to update a context with invalid context id.");
}
if (server_def.has_cluster_device_filters()) {
LOG(WARNING) << "Device filters can only be specified when initializing "
"the cluster. Any changes in device filters are ignored "
"when updating the server def.";
}
// TODO(haoyuzhang): Check server_def compatibility before the update
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
ctx, /*reset_context=*/false);
@ -810,8 +828,9 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
"TFE_ContextSetServerDef not supported on mobile");
return false;
#else // !defined(IS_MOBILE_PLATFORM)
tensorflow::EagerContext* context = ctx->context;
tensorflow::GrpcServer* grpc_server =
static_cast<tensorflow::GrpcServer*>(ctx->context->GetServer());
static_cast<tensorflow::GrpcServer*>(context->GetServer());
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
status->status = grpc_server->master_env()->worker_cache->GetEagerClientCache(
@ -830,7 +849,7 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
// Send a rpc request to the worker to check aliveness.
tensorflow::eager::KeepAliveRequest request;
request.set_context_id(ctx->context->GetContextId());
request.set_context_id(context->GetContextId());
tensorflow::eager::KeepAliveResponse response;
tensorflow::Status keep_alive_status;
@ -885,108 +904,180 @@ void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
if (h == nullptr) return;
tensorflow::profiler::TraceMe activity(
"TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
VLOG(1) << "Deleting tensor handle " << h << " with internal handle "
<< h->handle;
if (h->handle) {
h->handle->Unref();
}
delete h;
}
tensorflow::TensorHandleInterface::~TensorHandleInterface() {
VLOG(1) << "Deleting tensor handle " << this << " with internal handle "
<< handle_;
if (handle_) {
handle_->Unref();
}
}
bool tensorflow::TensorHandleInterface::IsValid(Status* status) const {
if (handle_ == nullptr) {
*status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return false;
}
return true;
}
TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
return static_cast<TF_DataType>(h->handle->dtype);
return h->handle->DataType();
}
TF_DataType tensorflow::TensorHandleInterface::DataType() const {
return static_cast<TF_DataType>(handle_->dtype);
}
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return -1;
}
return h->handle->NumDims(&status->status);
}
int tensorflow::TensorHandleInterface::NumDims(Status* status) const {
if (!IsValid(status)) {
return -1;
}
int result;
status->status = h->handle->NumDims(&result);
*status = handle_->NumDims(&result);
return result;
}
int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return -1;
}
return h->handle->NumElements(&status->status);
}
int64_t tensorflow::TensorHandleInterface::NumElements(Status* status) const {
if (!IsValid(status)) {
return -1;
}
tensorflow::int64 result;
status->status = h->handle->NumElements(&result);
*status = handle_->NumElements(&result);
return result;
}
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return -1;
}
return h->handle->Dim(dim_index, &status->status);
}
int64_t tensorflow::TensorHandleInterface::Dim(int dim_index,
Status* status) const {
if (!IsValid(status)) {
return -1;
}
tensorflow::int64 result;
status->status = h->handle->Dim(dim_index, &result);
*status = handle_->Dim(dim_index, &result);
return result;
}
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
tensorflow::Device* d = h->handle->op_device();
return h->handle->DeviceName(&status->status);
}
const char* tensorflow::TensorHandleInterface::DeviceName(
Status* status) const {
if (!IsValid(status)) {
return nullptr;
}
tensorflow::Device* d = handle_->op_device();
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str();
}
const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
tensorflow::Device* d = h->handle->device();
return h->handle->BackingDeviceName(&status->status);
}
const char* tensorflow::TensorHandleInterface::BackingDeviceName(
Status* status) const {
if (!IsValid(status)) {
return nullptr;
}
tensorflow::Device* d = handle_->device();
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str();
}
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr || !h->handle->IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
h->handle->Ref();
return new TFE_TensorHandle{
std::unique_ptr<AbstractTensorHandleInterface>(h->handle->Copy())};
}
return new TFE_TensorHandle(h->handle);
AbstractTensorHandleInterface* tensorflow::TensorHandleInterface::Copy() {
handle_->Ref();
return new TensorHandleInterface(handle_);
}
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
tensorflow::TensorHandle* handle = h->handle;
return h->handle->Resolve(&status->status);
}
TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
if (!IsValid(status)) {
return nullptr;
}
// TODO(agarwal): move this implementation inside TFE_TensorHandle.
if (handle->IsRemote()) {
if (handle_->IsRemote()) {
const tensorflow::Tensor* t = nullptr;
tensorflow::TensorHandle* h_cpu = nullptr;
status->status = EagerCopyToDevice(
handle, handle->Context(), &handle->Context()->Executor(),
handle->Context()->HostCPU(), false, &h_cpu);
if (!status->status.ok()) {
*status = EagerCopyToDevice(handle_, handle_->Context(),
&handle_->Context()->Executor(),
handle_->Context()->HostCPU(), false, &h_cpu);
if (!status->ok()) {
return nullptr;
}
status->status = h_cpu->Tensor(&t);
if (!status->status.ok()) {
*status = h_cpu->Tensor(&t);
if (!status->ok()) {
h_cpu->Unref();
return nullptr;
}
@ -995,28 +1086,30 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
return retval;
} else {
tensorflow::Tensor tensor;
if (IsCPU(handle->device())) {
if (IsCPU(handle_->device())) {
const tensorflow::Tensor* src = nullptr;
status->status = handle->Tensor(&src);
if (!status->status.ok()) return nullptr;
*status = handle_->Tensor(&src);
if (!status->ok()) return nullptr;
tensor = *src;
} else {
tensorflow::EagerContext* ctx = handle->Context();
tensorflow::EagerContext* ctx = handle_->Context();
CHECK_NE(ctx, nullptr);
status->status = h->handle->CopyToDevice(ctx, ctx->HostCPU(), &tensor);
if (!status->status.ok()) return nullptr;
*status = handle_->CopyToDevice(*ctx, ctx->HostCPU(), &tensor);
if (!status->ok()) return nullptr;
}
return tensorflow::TF_TensorFromTensor(tensor, status);
}
}
void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr || !h->handle->IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
tensorflow::TensorHandle* handle = h->handle;
tensorflow::TensorHandle* handle =
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle();
if (handle->IsRemote()) {
status->status = tensorflow::errors::InvalidArgument(
@ -1045,7 +1138,8 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg, TF_Status* status) {
tensorflow::Device* device;
status->status = ctx->context->FindDeviceFromName(device_name, &device);
tensorflow::EagerContext* context = ctx->context;
status->status = context->FindDeviceFromName(device_name, &device);
if (!status->status.ok()) {
deallocator(data, len, deallocator_arg);
return nullptr;
@ -1073,11 +1167,12 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
buf->Unref();
tensorflow::TensorHandle* ret_handle;
status->status = tensorflow::TensorHandle::CreateLocalHandle(
t, device, ctx->context, &ret_handle);
t, device, context, &ret_handle);
if (!status->status.ok()) {
return nullptr;
}
return new TFE_TensorHandle(ret_handle);
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(ret_handle)};
}
// This function will block till the operation that produces `h` has
@ -1085,12 +1180,14 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
// bytes of the memory pointed to by the device pointer returned above.
size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr || !h->handle->IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return 0;
}
tensorflow::TensorHandle* handle = h->handle;
tensorflow::TensorHandle* handle =
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle();
if (handle->IsRemote()) {
status->status = tensorflow::errors::InvalidArgument(
@ -1108,8 +1205,14 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) {
return NewOrResetOp(ctx, op_or_function_name, nullptr, status,
/* op_to_reset= */ nullptr);
std::unique_ptr<TFE_Op> new_op(
new TFE_Op{tensorflow::EagerOperation(ctx->context)});
status->status =
new_op->operation.Reset(op_or_function_name, nullptr, false, nullptr);
if (!status->status.ok()) {
new_op.reset();
}
return new_op.release();
}
void TFE_DeleteOp(TFE_Op* op) { delete op; }
@ -1120,7 +1223,7 @@ void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
tensorflow::Device* device = (op->operation.Device() == nullptr)
? op->operation.EagerContext()->HostCPU()
? op->operation.EagerContext().HostCPU()
: op->operation.Device();
return device->name().c_str();
}
@ -1134,20 +1237,23 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
}
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
op->operation.AddInput(input->handle);
if (op->inference_ctx) {
status->status = OpInferSingleInputAttrs(op, input);
}
tensorflow::TensorHandle* h =
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
input->handle.get())
->Handle();
op->operation.AddInput(h);
status->status = op->operation.MaybeInferSingleInputAttrs(h);
}
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
TF_Status* status) {
for (int i = 0; i < num_inputs; ++i) {
op->operation.AddInput(inputs[i]->handle);
}
if (op->inference_ctx) {
status->status = OpInferInputListAttrs(op, inputs, num_inputs);
op->operation.AddInput(
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
inputs[i]->handle.get())
->Handle());
}
status->status = op->operation.InferInputListAttrs(num_inputs);
}
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
@ -1380,15 +1486,16 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) {
VLOG(1) << "Calling TFE_Execute() on op " << op;
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
VLOG(1) << "Calling TFE_Execute() on op " << op;
status->status = tensorflow::EagerExecute(&op->operation,
handle_retvals.data(), num_retvals);
if (!status->status.ok()) {
return;
}
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = new TFE_TensorHandle(handle_retvals[i]);
retvals[i] = new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle_retvals[i])};
}
}
@ -1398,15 +1505,18 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
TF_Status* status) {
tensorflow::TensorHandle* handle = nullptr;
tensorflow::Device* device;
status->status = ctx->context->FindDeviceFromName(device_name, &device);
tensorflow::EagerContext* context = ctx->context;
status->status = context->FindDeviceFromName(device_name, &device);
if (!status->status.ok()) {
return nullptr;
}
status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context,
&ctx->context->Executor(),
device, false, &handle);
status->status = tensorflow::EagerCopyToDevice(
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle(),
context, &context->Executor(), device, false, &handle);
if (status->status.ok()) {
return new TFE_TensorHandle(handle);
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
}
return nullptr;
}
@ -1454,11 +1564,12 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
TF_Status* status) {
status->status = ctx->context->Executor().WaitForAllPendingNodes();
tensorflow::EagerContext* context = ctx->context;
status->status = context->Executor().WaitForAllPendingNodes();
if (!status->status.ok()) return;
tensorflow::mutex_lock ml(*ctx->context->MetadataMu());
status->status = MessageToBuffer(*ctx->context->RunMetadataProto(), buf);
ctx->context->ClearRunMetadata();
tensorflow::mutex_lock ml(*context->MetadataMu());
status->status = MessageToBuffer(*context->RunMetadataProto(), buf);
context->ClearRunMetadata();
}
namespace {

View File

@ -206,14 +206,14 @@ typedef struct TFE_TensorDebugInfo TFE_TensorDebugInfo;
// error and nullptr is returned. This function can block till the operation
// that produces `handle` has completed.
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
TFE_TensorHandle* handle, TF_Status* status);
TFE_TensorHandle* h, TF_Status* status);
// Deletes `debug_info`.
TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo(
TFE_TensorDebugInfo* debug_info);
// Returns the number of dimensions used to represent the tensor on its device.
// The number of dimensions used to reprensent the tensor on device can be
// The number of dimensions used to represent the tensor on device can be
// different from the number returned by TFE_TensorHandleNumDims.
// The return value was current at the time of TFE_TensorDebugInfo creation.
TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims(

View File

@ -28,19 +28,22 @@ using tensorflow::string;
namespace {
std::vector<int64> TensorShapeAsVector(TFE_TensorHandle* handle,
TF_Status* status) {
std::vector<int64> TensorShapeAsVector(const tensorflow::TensorHandle& handle,
tensorflow::Status* status) {
std::vector<int64> shape;
int rank = TFE_TensorHandleNumDims(handle, status);
if (TF_GetCode(status) != TF_OK) {
int rank = -1;
*status = handle.NumDims(&rank);
if (!status->ok()) {
return shape;
}
shape.reserve(rank);
for (int i = 0; i < rank; ++i) {
shape.push_back(TFE_TensorHandleDim(handle, i, status));
if (TF_GetCode(status) != TF_OK) {
tensorflow::int64 dim;
*status = handle.Dim(i, &dim);
if (!status->ok()) {
return shape;
}
shape.push_back(dim);
}
return shape;
}
@ -50,15 +53,20 @@ std::vector<int64> TensorShapeAsVector(TFE_TensorHandle* handle,
extern "C" {
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
TFE_TensorHandle* handle, TF_Status* status) {
TFE_TensorHandle* h, TF_Status* status) {
return h->handle->TensorDebugInfo(&status->status);
}
TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
Status* status) {
const tensorflow::Tensor* tensor;
status->status = handle->handle->Tensor(&tensor);
if (TF_GetCode(status) != TF_OK) {
*status = handle_->Tensor(&tensor);
if (!status->ok()) {
return nullptr;
}
#ifdef TENSORFLOW_EAGER_USE_XLA
tensorflow::Device* device = handle->handle->device();
tensorflow::Device* device = handle_->device();
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
tensorflow::XlaDevice* xla_device =
@ -67,15 +75,15 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
tensorflow::XlaDevice::PaddedShapeFn shape_fn =
xla_device->metadata().padded_shape_fn();
xla::Shape padded_shape;
status->status = shape_fn(*tensor, &padded_shape);
if (!status->status.ok()) {
*status = shape_fn(*tensor, &padded_shape);
if (!status->ok()) {
return nullptr;
}
if (VLOG_IS_ON(3)) {
std::vector<int64> shape_to_log = TensorShapeAsVector(handle, status);
if (!status->status.ok()) {
std::vector<int64> shape_to_log = TensorShapeAsVector(*handle_, status);
if (!status->ok()) {
// Ignore the status here as we are simply logging.
status->status = tensorflow::Status::OK();
*status = tensorflow::Status::OK();
} else {
VLOG(3) << "Fully padded shape of ["
<< absl::StrJoin(shape_to_log, ", ") << "] is "
@ -88,7 +96,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
// Currently, the only case of XlaTensor containing a tuple shape is to
// represent 64 bit ints, doubles, and complex numbers (we don't support
// 64bit complex numbers).
status->status = tensorflow::errors::InvalidArgument(
*status = tensorflow::errors::InvalidArgument(
"XlaTensors should only contain tuples of size 2. Shape: ",
padded_shape.DebugString());
return nullptr;
@ -100,13 +108,13 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
const xla::Shape& shape1 =
xla::ShapeUtil::GetTupleElementShape(padded_shape, 1);
if (shape0.IsTuple() || shape1.IsTuple()) {
status->status = tensorflow::errors::InvalidArgument(
*status = tensorflow::errors::InvalidArgument(
"XlaTensors should not contain nested tuples. Shape: ",
padded_shape.DebugString());
return nullptr;
}
if (!xla::ShapeUtil::Equal(shape0, shape1)) {
status->status = tensorflow::errors::InvalidArgument(
*status = tensorflow::errors::InvalidArgument(
"Subshapes of XlaTensors should be the same. Shape: ",
padded_shape.DebugString());
return nullptr;
@ -131,15 +139,15 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
dev_dims.push_back(padded_shape.dimensions(dim_index));
}
}
status->status = tensorflow::Status::OK();
*status = tensorflow::Status::OK();
return new TFE_TensorDebugInfo(dev_dims);
}
#endif // TENSORFLOW_EAGER_USE_XLA
// If the tensor is not an XLA tensor, the device shape is
// the same as regular tensor shape.
std::vector<int64> dev_dims = TensorShapeAsVector(handle, status);
if (TF_GetCode(status) != TF_OK) {
std::vector<int64> dev_dims = TensorShapeAsVector(*handle_, status);
if (!status->ok()) {
return nullptr;
}
return new TFE_TensorDebugInfo(dev_dims);

View File

@ -18,22 +18,23 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/profiler/rpc/client/capture_profile.h"
#include "tensorflow/core/profiler/rpc/profiler_server.h"
using tensorflow::string;
void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status,
TFE_Op* op_to_reset) {
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status) {
if (op_to_reset) {
NewOrResetOp(ctx, op_or_function_name, raw_device_name, status,
op_to_reset);
status->status = op_to_reset->operation.Reset(
op_or_function_name, raw_device_name, false, nullptr);
} else {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"op_to_reset should not be nullptr");
@ -41,7 +42,9 @@ void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name,
}
void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
op->operation.ConsumeInput(h->handle);
op->operation.ConsumeInput(
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle());
}
TFE_Profiler* TFE_NewProfiler() { return new TFE_Profiler(); }
@ -85,14 +88,14 @@ bool TFE_ProfilerClientStartTracing(const char* service_addr,
int num_tracing_attempts,
TF_Status* status) {
tensorflow::Status s =
tensorflow::profiler::client::ValidateHostPortPair(service_addr);
tensorflow::profiler::ValidateHostPortPair(service_addr);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return false;
}
s = tensorflow::profiler::client::StartTracing(
service_addr, logdir, worker_list, include_dataset_ops, duration_ms,
num_tracing_attempts);
s = tensorflow::profiler::Trace(service_addr, logdir, worker_list,
include_dataset_ops, duration_ms,
num_tracing_attempts);
tensorflow::Set_TF_Status_from_Status(status, s);
return s.ok();
}
@ -101,14 +104,14 @@ void TFE_ProfilerClientMonitor(const char* service_addr, int duration_ms,
int monitoring_level, bool display_timestamp,
TF_Buffer* result, TF_Status* status) {
tensorflow::Status s =
tensorflow::profiler::client::ValidateHostPortPair(service_addr);
tensorflow::profiler::ValidateHostPortPair(service_addr);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return;
}
string content;
s = tensorflow::profiler::client::Monitor(
service_addr, duration_ms, monitoring_level, display_timestamp, &content);
s = tensorflow::profiler::Monitor(service_addr, duration_ms, monitoring_level,
display_timestamp, &content);
void* data = tensorflow::port::Malloc(content.length());
content.copy(static_cast<char*>(data), content.length(), 0);
result->data = data;
@ -616,3 +619,16 @@ void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
return new TFE_Executor(&ctx->context->Executor());
}
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
ctx->context->HostCPU()->parsed_name());
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
void* data = tensorflow::port::Malloc(str.length());
str.copy(static_cast<char*>(data), str.length(), 0);
buf->data = data;
buf->length = str.length();
buf->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
}

View File

@ -29,10 +29,10 @@ extern "C" {
// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster
// than seperately calling it because if the existing op has the same
// `raw_device_name`, it skips parsing and just leave as it is.
TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Context* ctx,
TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset,
const char* op_or_function_name,
const char* raw_device_name,
TF_Status* status, TFE_Op* op_to_reset);
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
TF_Status* status);
@ -458,6 +458,11 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg, TF_Status* status);
// Retrieves the address space (i.e. job, replia, task) of the local host and
// saves it in the buffer.
TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
TF_Buffer* buf);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -1,66 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/host_info.h"
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status,
TFE_Op* op_to_reset) {
const char* name = op_or_function_name; // Shorthand
const tensorflow::AttrTypeMap* types;
bool is_function = false;
status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function);
if (!status->status.ok()) {
return nullptr;
}
if (op_to_reset && op_to_reset->ctx != ctx) {
status->status = tensorflow::errors::Internal(
"Cannot reset a TFE_Op from another TFE_Context");
return nullptr;
}
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
if (!is_function) {
const tensorflow::OpDef* op_def;
status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def);
if (!status->status.ok()) {
return nullptr;
}
inference_ctx.reset(new TFE_OpInferenceContext(op_def));
} else if (!ctx->context->FindFunctionByName(name)) {
status->status = tensorflow::errors::NotFound(
"'", name,
"' is neither a type of a primitive operation nor a name "
"of a function registered in binary running on ",
tensorflow::port::Hostname(),
". Make sure the operation or function is "
"registered in the binary running in this process.");
return nullptr;
}
if (op_to_reset) {
status->status = op_to_reset->Reset(
name, is_function, types, raw_device_name, std::move(inference_ctx));
return op_to_reset;
}
TFE_Op* new_op =
new TFE_Op(ctx, name, is_function, types, std::move(inference_ctx));
status->status = new_op->operation.SetDeviceName(raw_device_name);
return new_op;
}

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
@ -62,36 +63,10 @@ struct TFE_ContextOptions {
};
struct TFE_Context {
TFE_Context(const tensorflow::SessionOptions& opts,
TFE_ContextDevicePlacementPolicy default_device_placement_policy,
TFE_ContextMirroringPolicy default_mirroring_policy, bool async,
const bool lazy_remote_inputs_copy,
const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned,
tensorflow::Rendezvous* rendezvous,
const tensorflow::CustomKernelCreator* custom_kernel_creator)
: context(new tensorflow::EagerContext(
opts,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
default_device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(
default_mirroring_policy),
async, lazy_remote_inputs_copy, device_mgr, device_mgr_owned,
rendezvous, custom_kernel_creator)) {}
~TFE_Context() {
// TODO(iga): Add a separate API method to shutdown TFE_Context so that we
// don't send RPCs and block in destructor.
context->WaitForAndCloseRemoteContexts();
// context->RefCountIsOne() should be true here.
// TODO(iga): Remove EagerContext refcounting.
context->Unref();
}
tensorflow::EagerContext* context;
};
struct TFE_TensorHandle {
explicit TFE_TensorHandle(tensorflow::TensorHandle* h) : handle(h) {}
static TFE_TensorHandle* CreateLocalHandle(const class tensorflow::Tensor& t,
TF_Status* s) {
tensorflow::TensorHandle* handle;
@ -99,10 +74,11 @@ struct TFE_TensorHandle {
if (!s->status.ok()) {
return nullptr;
}
return new TFE_TensorHandle(handle);
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
}
tensorflow::TensorHandle* handle;
std::unique_ptr<AbstractTensorHandleInterface> handle;
};
struct TFE_TensorDebugInfo {
@ -113,46 +89,10 @@ struct TFE_TensorDebugInfo {
std::vector<tensorflow::int64> dev_dims;
};
struct TFE_OpInferenceContext {
explicit TFE_OpInferenceContext(const tensorflow::OpDef* op_def)
: op_def(op_def) {}
const tensorflow::OpDef* op_def; // op definition from protobuf
int input_arg_idx = 0; // arg definition index for the next input to be added
tensorflow::gtl::FlatSet<std::string> attrs; // attributes inferred so far
};
struct TFE_Op {
TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
const tensorflow::AttrTypeMap* t,
std::unique_ptr<TFE_OpInferenceContext> inference_ctx)
: ctx(ctx),
operation(ctx->context, op, is_function, t),
inference_ctx(std::move(inference_ctx)) {}
void Clear() {
operation.Clear();
inference_ctx.reset();
}
tensorflow::Status Reset(const char* op, bool is_function,
const tensorflow::AttrTypeMap* t,
const char* raw_device_name,
std::unique_ptr<TFE_OpInferenceContext> infer_ctx) {
inference_ctx = std::move(infer_ctx);
return operation.Reset(ctx->context, op, is_function, t, raw_device_name,
nullptr);
}
TFE_Context* ctx;
tensorflow::EagerOperation operation;
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
};
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status,
TFE_Op* op_to_reset = nullptr);
struct TFE_Profiler {
explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); }

View File

@ -1362,10 +1362,11 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
TFE_TensorHandle* inputs[] = {input1, input2};
TFE_OpAddInput(concatOp, dim, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CHECK(concatOp->inference_ctx);
CHECK(concatOp->operation.OpDef());
TFE_OpAddInput(concatOp, inputs[0], status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_FALSE(concatOp->inference_ctx) << "Inference context is still present";
EXPECT_FALSE(concatOp->operation.OpDef())
<< "Inference context is still present";
TFE_OpAddInput(concatOp, inputs[1], status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);

View File

@ -284,7 +284,7 @@ class ForwardAccumulator {
// Temporarily push or pop transient state for this accumulator.
//
// Allows an accumulator which is currently processing an operation to
// temporarily reset its state. Without pushing and poping, accumulators
// temporarily reset its state. Without pushing and popping, accumulators
// ignore operations executed as a direct result of their own jvp
// computations.
void PushState() { call_state_.emplace(nullptr, false); }

View File

@ -0,0 +1,90 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
#define TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
// Abstract interface to a TensorHandle.
//
// A TensorHandle is management class around a Tensor which may track additional
// metadata and synchronization.
//
// This allows us to hide concrete implementations of TensorHandle from header
// files. The interface lists the common functionality that must be provided by
// any concrete implementation. However, in cases where the true concrete class
// is needed a static_cast can be applied.
class AbstractTensorHandleInterface {
public:
virtual ~AbstractTensorHandleInterface() {}
// Check if the handle is in a valid initialized state.
virtual bool IsValid(tensorflow::Status* status) const = 0;
// Returns tensor dtype.
virtual TF_DataType DataType() const = 0;
// Returns number of dimensions.
virtual int NumDims(tensorflow::Status* status) const = 0;
// Returns number of elements across all dimensions.
virtual int64_t NumElements(tensorflow::Status* status) const = 0;
// Returns size of specified dimension
virtual int64_t Dim(int dim_index, tensorflow::Status* status) const = 0;
// Returns the device which created the handle.
virtual const char* DeviceName(tensorflow::Status* status) const = 0;
// Returns the device where the tensor was placed.
virtual const char* BackingDeviceName(tensorflow::Status* status) const = 0;
// Returns a tensor for the handle. If tensor is remote, it will be copied.
virtual TF_Tensor* Resolve(tensorflow::Status* status) = 0;
// Returns debug information about the tensor.
virtual TFE_TensorDebugInfo* TensorDebugInfo(tensorflow::Status* status) = 0;
// Return a copy of the handle.
virtual AbstractTensorHandleInterface* Copy() = 0;
};
namespace tensorflow {
class TensorHandleInterface : public AbstractTensorHandleInterface {
public:
explicit TensorHandleInterface(TensorHandle* h) : handle_(h) {}
~TensorHandleInterface() override;
bool IsValid(Status* status) const override;
TF_DataType DataType() const override;
int NumDims(Status* status) const override;
int64_t NumElements(Status* status) const override;
int64_t Dim(int dim_index, Status* status) const override;
const char* DeviceName(Status* status) const override;
const char* BackingDeviceName(Status* status) const override;
TF_Tensor* Resolve(Status* status) override;
TFE_TensorDebugInfo* TensorDebugInfo(Status* status) override;
AbstractTensorHandleInterface* Copy() override;
// TODO(gjn): This is not a very generic interface, but is needed for specific
// use cases.
TensorHandle* Handle() { return handle_; }
private:
TensorHandle* handle_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_

View File

@ -18,37 +18,23 @@ cc_library(
],
)
# Core TensorFlow depends on this, this will be included in main library
cc_library(
name = "filesystem_interface_impl",
srcs = ["filesystem_interface.cc"],
hdrs = ["filesystem_interface.h"],
deps = [
":modular_filesystem",
"//tensorflow/c:tf_file_statistics",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_internal",
"//tensorflow/core:ptr_util",
"//tensorflow/core/platform:env",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:strcat",
"//tensorflow/core/platform:stringpiece",
],
alwayslink = 1,
)
# Core TensorFlow depends on this, will be included in main library
cc_library(
name = "modular_filesystem",
srcs = ["modular_filesystem.cc"],
srcs = [
"modular_filesystem.cc",
"modular_filesystem_registration.cc",
"modular_filesystem_registration.h",
],
hdrs = ["modular_filesystem.h"],
deps = [
":filesystem_interface",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core:lib",
"//tensorflow/c:tf_status_internal",
"//tensorflow/core:ptr_util",
"//tensorflow/core/platform:env",
"//tensorflow/core/platform:strcat",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
],
)
@ -63,16 +49,12 @@ tf_cc_test(
"notap", # b/139060984, requires implementing modular support for Google filesystem
],
deps = [
":filesystem_interface_impl",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_internal",
":modular_filesystem",
"//tensorflow/core:framework_internal",
"//tensorflow/core/lib/io:path",
"//tensorflow/core/platform:env",
"//tensorflow/core/platform:error",
"//tensorflow/core/platform:stacktrace_handler",
"//tensorflow/core/platform:str_util",
"//tensorflow/core/platform:strcat",
"//tensorflow/core/platform:test",
],
)

View File

@ -1,366 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/util/ptr_util.h"
/// This translation unit is linked in core TensorFlow and provides the
/// functionality needed for plugin registration to check ABI/API compatibility,
/// to ensure required methods are present, to ensure plugins are not allowed to
/// change functionality after being loaded and to register the filesystems
/// provided by a plugin. Consult the header file for more information about
/// how this is achieved.
namespace tensorflow {
namespace {
// Checks if the plugin and core ABI numbers match, filling in `status`.
//
// If the numbers don't match, plugin cannot be loaded.
static bool CheckABIHelper(int pluginABI, int coreABI, StringPiece where,
TF_Status* status) {
if (pluginABI != coreABI) {
TF_SetStatus(
status, TF_FAILED_PRECONDITION,
strings::StrCat("Plugin ABI (", pluginABI, ") for ", where,
" operations doesn't match expected core ABI (",
coreABI, "). Plugin cannot be loaded.")
.c_str());
return false;
}
return true;
}
// Checks if the plugin and core ABI numbers match, for all operations.
//
// If the numbers don't match, plugin cannot be loaded.
//
// Uses the simpler `CheckABIHelper(int, int, StringPiece, TF_Status*)`
static bool CheckABI(
int plugin_filesystem_ops_ABI,
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
int plugin_random_access_file_ops_ABI,
const TF_WritableFileOps* plugin_writable_file_ops,
int plugin_writable_file_ops_ABI,
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
int plugin_read_only_memory_region_ops_ABI, TF_Status* status) {
if (!CheckABIHelper(plugin_filesystem_ops_ABI, TF_FILESYSTEM_OPS_ABI,
"filesystem", status))
return false;
if (plugin_random_access_file_ops != nullptr &&
!CheckABIHelper(plugin_random_access_file_ops_ABI,
TF_RANDOM_ACCESS_FILE_OPS_ABI, "random access file",
status))
return false;
if (plugin_writable_file_ops != nullptr &&
!CheckABIHelper(plugin_writable_file_ops_ABI, TF_WRITABLE_FILE_OPS_ABI,
"writable file", status))
return false;
if (plugin_read_only_memory_region_ops != nullptr &&
!CheckABIHelper(plugin_read_only_memory_region_ops_ABI,
TF_READ_ONLY_MEMORY_REGION_OPS_ABI,
"read only memory region", status))
return false;
return true;
}
// Checks if the plugin and core API numbers match, logging mismatches.
static void CheckAPIHelper(int plugin_API, int core_API, StringPiece where) {
if (plugin_API != core_API) {
VLOG(0) << "Plugin API (" << plugin_API << ") for " << where
<< " operations doesn't match expected core API (" << core_API
<< "). Plugin will be loaded but functionality might be missing.";
}
}
// Checks if the plugin and core API numbers match, for all operations.
//
// Uses the simpler `CheckAPIHelper(int, int, StringPiece)`.
static void CheckAPI(
int plugin_filesystem_ops_API,
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
int plugin_random_access_file_ops_API,
const TF_WritableFileOps* plugin_writable_file_ops,
int plugin_writable_file_ops_API,
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
int plugin_read_only_memory_region_ops_API) {
CheckAPIHelper(plugin_filesystem_ops_API, TF_FILESYSTEM_OPS_API,
"filesystem");
if (plugin_random_access_file_ops != nullptr)
CheckAPIHelper(plugin_random_access_file_ops_API,
TF_RANDOM_ACCESS_FILE_OPS_API, "random access file");
if (plugin_writable_file_ops != nullptr)
CheckAPIHelper(plugin_writable_file_ops_API, TF_WRITABLE_FILE_OPS_API,
"writable file");
if (plugin_read_only_memory_region_ops != nullptr)
CheckAPIHelper(plugin_read_only_memory_region_ops_API,
TF_READ_ONLY_MEMORY_REGION_OPS_API,
"read only memory region");
}
// Validates the filesystem operations supplied by the plugin.
static bool ValidateHelper(const TF_FilesystemOps* ops, TF_Status* status) {
if (ops == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without operations");
return false;
}
if (ops->init == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `init` operation");
return false;
}
if (ops->cleanup == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `cleanup` operation");
return false;
}
return true;
}
// Validates the random access file operations supplied by the plugin.
static bool ValidateHelper(const TF_RandomAccessFileOps* ops,
TF_Status* status) {
if (ops == nullptr) {
// We allow filesystems where files can only be written to (from TF code)
return true;
}
if (ops->cleanup == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `cleanup` operation on "
"random access files");
return false;
}
return true;
}
// Validates the writable file operations supplied by the plugin.
static bool ValidateHelper(const TF_WritableFileOps* ops, TF_Status* status) {
if (ops == nullptr) {
// We allow read-only filesystems
return true;
}
if (ops->cleanup == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `cleanup` operation on "
"writable files");
return false;
}
return true;
}
// Validates the read only memory region operations given by the plugin.
static bool ValidateHelper(const TF_ReadOnlyMemoryRegionOps* ops,
TF_Status* status) {
if (ops == nullptr) {
// read only memory region support is always optional
return true;
}
if (ops->cleanup == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `cleanup` operation on "
"read only memory regions");
return false;
}
if (ops->data == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `data` operation on "
"read only memory regions");
return false;
}
if (ops->length == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `length` operation on "
"read only memory regions");
return false;
}
return true;
}
// Validates the operations supplied by the plugin.
//
// Uses the 4 simpler `ValidateHelper(const TF_..., TF_Status*)` to validate
// each individual function table and then checks that the function table for a
// specific file type exists if the plugin offers support for creating that
// type of files.
static bool Validate(
const TF_FilesystemOps* plugin_filesystem_ops,
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
const TF_WritableFileOps* plugin_writable_file_ops,
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
TF_Status* status) {
if (!ValidateHelper(plugin_filesystem_ops, status)) return false;
if (!ValidateHelper(plugin_random_access_file_ops, status)) return false;
if (!ValidateHelper(plugin_writable_file_ops, status)) return false;
if (!ValidateHelper(plugin_read_only_memory_region_ops, status)) return false;
if (plugin_filesystem_ops->new_random_access_file != nullptr &&
plugin_random_access_file_ops == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Filesystem allows creation of random access files but no "
"operations on them have been supplied.");
return false;
}
if ((plugin_filesystem_ops->new_writable_file != nullptr ||
plugin_filesystem_ops->new_appendable_file != nullptr) &&
plugin_writable_file_ops == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Filesystem allows creation of writable files but no "
"operations on them have been supplied.");
return false;
}
if (plugin_filesystem_ops->new_read_only_memory_region_from_file != nullptr &&
plugin_read_only_memory_region_ops == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Filesystem allows creation of readonly memory regions but no "
"operations on them have been supplied.");
return false;
}
return true;
}
// Copies a function table from plugin memory space to core memory space.
//
// This has three benefits:
// * allows having newer plugins than the current core TensorFlow: the
// additional entries in the plugin's table are just discarded;
// * allows having older plugins than the current core TensorFlow (though
// we are still warning users): the entries that core TensorFlow expects
// but plugins didn't provide will be set to `nullptr` values and core
// TensorFlow will know to not call these on behalf of users;
// * increased security as plugins will not be able to alter function table
// after loading up. Thus, malicious plugins can't alter functionality to
// probe for gadgets inside core TensorFlow. We can even protect the area
// of memory where the copies reside to not allow any more writes to it
// after all copies are created.
template <typename T>
static std::unique_ptr<const T> CopyToCore(const T* plugin_ops,
size_t plugin_size) {
if (plugin_ops == nullptr) return nullptr;
size_t copy_size = sizeof(T);
if (plugin_size < copy_size) {
copy_size = plugin_size;
}
auto core_ops = tensorflow::MakeUnique<T>();
memcpy(const_cast<T*>(core_ops.get()), plugin_ops, copy_size);
return core_ops;
}
} // namespace
} // namespace tensorflow
void RegisterFilesystemPlugin(
int plugin_filesystem_ops_ABI, int plugin_filesystem_ops_API,
size_t plugin_filesystem_ops_size, int plugin_random_access_file_ops_ABI,
int plugin_random_access_file_ops_API,
size_t plugin_random_access_file_ops_size, int plugin_writable_file_ops_ABI,
int plugin_writable_file_ops_API, size_t plugin_writable_file_ops_size,
int plugin_read_only_memory_region_ops_ABI,
int plugin_read_only_memory_region_ops_API,
size_t plugin_read_only_memory_region_ops_size, const char* scheme,
const TF_FilesystemOps* plugin_filesystem_ops,
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
const TF_WritableFileOps* plugin_writable_file_ops,
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
TF_Status* status) {
if (scheme == nullptr) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"`scheme` argument must not be `nullptr`.");
return;
}
// ABI numbers must match exactly for plugin to be loaded
if (!tensorflow::CheckABI(
plugin_filesystem_ops_ABI, plugin_random_access_file_ops,
plugin_random_access_file_ops_ABI, plugin_writable_file_ops,
plugin_writable_file_ops_ABI, plugin_read_only_memory_region_ops,
plugin_read_only_memory_region_ops_ABI, status)) {
return;
}
// API numbers should match but mismatch doesn't block plugin load
tensorflow::CheckAPI(plugin_filesystem_ops_API, plugin_random_access_file_ops,
plugin_random_access_file_ops_API,
plugin_writable_file_ops, plugin_writable_file_ops_API,
plugin_read_only_memory_region_ops,
plugin_read_only_memory_region_ops_API);
// Plugin can only be loaded if all supplied ops are valid
if (!tensorflow::Validate(plugin_filesystem_ops,
plugin_random_access_file_ops,
plugin_writable_file_ops,
plugin_read_only_memory_region_ops, status)) {
return;
}
// Copy all the function tables to core TensorFlow memory space
auto core_filesystem_ops = tensorflow::CopyToCore<TF_FilesystemOps>(
plugin_filesystem_ops, plugin_filesystem_ops_size);
auto core_random_access_file_ops =
tensorflow::CopyToCore<TF_RandomAccessFileOps>(
plugin_random_access_file_ops, plugin_random_access_file_ops_size);
auto core_writable_file_ops = tensorflow::CopyToCore<TF_WritableFileOps>(
plugin_writable_file_ops, plugin_writable_file_ops_size);
auto core_read_only_memory_region_ops =
tensorflow::CopyToCore<TF_ReadOnlyMemoryRegionOps>(
plugin_read_only_memory_region_ops,
plugin_read_only_memory_region_ops_size);
// Initialize the opaque filesystem structure
auto filesystem = tensorflow::MakeUnique<TF_Filesystem>();
core_filesystem_ops->init(filesystem.get(), status);
if (!status->status.ok()) {
core_filesystem_ops->cleanup(filesystem.get());
return;
}
// Register new filesystem
status->status = tensorflow::Env::Default()->RegisterFileSystem(
scheme, tensorflow::MakeUnique<tensorflow::ModularFileSystem>(
std::move(filesystem), std::move(core_filesystem_ops),
std::move(core_random_access_file_ops),
std::move(core_writable_file_ops),
std::move(core_read_only_memory_region_ops)));
}

View File

@ -56,7 +56,7 @@ extern "C" {
/// Lifetime: The wrapper data structures are owned by core TensorFlow. The data
/// pointed to by the `void*` members is always owned by the plugin. The plugin
/// will provide functions to call to allocate and deallocate this data (see
/// next section) and core TensorFlow ensures to call these at the proper time.
/// next sections) and core TensorFlow ensures to call these at the proper time.
///
/// Plugins will never receive a `TF_*` pointer that is `nullptr`. Core
/// TensorFlow will never touch the `void*` wrapped by these structures, except
@ -529,7 +529,7 @@ typedef struct TF_FilesystemOps {
/// If `statuses` is not null, plugins must fill each element with detailed
/// status for each file, as if calling `path_exists` on each one. Core
/// TensorFlow initializes the `statuses` array and plugins must use
/// `TF_SetStatus` to set each element instead of dirrectly assigning.
/// `TF_SetStatus` to set each element instead of directly assigning.
///
/// DEFAULT IMPLEMENTATION: Checks existence of every file. Needs
/// `path_exists`.
@ -601,6 +601,10 @@ typedef struct TF_FilesystemOps {
///
/// Plugins must not return `nullptr`. Returning empty strings is allowed.
///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// This function will be called by core TensorFlow to clean up all path
/// arguments for all other methods in the filesystem API.
///
@ -618,6 +622,10 @@ typedef struct TF_FilesystemOps {
/// In case of error, plugins must set `status` to a value different than
/// `TF_OK`, free memory allocated for `entries` and return -1.
///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// Plugins:
/// * Must set `status` to `TF_OK` if all children were returned.
/// * Must set `status` to `TF_NOT_FOUND` if `path` doesn't point to a
@ -654,6 +662,10 @@ typedef struct TF_FilesystemOps {
/// different than `TF_OK`, free any memory that might have been allocated for
/// `entries` and return -1.
///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// Plugins:
/// * Must set `status` to `TF_OK` if all matches were returned.
/// * Might use any other error value for `status` to signal other errors.
@ -736,95 +748,132 @@ constexpr size_t TF_FILESYSTEM_OPS_SIZE = sizeof(TF_FilesystemOps);
/// SECTION 4. Plugin registration and initialization
/// ----------------------------------------------------------------------------
///
/// In this section we define two functions:
/// * `TF_InitPlugin`: must be present in the plugin shared object as it will
/// be called by core TensorFlow when the filesystem plugin is loaded;
/// * `RegisterFilesystemPlugin`: it is implemented by core TensorFlow but
/// plugins must call it in their `TF_InitPlugin`, usually using the macro
/// `TF_REGISTER_FILESYSTEM_PLUGIN`.
/// In this section we define the API used by core TensorFlow to initialize a
/// filesystem provided by a plugin. That is, we define the following:
/// * `TF_InitPlugin` function: must be present in the plugin shared object as
/// it will be called by core TensorFlow when the filesystem plugin is
/// loaded;
/// * `TF_FilesystemPluginOps` struct: used to transfer information between
/// plugins and core TensorFlow about the operations provided and metadata;
/// * `TF_FilesystemPluginInfo` struct: similar to the above structure, but
/// collects information about all the file schemes that the plugin provides
/// support for, as well as about the plugin's memory handling routines;
/// * `TF_SetFilesystemVersionMetadata` function: must be called by plugins in
/// their `TF_InitPlugin` to record the versioning information the plugins
/// are compiled against.
///
/// The `TF_InitPlugin` function is used by plugins to set up the data
/// structures that implement this interface, as presented in Section 2.
///
/// The `RegisterFilesystemPlugin` is used by core TensorFlow to check that
/// plugins satisfy the requirements expected by core TensorFlow, as follows:
/// 1. If ABI numbers don't match we don't load the plugin, else we continue.
/// 2. If the API numbers are mismatched, we warn the user and continue
/// loading the plugin.
/// 3. If any required operation is missing, we stop loading the plugin.
///
/// If all these checks succeed, we copy the plugin operations to a different
/// memory location so that core TensorFlow has the guarantee that they won't be
/// changed by plugins at a later time. Finally, we initialize the opaque
/// pointer of `TF_Filesystem` by calling the required `init` function of
/// `TF_FilesystemOps` and if that succeeds we register the filesystem.
/// structures that implement this interface, as presented in Section 2. In
/// order to not have plugin shared objects call back symbols defined in core
/// TensorFlow, `TF_InitPlugin` has a `TF_FilesystemPluginInfo` argument which
/// the plugin must fill (using the `TF_SetFilesystemVersionMetadata` for the
/// metadata and setting up all the supported operations and the URI schemes
/// that are supported).
// Initializes a TensorFlow plugin.
//
// Must be implemented by the plugin DSO. It is called by TensorFlow runtime.
//
// Filesystem plugins can be loaded on demand by users via
// `Env::LoadLibrary` or during TensorFlow's startup if they are on certain
// paths (although this has a security risk if two plugins register for the
// same filesystem and the malicious one loads before the legimitate one -
// but we consider this to be something that users should care about and
// manage themselves). In both of these cases, core TensorFlow looks for
// the `TF_InitPlugin` symbol and calls that function.
//
// A plugin is loaded only if this `status` is `TF_OK` after the call.
TF_CAPI_EXPORT extern void TF_InitPlugin(TF_Status* status);
/// This structure incorporates the operations defined in Section 2 and the
/// metadata defined in section 3, allowing plugins to define different ops
/// for different URI schemes.
///
/// Every URI scheme is of the form "fs" for URIs of form "fs:///path/to/file".
/// For local filesystems (i.e., when the URI is "/path/to/file"), the scheme
/// must be "". The scheme must never be `nullptr`.
///
/// Every plugin fills this in `TF_InitPlugin`, using the alocator passed as
/// argument to allocate memory. After `TF_InitPlugin` finishes, core
/// TensorFlow uses the information present in this to initialize filesystems
/// for the URI schemes that the plugin requests.
///
/// All pointers defined in this structure point to memory allocated by the DSO
/// using an allocator provided by core TensorFlow when calling `TF_InitPlugin`.
///
/// IMPORTANT: To maintain binary compatibility, the layout of this structure
/// must not change! In the unlikely case that a new type of file needs to be
/// supported, add the new ops and metadata at the end of the structure.
typedef struct TF_FilesystemPluginOps {
char* scheme;
int filesystem_ops_abi;
int filesystem_ops_api;
size_t filesystem_ops_size;
TF_FilesystemOps* filesystem_ops;
int random_access_file_ops_abi;
int random_access_file_ops_api;
size_t random_access_file_ops_size;
TF_RandomAccessFileOps* random_access_file_ops;
int writable_file_ops_abi;
int writable_file_ops_api;
size_t writable_file_ops_size;
TF_WritableFileOps* writable_file_ops;
int read_only_memory_region_ops_abi;
int read_only_memory_region_ops_api;
size_t read_only_memory_region_ops_size;
TF_ReadOnlyMemoryRegionOps* read_only_memory_region_ops;
} TF_FilesystemPluginOps;
/// Registers a filesystem plugin so that core TensorFlow can use it.
/// This structure gathers together all the operations provided by the plugin.
///
/// Must be called by the plugin during `TF_InitPlugin`, usually by using the
/// convenience `TF_REGISTER_FILESYSTEM_PLUGIN` macro.
/// Plugins must provide exactly `num_schemes` elements in the `ops` array.
///
/// Arguments (grouped by category):
/// * `..ABI`: ABI compatibility numbers (see Section 3.).
/// * `..API`: API compatibility numbers (see Section 3.).
/// * `..Size`: Sizes of the operation tables (see Section 3.).
/// * `scheme`: The URI scheme that plugin is registering filesystems for.
/// Must be of the form "fs" for URIs of form "fs:///path/to/file". For
/// local filesystems (i.e., when the URI is "/path/to/file"), `scheme`
/// must be "". Must never be `nullptr`.
/// * `..Ops`: The function tables provided by the plugin. Owned by the
/// plugin, but core TensorFlow makes a copy of these.
/// * `status`: The output variable for representing success/failure.
/// Since memory that is allocated by the DSO gets transferred to core
/// TensorFlow, we need to provide a way for the allocation and deallocation to
/// match. This is why this structure also defines `plugin_memory_allocate` and
/// `plugin_memory_free` members.
///
/// Sets `status` to `TF_OK` if plugin was registered and filesystem operations
/// can be invoked from anywhere during TensorFlow's runtime. Any other value of
/// `status` means that plugin failed to load properly and as such the
/// operations it provides cannot be used at all (i.e., core TensorFlow will
/// never run them, returning early with `TF_UNIMPLEMENTED` or similar error
/// values).
TF_CAPI_EXPORT extern void RegisterFilesystemPlugin(
int pluginFilesystemOpsABI, int pluginFilesystemOpsAPI,
size_t pluginFilesystemOpsSize, int pluginRandomAccessFileOpsABI,
int pluginRandomAccessFileOpsAPI, size_t pluginRandomAccessFileOpsSize,
int pluginWritableFileOpsABI, int pluginWritableFileOpsAPI,
size_t pluginWritableFileOpsSize, int pluginReadOnlyMemoryRegionOpsABI,
int pluginReadOnlyMemoryRegionOpsAPI,
size_t pluginReadOnlyMemoryRegionOpsSize, const char* scheme,
const TF_FilesystemOps* pluginFilesystemOps,
const TF_RandomAccessFileOps* pluginRandomAccessFileOps,
const TF_WritableFileOps* pluginWritableFileOps,
const TF_ReadOnlyMemoryRegionOps* pluginReadOnlyMemoryRegionOps,
TF_Status* status);
/// All memory allocated by the plugin that will be owned by core TensorFlow
/// must be allocated using the allocator in this structure. Core TensorFlow
/// will use the deallocator to free this memory once it no longer needs it.
///
/// IMPORTANT: To maintain binary compatibility, the layout of this structure
/// must not change! In the unlikely case that new global operations must be
/// provided, add them at the end of the structure.
typedef struct TF_FilesystemPluginInfo {
size_t num_schemes;
TF_FilesystemPluginOps* ops;
void* (*plugin_memory_allocate)(size_t size);
void (*plugin_memory_free)(void* ptr);
} TF_FilesystemPluginInfo;
/// This macro is just a convenience wrapper around `RegisterFilesystemPlugin`.
/// Plugins should prefer using this macro instead of a direct call.
#define TF_REGISTER_FILESYSTEM_PLUGIN( \
scheme, pluginFilesystemOps, pluginRandomAccessFileOps, \
pluginWritableFileOps, pluginReadOnlyMemoryRegionOps, status) \
RegisterFilesystemPlugin( \
TF_FILESYSTEM_OPS_ABI, TF_FILESYSTEM_OPS_API, TF_FILESYSTEM_OPS_SIZE, \
TF_RANDOM_ACCESS_FILE_OPS_ABI, TF_RANDOM_ACCESS_FILE_OPS_API, \
TF_RANDOM_ACCESS_FILE_OPS_SIZE, TF_WRITABLE_FILE_OPS_ABI, \
TF_WRITABLE_FILE_OPS_API, TF_WRITABLE_FILE_OPS_SIZE, \
TF_READ_ONLY_MEMORY_REGION_OPS_ABI, TF_READ_ONLY_MEMORY_REGION_OPS_API, \
TF_READ_ONLY_MEMORY_REGION_OPS_SIZE, scheme, pluginFilesystemOps, \
pluginRandomAccessFileOps, pluginWritableFileOps, \
pluginReadOnlyMemoryRegionOps, status)
/// Convenience function for setting the versioning metadata.
///
/// The argument is guaranteed to not be `nullptr`.
///
/// We want this to be defined in the plugin's memory space and we guarantee
/// that core TensorFlow will never call this.
static inline void TF_SetFilesystemVersionMetadata(
TF_FilesystemPluginOps* ops) {
ops->filesystem_ops_abi = TF_FILESYSTEM_OPS_ABI;
ops->filesystem_ops_api = TF_FILESYSTEM_OPS_API;
ops->filesystem_ops_size = TF_FILESYSTEM_OPS_SIZE;
ops->random_access_file_ops_abi = TF_RANDOM_ACCESS_FILE_OPS_ABI;
ops->random_access_file_ops_api = TF_RANDOM_ACCESS_FILE_OPS_API;
ops->random_access_file_ops_size = TF_RANDOM_ACCESS_FILE_OPS_SIZE;
ops->writable_file_ops_abi = TF_WRITABLE_FILE_OPS_ABI;
ops->writable_file_ops_api = TF_WRITABLE_FILE_OPS_API;
ops->writable_file_ops_size = TF_WRITABLE_FILE_OPS_SIZE;
ops->read_only_memory_region_ops_abi = TF_READ_ONLY_MEMORY_REGION_OPS_ABI;
ops->read_only_memory_region_ops_api = TF_READ_ONLY_MEMORY_REGION_OPS_API;
ops->read_only_memory_region_ops_size = TF_READ_ONLY_MEMORY_REGION_OPS_SIZE;
}
/// Initializes a TensorFlow plugin.
///
/// Must be implemented by the plugin DSO. It is called by TensorFlow runtime.
///
/// Filesystem plugins can be loaded on demand by users via
/// `Env::LoadLibrary` or during TensorFlow's startup if they are on certain
/// paths (although this has a security risk if two plugins register for the
/// same filesystem and the malicious one loads before the legimitate one -
/// but we consider this to be something that users should care about and
/// manage themselves). In both of these cases, core TensorFlow looks for
/// the `TF_InitPlugin` symbol and calls this function.
///
/// For every filesystem URI scheme that this plugin supports, the plugin must
/// add one `TF_FilesystemPluginInfo` entry in `plugin_info->ops` and call
/// `TF_SetFilesystemVersionMetadata` for that entry.
///
/// Plugins must also initialize `plugin_info->plugin_memory_allocate` and
/// `plugin_info->plugin_memory_free` to ensure memory allocated by plugin is
/// freed in a compatible way.
TF_CAPI_EXPORT extern void TF_InitPlugin(TF_FilesystemPluginInfo* plugin_info);
#ifdef __cplusplus
} // end extern "C"

View File

@ -18,11 +18,10 @@ limitations under the License.
#include <string>
#include <utility>
#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/file_system_helper.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/util/ptr_util.h"
// TODO(mihaimaruseac): After all filesystems are converted, all calls to
@ -165,16 +164,18 @@ Status ModularFileSystem::GetChildren(const std::string& dir,
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
std::string translated_name = TranslateName(dir);
char** children;
// Note that `children` is allocated by the plugin and freed by core
// TensorFlow, so we need to use `plugin_memory_free_` here.
char** children = nullptr;
const int num_children =
ops_->get_children(filesystem_.get(), translated_name.c_str(), &children,
plugin_status.get());
if (num_children >= 0) {
for (int i = 0; i < num_children; i++) {
result->push_back(std::string(children[i]));
free(children[i]);
plugin_memory_free_(children[i]);
}
free(children);
plugin_memory_free_(children);
}
return StatusFromTF_Status(plugin_status.get());
@ -186,15 +187,17 @@ Status ModularFileSystem::GetMatchingPaths(const std::string& pattern,
return internal::GetMatchingPaths(this, Env::Default(), pattern, result);
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
char** matches;
// Note that `matches` is allocated by the plugin and freed by core
// TensorFlow, so we need to use `plugin_memory_free_` here.
char** matches = nullptr;
const int num_matches = ops_->get_matching_paths(
filesystem_.get(), pattern.c_str(), &matches, plugin_status.get());
if (num_matches >= 0) {
for (int i = 0; i < num_matches; i++) {
result->push_back(std::string(matches[i]));
free(matches[i]);
plugin_memory_free_(matches[i]);
}
free(matches);
plugin_memory_free_(matches);
}
return StatusFromTF_Status(plugin_status.get());
@ -358,7 +361,8 @@ std::string ModularFileSystem::TranslateName(const std::string& name) const {
CHECK(p != nullptr) << "TranslateName(" << name << ") returned nullptr";
std::string ret(p);
free(p);
// Since `p` is allocated by plugin, free it using plugin's method.
plugin_memory_free_(p);
return ret;
}
@ -435,4 +439,8 @@ Status ModularWritableFile::Tell(int64* position) {
return StatusFromTF_Status(plugin_status.get());
}
Status RegisterFilesystemPlugin(const std::string& dso_path) {
return filesystem_registration::RegisterFilesystemPluginImpl(dso_path);
}
} // namespace tensorflow

View File

@ -32,7 +32,7 @@ namespace tensorflow {
// TODO(b/143949615): After all filesystems are converted, this file will be
// moved to core/platform, and this class can become a singleton and replace the
// need for `Env::Default()`. At that time, we might decide to remove the need
// for `Env::Default()` altoghether, but that's a different project, not in
// for `Env::Default()` altogether, but that's a different project, not in
// scope for now. I'm just mentioning this here as that transition will mean
// removal of the registration part from `Env` and adding it here instead: we
// will need tables to hold for each scheme the function tables that implement
@ -46,12 +46,16 @@ class ModularFileSystem final : public FileSystem {
std::unique_ptr<const TF_RandomAccessFileOps> random_access_file_ops,
std::unique_ptr<const TF_WritableFileOps> writable_file_ops,
std::unique_ptr<const TF_ReadOnlyMemoryRegionOps>
read_only_memory_region_ops)
read_only_memory_region_ops,
std::function<void*(size_t)> plugin_memory_allocate,
std::function<void(void*)> plugin_memory_free)
: filesystem_(std::move(filesystem)),
ops_(std::move(filesystem_ops)),
random_access_file_ops_(std::move(random_access_file_ops)),
writable_file_ops_(std::move(writable_file_ops)),
read_only_memory_region_ops_(std::move(read_only_memory_region_ops)) {}
read_only_memory_region_ops_(std::move(read_only_memory_region_ops)),
plugin_memory_allocate_(std::move(plugin_memory_allocate)),
plugin_memory_free_(std::move(plugin_memory_free)) {}
~ModularFileSystem() override { ops_->cleanup(filesystem_.get()); }
@ -93,6 +97,8 @@ class ModularFileSystem final : public FileSystem {
std::unique_ptr<const TF_WritableFileOps> writable_file_ops_;
std::unique_ptr<const TF_ReadOnlyMemoryRegionOps>
read_only_memory_region_ops_;
std::function<void*(size_t)> plugin_memory_allocate_;
std::function<void(void*)> plugin_memory_free_;
TF_DISALLOW_COPY_AND_ASSIGN(ModularFileSystem);
};
@ -156,6 +162,9 @@ class ModularReadOnlyMemoryRegion final : public ReadOnlyMemoryRegion {
TF_DISALLOW_COPY_AND_ASSIGN(ModularReadOnlyMemoryRegion);
};
// Registers a filesystem plugin so that core TensorFlow can use it.
Status RegisterFilesystemPlugin(const std::string& dso_path);
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_H_

View File

@ -0,0 +1,346 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
// Checks that all schemes provided by a plugin are valid.
// TODO(mihaimaruseac): More validation could be done here, based on supported
// charset, maximum length, etc. Punting it for later.
static Status ValidateScheme(const char* scheme) {
if (scheme == nullptr)
return errors::InvalidArgument(
"Attempted to register filesystem with `nullptr` URI scheme");
return Status::OK();
}
// Checks if the plugin and core ABI numbers match.
//
// If the numbers don't match, plugin cannot be loaded.
static Status CheckABI(int pluginABI, int coreABI, StringPiece where) {
if (pluginABI != coreABI)
return errors::FailedPrecondition(
strings::StrCat("Plugin ABI (", pluginABI, ") for ", where,
" operations doesn't match expected core ABI (",
coreABI, "). Plugin cannot be loaded."));
return Status::OK();
}
// Checks if the plugin and core ABI numbers match, for all operations.
//
// If the numbers don't match, plugin cannot be loaded.
//
// Uses the simpler `CheckABI(int, int, StringPiece)`.
static Status ValidateABI(const TF_FilesystemPluginOps* ops) {
TF_RETURN_IF_ERROR(
CheckABI(ops->filesystem_ops_abi, TF_FILESYSTEM_OPS_ABI, "filesystem"));
if (ops->random_access_file_ops != nullptr)
TF_RETURN_IF_ERROR(CheckABI(ops->random_access_file_ops_abi,
TF_RANDOM_ACCESS_FILE_OPS_ABI,
"random access file"));
if (ops->writable_file_ops != nullptr)
TF_RETURN_IF_ERROR(CheckABI(ops->writable_file_ops_abi,
TF_WRITABLE_FILE_OPS_ABI, "writable file"));
if (ops->read_only_memory_region_ops != nullptr)
TF_RETURN_IF_ERROR(CheckABI(ops->read_only_memory_region_ops_abi,
TF_READ_ONLY_MEMORY_REGION_OPS_ABI,
"read only memory region"));
return Status::OK();
}
// Checks if the plugin and core API numbers match, logging mismatches.
static void CheckAPI(int plugin_API, int core_API, StringPiece where) {
if (plugin_API != core_API) {
VLOG(0) << "Plugin API (" << plugin_API << ") for " << where
<< " operations doesn't match expected core API (" << core_API
<< "). Plugin will be loaded but functionality might be missing.";
}
}
// Checks if the plugin and core API numbers match, for all operations.
//
// Uses the simpler `CheckAPIHelper(int, int, StringPiece)`.
static void ValidateAPI(const TF_FilesystemPluginOps* ops) {
CheckAPI(ops->filesystem_ops_api, TF_FILESYSTEM_OPS_API, "filesystem");
if (ops->random_access_file_ops != nullptr)
CheckAPI(ops->random_access_file_ops_api, TF_RANDOM_ACCESS_FILE_OPS_API,
"random access file");
if (ops->writable_file_ops != nullptr)
CheckAPI(ops->writable_file_ops_api, TF_WRITABLE_FILE_OPS_API,
"writable file");
if (ops->read_only_memory_region_ops != nullptr)
CheckAPI(ops->read_only_memory_region_ops_api,
TF_READ_ONLY_MEMORY_REGION_OPS_API, "read only memory region");
}
// Validates the filesystem operations supplied by the plugin.
static Status ValidateHelper(const TF_FilesystemOps* ops) {
if (ops == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without operations");
if (ops->init == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `init` operation");
if (ops->cleanup == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `cleanup` operation");
return Status::OK();
}
// Validates the random access file operations supplied by the plugin.
static Status ValidateHelper(const TF_RandomAccessFileOps* ops) {
if (ops == nullptr) {
// We allow filesystems where files can only be written to (from TF code)
return Status::OK();
}
if (ops->cleanup == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `cleanup` operation on random "
"access files");
return Status::OK();
}
// Validates the writable file operations supplied by the plugin.
static Status ValidateHelper(const TF_WritableFileOps* ops) {
if (ops == nullptr) {
// We allow read-only filesystems
return Status::OK();
}
if (ops->cleanup == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `cleanup` operation on writable "
"files");
return Status::OK();
}
// Validates the read only memory region operations given by the plugin.
static Status ValidateHelper(const TF_ReadOnlyMemoryRegionOps* ops) {
if (ops == nullptr) {
// read only memory region support is always optional
return Status::OK();
}
if (ops->cleanup == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `cleanup` operation on read "
"only memory regions");
if (ops->data == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `data` operation on read only "
"memory regions");
if (ops->length == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `length` operation on read only "
"memory regions");
return Status::OK();
}
// Validates the operations supplied by the plugin.
//
// Uses the 4 simpler `ValidateHelper(const TF_...*)` to validate each
// individual function table and then checks that the function table for a
// specific file type exists if the plugin offers support for creating that
// type of files.
static Status ValidateOperations(const TF_FilesystemPluginOps* ops) {
TF_RETURN_IF_ERROR(ValidateHelper(ops->filesystem_ops));
TF_RETURN_IF_ERROR(ValidateHelper(ops->random_access_file_ops));
TF_RETURN_IF_ERROR(ValidateHelper(ops->writable_file_ops));
TF_RETURN_IF_ERROR(ValidateHelper(ops->read_only_memory_region_ops));
if (ops->filesystem_ops->new_random_access_file != nullptr &&
ops->random_access_file_ops == nullptr)
return errors::FailedPrecondition(
"Filesystem allows creation of random access files but no "
"operations on them have been supplied.");
if ((ops->filesystem_ops->new_writable_file != nullptr ||
ops->filesystem_ops->new_appendable_file != nullptr) &&
ops->writable_file_ops == nullptr)
return errors::FailedPrecondition(
"Filesystem allows creation of writable files but no "
"operations on them have been supplied.");
if (ops->filesystem_ops->new_read_only_memory_region_from_file != nullptr &&
ops->read_only_memory_region_ops == nullptr)
return errors::FailedPrecondition(
"Filesystem allows creation of readonly memory regions but no "
"operations on them have been supplied.");
return Status::OK();
}
// Copies a function table from plugin memory space to core memory space.
//
// This has three benefits:
// * allows having newer plugins than the current core TensorFlow: the
// additional entries in the plugin's table are just discarded;
// * allows having older plugins than the current core TensorFlow (though
// we are still warning users): the entries that core TensorFlow expects
// but plugins didn't provide will be set to `nullptr` values and core
// TensorFlow will know to not call these on behalf of users;
// * increased security as plugins will not be able to alter function table
// after loading up. Thus, malicious plugins can't alter functionality to
// probe for gadgets inside core TensorFlow. We can even protect the area
// of memory where the copies reside to not allow any more writes to it
// after all copies are created.
template <typename T>
static std::unique_ptr<const T> CopyToCore(const T* plugin_ops,
size_t plugin_size) {
if (plugin_ops == nullptr) return nullptr;
size_t copy_size = std::min(plugin_size, sizeof(T));
auto core_ops = tensorflow::MakeUnique<T>();
memset(core_ops.get(), 0, sizeof(T));
memcpy(core_ops.get(), plugin_ops, copy_size);
return core_ops;
}
// Registers one filesystem from the plugin.
//
// Must be called only with `index` a valid index in `info->ops`.
static Status RegisterFileSystem(const TF_FilesystemPluginInfo* info,
int index) {
// Step 1: Copy all the function tables to core TensorFlow memory space
auto core_filesystem_ops = CopyToCore<TF_FilesystemOps>(
info->ops[index].filesystem_ops, info->ops[index].filesystem_ops_size);
auto core_random_access_file_ops = CopyToCore<TF_RandomAccessFileOps>(
info->ops[index].random_access_file_ops,
info->ops[index].random_access_file_ops_size);
auto core_writable_file_ops =
CopyToCore<TF_WritableFileOps>(info->ops[index].writable_file_ops,
info->ops[index].writable_file_ops_size);
auto core_read_only_memory_region_ops =
CopyToCore<TF_ReadOnlyMemoryRegionOps>(
info->ops[index].read_only_memory_region_ops,
info->ops[index].read_only_memory_region_ops_size);
// Step 2: Initialize the opaque filesystem structure
auto filesystem = tensorflow::MakeUnique<TF_Filesystem>();
TF_Status* c_status = TF_NewStatus();
Status status = Status::OK();
core_filesystem_ops->init(filesystem.get(), c_status);
status = Status(c_status->status);
TF_DeleteStatus(c_status);
if (!status.ok()) return status;
// Step 3: Actual registration
return Env::Default()->RegisterFileSystem(
info->ops[index].scheme,
tensorflow::MakeUnique<tensorflow::ModularFileSystem>(
std::move(filesystem), std::move(core_filesystem_ops),
std::move(core_random_access_file_ops),
std::move(core_writable_file_ops),
std::move(core_read_only_memory_region_ops),
info->plugin_memory_allocate, info->plugin_memory_free));
}
// Registers filesystem at `index`, if plugin is providing valid information.
//
// Extracted to a separate function so that pointers inside `info` are freed
// by the caller regardless of whether validation/registration failed or not.
//
// Must be called only with `index` a valid index in `info->ops`.
static Status ValidateAndRegisterFilesystems(
const TF_FilesystemPluginInfo* info, int index) {
TF_RETURN_IF_ERROR(ValidateScheme(info->ops[index].scheme));
TF_RETURN_IF_ERROR(ValidateABI(&info->ops[index]));
ValidateAPI(&info->ops[index]); // we just warn on API number mismatch
TF_RETURN_IF_ERROR(ValidateOperations(&info->ops[index]));
TF_RETURN_IF_ERROR(RegisterFileSystem(info, index));
return Status::OK();
}
// Ensures that the plugin provides the required memory management operations.
static Status ValidatePluginMemoryRoutines(
const TF_FilesystemPluginInfo* info) {
if (info->plugin_memory_allocate == nullptr)
return errors::FailedPrecondition(
"Cannot load filesystem plugin which does not provide "
"`plugin_memory_allocate`");
if (info->plugin_memory_free == nullptr)
return errors::FailedPrecondition(
"Cannot load filesystem plugin which does not provide "
"`plugin_memory_free`");
return Status::OK();
}
namespace filesystem_registration {
Status RegisterFilesystemPluginImpl(const std::string& dso_path) {
// Step 1: Load plugin
Env* env = Env::Default();
void* dso_handle;
TF_RETURN_IF_ERROR(env->LoadLibrary(dso_path.c_str(), &dso_handle));
// Step 2: Load symbol for `TF_InitPlugin`
void* dso_symbol;
TF_RETURN_IF_ERROR(
env->GetSymbolFromLibrary(dso_handle, "TF_InitPlugin", &dso_symbol));
// Step 3: Call `TF_InitPlugin`
TF_FilesystemPluginInfo info;
memset(&info, 0, sizeof(info));
auto TF_InitPlugin =
reinterpret_cast<int (*)(TF_FilesystemPluginInfo*)>(dso_symbol);
TF_InitPlugin(&info);
// Step 4: Ensure plugin provides the memory management functions.
TF_RETURN_IF_ERROR(ValidatePluginMemoryRoutines(&info));
// Step 5: Validate and register all filesystems
// Try to register as many filesystems as possible.
// Free memory once we no longer need it
Status status;
for (int i = 0; i < info.num_schemes; i++) {
status.Update(ValidateAndRegisterFilesystems(&info, i));
info.plugin_memory_free(info.ops[i].scheme);
info.plugin_memory_free(info.ops[i].filesystem_ops);
info.plugin_memory_free(info.ops[i].random_access_file_ops);
info.plugin_memory_free(info.ops[i].writable_file_ops);
info.plugin_memory_free(info.ops[i].read_only_memory_region_ops);
}
info.plugin_memory_free(info.ops);
return status;
}
} // namespace filesystem_registration
} // namespace tensorflow

View File

@ -0,0 +1,28 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
namespace filesystem_registration {
Status RegisterFilesystemPluginImpl(const std::string& dso_path);
} // namespace filesystem_registration
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_

View File

@ -1,35 +1,47 @@
# Experimental posix filesystem plugin.
load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
package(
default_visibility = ["//visibility:private"],
licenses = ["notice"], # Apache 2.0
)
# Although this target results in a shared object that will be loaded at
# runtime, this target must be a `cc_library` instead of a `cc_binary`. Making
# it a `cc_binary` requires `linkshared = True`. In turn, this brings in several
# TensorFlow symbols under `tensorflow::` namespace, for which we have no ABI
# guarantees. Hence, in order to maintain ABI compatibility, this is marked as a
# `cc_library` for now and we will revisit in the future.
# TODO(mihaimaruseac): Determine if `cc_binary` makes more sense (when all
# filesystems are converted and BUILD files are refactored to be modular).
# TODO(b/144585140): The helpers should be separated into a different BUILD target
# but doing that would result in symbols not being visible when loading plugin.
# Revisit this once POSIX filesystem completely lands. See also the other TODO.
# This also has the unfortunate effect that both versions of copy_file get
# compiled, regardless of which one actually gets used!
# Filesystem implementation for POSIX environments: Linux, MacOS, Android, etc.
tf_cc_shared_object(
name = "libposix_filesystem.so",
framework_so = [],
linkstatic = False,
visibility = ["//visibility:public"],
deps = [":posix_filesystem_impl"],
)
# The real implementation of the filesystem.
cc_library(
name = "posix_filesystem",
srcs = [
"posix_filesystem.cc",
"posix_filesystem_helper.cc",
"posix_filesystem_helper.h",
"copy_file.h",
] + select({
"//tensorflow:linux_x86_64": ["copy_file_linux.cc"],
"//conditions:default": ["copy_file_portable.cc"],
}),
name = "posix_filesystem_impl",
srcs = ["posix_filesystem.cc"],
deps = [
":posix_filesystem_helper",
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
],
)
# Library implementing helper functionality, so that the above only contains
# the API implementation for modular filesystems.
cc_library(
name = "posix_filesystem_helper",
srcs = ["posix_filesystem_helper.cc"],
hdrs = ["posix_filesystem_helper.h"],
deps = [":copy_file"],
)
# On Linux, we can copy files faster using `sendfile`. But not elsewhere.
# Hence, this private library to select which implementation to use.
cc_library(
name = "copy_file",
srcs = select({
"//tensorflow:linux_x86_64": ["copy_file_linux.cc"],
"//conditions:default": ["copy_file_portable.cc"],
}),
hdrs = ["copy_file.h"],
)

View File

@ -24,8 +24,6 @@ limitations under the License.
#include <sys/stat.h>
#include <unistd.h>
#include <vector>
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_helper.h"
#include "tensorflow/c/tf_status.h"
@ -33,6 +31,9 @@ limitations under the License.
// Implementation of a filesystem for POSIX environments.
// This filesystem will support `file://` and empty (local) URI schemes.
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); }
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {
@ -45,7 +46,9 @@ typedef struct PosixFile {
static void Cleanup(TF_RandomAccessFile* file) {
auto posix_file = static_cast<PosixFile*>(file->plugin_file);
close(posix_file->fd);
free(const_cast<char*>(posix_file->filename));
// This would be safe to free using `free` directly as it is only opaque.
// However, it is better to be consistent everywhere.
plugin_memory_free(const_cast<char*>(posix_file->filename));
delete posix_file;
}
@ -100,7 +103,7 @@ typedef struct PosixFile {
static void Cleanup(TF_WritableFile* file) {
auto posix_file = static_cast<PosixFile*>(file->plugin_file);
free(const_cast<char*>(posix_file->filename));
plugin_memory_free(const_cast<char*>(posix_file->filename));
delete posix_file;
}
@ -383,12 +386,13 @@ static int GetChildren(const TF_Filesystem* filesystem, const char* path,
if (num_entries < 0) {
TF_SetStatusFromIOError(status, errno, path);
} else {
*entries = static_cast<char**>(calloc(num_entries, sizeof((*entries)[0])));
*entries = static_cast<char**>(
plugin_memory_allocate(num_entries * sizeof((*entries)[0])));
for (int i = 0; i < num_entries; i++) {
(*entries)[i] = strdup(dir_entries[i]->d_name);
free(dir_entries[i]);
plugin_memory_free(dir_entries[i]);
}
free(dir_entries);
plugin_memory_free(dir_entries);
}
return num_entries;
@ -396,48 +400,59 @@ static int GetChildren(const TF_Filesystem* filesystem, const char* path,
} // namespace tf_posix_filesystem
void TF_InitPlugin(TF_Status* status) {
TF_RandomAccessFileOps random_access_file_ops = {
tf_random_access_file::Cleanup,
tf_random_access_file::Read,
};
TF_WritableFileOps writable_file_ops = {
tf_writable_file::Cleanup, tf_writable_file::Append,
tf_writable_file::Tell, tf_writable_file::Flush,
tf_writable_file::Sync, tf_writable_file::Close,
};
TF_ReadOnlyMemoryRegionOps read_only_memory_region_ops = {
tf_read_only_memory_region::Cleanup,
tf_read_only_memory_region::Data,
tf_read_only_memory_region::Length,
};
TF_FilesystemOps filesystem_ops = {
tf_posix_filesystem::Init,
tf_posix_filesystem::Cleanup,
tf_posix_filesystem::NewRandomAccessFile,
tf_posix_filesystem::NewWritableFile,
tf_posix_filesystem::NewAppendableFile,
tf_posix_filesystem::NewReadOnlyMemoryRegionFromFile,
tf_posix_filesystem::CreateDir,
/*recursively_create_dir=*/nullptr,
tf_posix_filesystem::DeleteFile,
tf_posix_filesystem::DeleteDir,
/*delete_recursively=*/nullptr,
tf_posix_filesystem::RenameFile,
tf_posix_filesystem::CopyFile,
tf_posix_filesystem::PathExists,
/*paths_exist=*/nullptr,
tf_posix_filesystem::Stat,
/*is_directory=*/nullptr,
/*get_file_size=*/nullptr,
/*translate_name=*/nullptr,
tf_posix_filesystem::GetChildren,
/*get_matching_paths=*/nullptr,
/*flush_caches=*/nullptr,
};
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
const char* uri) {
TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri);
for (const char* scheme : {"", "file"})
TF_REGISTER_FILESYSTEM_PLUGIN(scheme, &filesystem_ops,
&random_access_file_ops, &writable_file_ops,
&read_only_memory_region_ops, status);
ops->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup;
ops->random_access_file_ops->read = tf_random_access_file::Read;
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
ops->writable_file_ops->append = tf_writable_file::Append;
ops->writable_file_ops->tell = tf_writable_file::Tell;
ops->writable_file_ops->flush = tf_writable_file::Flush;
ops->writable_file_ops->sync = tf_writable_file::Sync;
ops->writable_file_ops->close = tf_writable_file::Close;
ops->read_only_memory_region_ops = static_cast<TF_ReadOnlyMemoryRegionOps*>(
plugin_memory_allocate(TF_READ_ONLY_MEMORY_REGION_OPS_SIZE));
ops->read_only_memory_region_ops->cleanup =
tf_read_only_memory_region::Cleanup;
ops->read_only_memory_region_ops->data = tf_read_only_memory_region::Data;
ops->read_only_memory_region_ops->length = tf_read_only_memory_region::Length;
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_posix_filesystem::Init;
ops->filesystem_ops->cleanup = tf_posix_filesystem::Cleanup;
ops->filesystem_ops->new_random_access_file =
tf_posix_filesystem::NewRandomAccessFile;
ops->filesystem_ops->new_writable_file = tf_posix_filesystem::NewWritableFile;
ops->filesystem_ops->new_appendable_file =
tf_posix_filesystem::NewAppendableFile;
ops->filesystem_ops->new_read_only_memory_region_from_file =
tf_posix_filesystem::NewReadOnlyMemoryRegionFromFile;
ops->filesystem_ops->create_dir = tf_posix_filesystem::CreateDir;
ops->filesystem_ops->delete_file = tf_posix_filesystem::DeleteFile;
ops->filesystem_ops->delete_dir = tf_posix_filesystem::DeleteDir;
ops->filesystem_ops->rename_file = tf_posix_filesystem::RenameFile;
ops->filesystem_ops->copy_file = tf_posix_filesystem::CopyFile;
ops->filesystem_ops->path_exists = tf_posix_filesystem::PathExists;
ops->filesystem_ops->stat = tf_posix_filesystem::Stat;
ops->filesystem_ops->get_children = tf_posix_filesystem::GetChildren;
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
info->plugin_memory_allocate = plugin_memory_allocate;
info->plugin_memory_free = plugin_memory_free;
info->num_schemes = 2;
info->ops = static_cast<TF_FilesystemPluginOps*>(
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
ProvideFilesystemSupportFor(&info->ops[0], "");
ProvideFilesystemSupportFor(&info->ops[1], "file");
}

View File

@ -44,7 +44,7 @@ int TransferFileContents(const char* src, const char* dst, mode_t mode,
}
// Both files have been opened, do the transfer.
// Since errno would be overriden by `close` below, save it here.
// Since errno would be overridden by `close` below, save it here.
int error_code = 0;
if (CopyFileContents(dst_fd, src_fd, size) < 0) error_code = errno;

View File

@ -0,0 +1,36 @@
# Experimental windows filesystem plugin.
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
package(
licenses = ["notice"], # Apache 2.0
)
# Filesystem implementation for Windows environment
tf_cc_shared_object(
name = "windows_filesystem.dll",
framework_so = [],
linkstatic = False,
tags = [
"manual",
"nobuilder",
"notap",
],
visibility = ["//visibility:public"],
deps = [":windows_filesystem_impl"],
)
# The real implementation of the filesystem.
cc_library(
name = "windows_filesystem_impl",
srcs = ["windows_filesystem.cc"],
copts = get_win_copts(),
tags = [
"manual",
"nobuilder",
"notap",
],
deps = [
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
],
)

View File

@ -0,0 +1,73 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stdlib.h>
#include <string.h>
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h"
// Implementation of a filesystem for POSIX environments.
// This filesystem will support `file://` and empty (local) URI schemes.
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); }
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {
// TODO(mihaimaruseac): Implement later
} // namespace tf_random_access_file
// SECTION 2. Implementation for `TF_WritableFile`
// ----------------------------------------------------------------------------
namespace tf_writable_file {
// TODO(mihaimaruseac): Implement later
} // namespace tf_writable_file
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
// ----------------------------------------------------------------------------
namespace tf_read_only_memory_region {
// TODO(mihaimaruseac): Implement later
} // namespace tf_read_only_memory_region
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
// ----------------------------------------------------------------------------
namespace tf_windows_filesystem {
// TODO(mihaimaruseac): Implement later
} // namespace tf_windows_filesystem
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
const char* uri) {
TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri);
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
info->plugin_memory_allocate = plugin_memory_allocate;
info->plugin_memory_free = plugin_memory_free;
info->num_schemes = 2;
info->ops = static_cast<TF_FilesystemPluginOps*>(
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
ProvideFilesystemSupportFor(&info->ops[0], "");
ProvideFilesystemSupportFor(&info->ops[1], "file");
}

View File

@ -181,7 +181,8 @@ void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor,
return;
}
const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i));
TF_Tensor* result = ::tensorflow::TF_TensorFromTensor(cc_tensor, status);
TF_Tensor* result =
::tensorflow::TF_TensorFromTensor(cc_tensor, &status->status);
if (TF_GetCode(status) == TF_OK) {
*tensor = result;
}

View File

@ -18,19 +18,36 @@ limitations under the License.
#include "tensorflow/c/kernels.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include <stddef.h>
#include <stdint.h>
#include <string.h>
#include <memory>
#include <string>
#include "absl/container/inlined_vector.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
struct MyCustomKernel {
bool created;

View File

@ -133,7 +133,7 @@ TEST(OpsTest, TestShapeInference_VectorizeFunction) {
TEST(OpsTest, AttributeAccessors) {
TF_OpDefinitionBuilder* builder =
TF_NewOpDefinitionBuilder("AttributeAccesorsOp");
TF_NewOpDefinitionBuilder("AttributeAccessorsOp");
TF_OpDefinitionBuilderAddAttr(builder, "foo1: int >= 2");
TF_OpDefinitionBuilderAddAttr(builder, "foo2: string=\"my string\"");
TF_OpDefinitionBuilderSetIsCommutative(builder, true);
@ -151,7 +151,7 @@ TEST(OpsTest, AttributeAccessors) {
op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
bool found = false;
for (const auto& op : op_list.op()) {
if (op.name() == "AttributeAccesorsOp") {
if (op.name() == "AttributeAccessorsOp") {
ASSERT_TRUE(op.is_commutative());
ASSERT_TRUE(op.is_aggregate());
ASSERT_TRUE(op.allows_uninitialized_input());

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h"
#include <memory>
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor_internal.h"
@ -103,49 +105,35 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
}
TF_Tensor* ret =
new TF_Tensor{Tensor(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimvec), buf)};
// TODO(gjn): Make the choice of interface a compile-time configuration.
tensorflow::TensorInterface ret(
Tensor(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimvec), buf));
buf->Unref();
size_t elem_size = TF_DataTypeSize(dtype);
if (elem_size > 0 && len < (elem_size * ret->tensor.NumElements())) {
delete ret;
if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
return nullptr;
}
return ret;
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(ret)};
}
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) {
// It is safe to move the Tensor if and only if we own the unique reference to
// it. In that case, we might as well not delete and reallocate, but a future
// implementation might need to do so.
TensorBuffer* buf = tensorflow::TensorCApi::Buffer(tensor->tensor);
if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
buf->OwnsMemory()) {
return tensor;
}
return nullptr;
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* t) {
return t->tensor->CanMove() ? t : nullptr;
}
void TF_DeleteTensor(TF_Tensor* t) { delete t; }
TF_DataType TF_TensorType(const TF_Tensor* t) {
return static_cast<TF_DataType>(t->tensor.dtype());
}
TF_DataType TF_TensorType(const TF_Tensor* t) { return t->tensor->Type(); }
int TF_NumDims(const TF_Tensor* t) { return t->tensor.dims(); }
int TF_NumDims(const TF_Tensor* t) { return t->tensor->NumDims(); }
int64_t TF_Dim(const TF_Tensor* t, int dim_index) {
return static_cast<int64_t>(t->tensor.dim_size(dim_index));
return t->tensor->Dim(dim_index);
}
size_t TF_TensorByteSize(const TF_Tensor* t) {
return tensorflow::TensorCApi::Buffer(t->tensor)->size();
}
size_t TF_TensorByteSize(const TF_Tensor* t) { return t->tensor->ByteSize(); }
void* TF_TensorData(const TF_Tensor* t) {
return tensorflow::TensorCApi::Buffer(t->tensor)->data();
}
void* TF_TensorData(const TF_Tensor* t) { return t->tensor->Data(); }
int64_t TF_TensorElementCount(const TF_Tensor* t) {
int64_t result = 1;
@ -160,16 +148,69 @@ void TF_TensorBitcastFrom(const TF_Tensor* from, TF_DataType type,
TF_Tensor* to, const int64_t* new_dims,
int num_new_dims, TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
Status cc_status(
static_cast<tensorflow::TensorInterface*>(to->tensor.get())
->BitcastFrom(*static_cast<const tensorflow::TensorInterface*>(
from->tensor.get()),
type, new_dims, num_new_dims));
Set_TF_Status_from_Status(status, cc_status);
}
namespace tensorflow {
bool TensorInterface::CanMove() const {
// It is safe to move the Tensor if and only if we own the unique reference to
// it. In that case, we might as well not delete and reallocate, but a future
// implementation might need to do so.
TensorBuffer* buf = tensorflow::TensorCApi::Buffer(tensor_);
if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
buf->OwnsMemory()) {
return true;
}
return false;
}
TF_DataType TensorInterface::Type() const {
return static_cast<TF_DataType>(tensor_.dtype());
}
int TensorInterface::NumDims() const { return tensor_.dims(); }
int64_t TensorInterface::Dim(int dim_index) const {
return static_cast<int64_t>(tensor_.dim_size(dim_index));
}
int64_t TensorInterface::NumElements() const {
return static_cast<int64_t>(tensor_.NumElements());
}
size_t TensorInterface::ByteSize() const {
return tensorflow::TensorCApi::Buffer(tensor_)->size();
}
void* TensorInterface::Data() const {
return tensorflow::TensorCApi::Buffer(tensor_)->data();
}
Status TensorInterface::BitcastFrom(const TensorInterface& from,
TF_DataType type, const int64_t* new_dims,
int num_new_dims) {
tensorflow::TensorShape s;
for (int i = 0; i < num_new_dims; ++i) {
s.AddDim(new_dims[i]);
}
Status cc_status(to->tensor.BitcastFrom(
from->tensor, static_cast<tensorflow::DataType>(type), s));
Set_TF_Status_from_Status(status, cc_status);
return tensor_.BitcastFrom(from.tensor_,
static_cast<tensorflow::DataType>(type), s);
}
} // namespace tensorflow
// --------------------------------------------------------------------------
void StringEncode(const char* src, size_t src_len, char* dst) {
dst = tensorflow::core::EncodeVarint64(dst, src_len);
memcpy(dst, src, src_len);
}
size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
size_t dst_len, TF_Status* status) {
const size_t sz = TF_StringEncodedSize(src_len);
@ -185,8 +226,7 @@ size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
src_len, "-byte string"));
return 0;
}
dst = tensorflow::core::EncodeVarint64(dst, src_len);
memcpy(dst, src, src_len);
StringEncode(src, src_len, dst);
return sz;
}
@ -245,13 +285,11 @@ static TF_Tensor* EmptyTensor(TF_DataType dtype,
namespace tensorflow {
// Non-static for testing.
TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) {
*status = tensorflow::Status::OK();
if (!src.IsInitialized()) {
Set_TF_Status_from_Status(
status, FailedPrecondition(
"attempt to use a tensor with an uninitialized value"));
*status = FailedPrecondition(
"attempt to use a tensor with an uninitialized value");
return nullptr;
}
if (src.NumElements() == 0) {
@ -259,14 +297,13 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
}
if (src.dtype() == tensorflow::DT_RESOURCE) {
if (src.shape().dims() != 0) {
Set_TF_Status_from_Status(
status, InvalidArgument(
"Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
src.shape().DebugString(),
"). Please file a bug at "
"https://github.com/tensorflow/tensorflow/issues/new, "
"ideally with a "
"short code snippet that reproduces this error."));
*status = InvalidArgument(
"Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
src.shape().DebugString(),
"). Please file a bug at "
"https://github.com/tensorflow/tensorflow/issues/new, "
"ideally with a "
"short code snippet that reproduces this error.");
return nullptr;
}
const string str =
@ -276,12 +313,11 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
return t;
}
if (src.dtype() != tensorflow::DT_STRING) {
auto* result = new TF_Tensor();
if (!result->tensor.CopyFrom(src, src.shape())) {
delete result;
Tensor tensor;
if (!tensor.CopyFrom(src, src.shape())) {
return nullptr;
}
return result;
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(tensor)};
}
// DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly
// encoded sequence of strings.
@ -305,23 +341,15 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
*offsets = (dst - data_start);
offsets++;
const string& s = srcarray(i);
size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status);
if (TF_GetCode(status) != TF_OK) {
Set_TF_Status_from_Status(
status,
InvalidArgument("invalid string tensor encoding (string #", i, " of ",
srcarray.size(), "): ", TF_Message(status)));
delete[] base;
return nullptr;
}
const size_t consumed = TF_StringEncodedSize(s.size());
StringEncode(s.data(), s.size(), dst);
dst += consumed;
dst_len -= consumed;
}
if (dst != base + size) {
Set_TF_Status_from_Status(
status, InvalidArgument(
"invalid string tensor encoding (decoded ", (dst - base),
" bytes, but the tensor is encoded in ", size, " bytes"));
*status = InvalidArgument(
"invalid string tensor encoding (decoded ", (dst - base),
" bytes, but the tensor is encoded in ", size, " bytes");
delete[] base;
return nullptr;
}
@ -339,31 +367,35 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
}
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
if (src->tensor.dtype() == DT_RESOURCE) {
if (src->tensor.dims() != 0) {
return static_cast<const tensorflow::TensorInterface*>(src->tensor.get())
->ToTensor(dst);
}
Status TensorInterface::ToTensor(Tensor* dst) const {
if (tensor_.dtype() == DT_RESOURCE) {
if (tensor_.dims() != 0) {
return InvalidArgument(
"Malformed TF_RESOURCE tensor: expected a scalar, got a tensor with "
"shape ",
src->tensor.shape().DebugString());
tensor_.shape().DebugString());
}
*dst = Tensor(tensorflow::DT_RESOURCE, src->tensor.shape());
*dst = Tensor(tensorflow::DT_RESOURCE, tensor_.shape());
if (!dst->scalar<tensorflow::ResourceHandle>()().ParseFromString(
string(static_cast<const char*>(TF_TensorData(src)),
TF_TensorByteSize(src)))) {
string(static_cast<const char*>(Data()), ByteSize()))) {
return InvalidArgument(
"Malformed TF_RESOUCE tensor: unable to parse resource handle");
"Malformed TF_RESOURCE tensor: unable to parse resource handle");
}
return Status::OK();
}
if (src->tensor.dtype() != DT_STRING) {
*dst = src->tensor;
if (tensor_.dtype() != DT_STRING) {
*dst = tensor_;
return Status::OK();
}
// TF_STRING tensors require copying since Tensor class expects a sequence of
// string objects.
const tensorflow::int64 num_elements = src->tensor.NumElements();
const char* input = reinterpret_cast<const char*>(TF_TensorData(src));
const size_t src_size = TF_TensorByteSize(src);
const tensorflow::int64 num_elements = tensor_.NumElements();
const char* input = reinterpret_cast<const char*>(Data());
const size_t src_size = ByteSize();
if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) <
num_elements) {
return InvalidArgument(
@ -372,7 +404,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
const char* limit = input + src_size;
*dst = Tensor(src->tensor.dtype(), src->tensor.shape());
*dst = Tensor(tensor_.dtype(), tensor_.shape());
auto dstarray = dst->flat<tstring>();
for (tensorflow::int64 i = 0; i < num_elements; ++i) {
tensorflow::uint64 offset =
@ -391,8 +423,8 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
return Status::OK();
}
bool TensorInterface::IsAligned() const { return tensor_.IsAligned(); }
} // namespace tensorflow
bool TF_TensorIsAligned(const TF_Tensor* tensor) {
return tensor->tensor.IsAligned();
}
bool TF_TensorIsAligned(const TF_Tensor* t) { return t->tensor->IsAligned(); }

View File

@ -16,9 +16,12 @@ limitations under the License.
#ifndef TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
#define TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
#include <memory>
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_interface.h"
#include "tensorflow/core/framework/tensor_shape.h"
// Internal structures used by the C API. These are likely to change and should
@ -28,7 +31,7 @@ limitations under the License.
// passed to or returned from C functions *by pointer*. Otherwise, changes to
// its internal structure will break the C API's binary interface.
typedef struct TF_Tensor {
::tensorflow::Tensor tensor;
std::unique_ptr<AbstractTensorInterface> tensor;
} TF_Tensor;
class TF_ManagedBuffer : public tensorflow::TensorBuffer {
@ -83,4 +86,5 @@ void* allocate_tensor(const char* operation, size_t len, Allocator* allocator);
// a different Allocator as `arg`.
void deallocate_buffer(void* data, size_t len, void* arg);
} // namespace tensorflow
#endif // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_

View File

@ -96,7 +96,7 @@ class SymbolicGradientBuilder {
// Used to identify nodes at which to stop backprop.
std::unordered_set<int> GetStopBackpropNodes(
const std::vector<bool>& reachable_nodes,
const std::unordered_set<int>& output_nodes);
const std::unordered_set<int>& output_nodes) const;
const Scope& scope_;
const ops::GradOpRegistry* registry_;
@ -190,7 +190,7 @@ std::vector<bool> SymbolicGradientBuilder::GetReachableNodes() {
std::unordered_set<int> SymbolicGradientBuilder::GetStopBackpropNodes(
const std::vector<bool>& reachable_nodes,
const std::unordered_set<int>& output_nodes) {
const std::unordered_set<int>& output_nodes) const {
// Output nodes that get transitively consumed by other `outputs_` are stored
// in `internal_outputs`.
std::unordered_set<int> internal_outputs;
@ -346,8 +346,8 @@ Status SymbolicGradientBuilder::SumGradients(const Output& src, Output* grad) {
"Unable to find backprop list for node.id ", src.node()->name());
}
const auto& grads = iter->second;
// Filter any backproped 'NoGradient' Outputs from 'grads' (if needed).
// Return any valid backproped gradients that remain after filtering,
// Filter any backpropped 'NoGradient' Outputs from 'grads' (if needed).
// Return any valid backpropped gradients that remain after filtering,
// or 'NoGradient' otherwise.
std::vector<Output> grads_to_keep;
for (const Output& o : grads) {
@ -519,7 +519,7 @@ Status SymbolicGradientBuilder::AddGradients() {
// Backprop along the in edges.
// TODO(andydavis) Find cleaner way to map each grad output returned by
// gradient function to the src node/output to which it should be
// backproped. Maybe grad functions can return a vector of Output pairs to
// backpropped. Maybe grad functions can return a vector of Output pairs to
// make this association explicit.
size_t dx_index = 0;
for (const Edge* e : n->in_edges()) {

View File

@ -64,7 +64,7 @@ bool IsZero(const Scope& scope, const Output& grad) {
// Multiply after broadcasting vec to match dimensions of mat.
// Args:
// vec: A 1-D tensor of dimension [D0]
// mat: A 2-D tensor of dimesnion [D0, D1]
// mat: A 2-D tensor of dimension [D0, D1]
//
// Returns:
// A tensor of dimension [D0, D1], the result fo vec * mat.

View File

@ -259,6 +259,9 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) {
RunTest(x, x_init_value, y, y_shape);
}
// TODO(rocm):
// Re-enable this test once 3D pooling is supported on ROCm platform
#ifndef TENSORFLOW_USE_ROCM
TEST_F(NNGradTest, MaxPool3DGradHelper) {
TensorShape x_shape({1, 3, 3, 3, 1});
TensorShape y_shape({1, 1, 1, 1, 1});
@ -271,6 +274,7 @@ TEST_F(NNGradTest, MaxPool3DGradHelper) {
SetRandomValuesForMaxPooling<float>(&x_init_value);
RunTest(x, x_init_value, y, y_shape);
}
#endif
TEST_F(NNGradTest, AvgPoolGradHelper) {
TensorShape x_shape({1, 2, 2, 1});
@ -283,6 +287,9 @@ TEST_F(NNGradTest, AvgPoolGradHelper) {
RunTest(x, x_shape, y, y_shape);
}
// TODO(rocm):
// Re-enable this test once 3D pooling is supported on ROCm platform
#ifndef TENSORFLOW_USE_ROCM
TEST_F(NNGradTest, AvgPool3DGradHelper) {
TensorShape x_shape({1, 3, 3, 3, 1});
TensorShape y_shape({1, 1, 1, 1, 1});
@ -293,6 +300,7 @@ TEST_F(NNGradTest, AvgPool3DGradHelper) {
auto y = AvgPool3D(scope_, x, ksize, strides, "SAME");
RunTest(x, x_shape, y, y_shape);
}
#endif
TEST_F(NNGradTest, LRN) {
TensorShape x_shape({1, 1, 2, 1});

View File

@ -124,13 +124,12 @@ cc_library(
hdrs = ["bundle_v2.h"],
deps = [
":constants",
"@com_google_absl//absl/container:flat_hash_set",
] + if_not_mobile([
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:strcat",
"//tensorflow/core/util/tensor_bundle",
]),
"@com_google_absl//absl/container:flat_hash_set",
],
)
tf_cc_test(

View File

@ -1,5 +1,6 @@
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
load("//tensorflow/core/platform:build_config.bzl", "if_llvm_aarch64_available")
package(
default_visibility = ["//visibility:private"],
@ -27,9 +28,15 @@ cc_library(
"compile.h",
"flags.h",
],
defines = if_llvm_aarch64_available(["TF_LLVM_AARCH64_AVAILABLE=1"]),
visibility = ["//tensorflow/python:__pkg__"],
deps = [
":aot_only_var_handle_op",
":embedded_protocol_buffers",
"@com_google_absl//absl/base",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"//tensorflow/compiler/tf2xla",
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
@ -53,10 +60,13 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@llvm-project//llvm:target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
] + if_llvm_aarch64_available([
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
]),
)
tf_cc_test(
@ -86,6 +96,19 @@ tf_cc_binary(
deps = [":tfcompile_main"],
)
cc_library(
name = "llvm_targets",
visibility = ["//tensorflow/python:__pkg__"],
deps = [
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@llvm-project//llvm:target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
] + if_llvm_aarch64_available([
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
]),
)
cc_library(
name = "tfcompile_main",
srcs = ["tfcompile_main.cc"],
@ -104,11 +127,6 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:aarch64_code_gen", # fixdeps: keep
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@llvm-project//llvm:target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
],
)
@ -214,8 +232,13 @@ cc_library(
cc_library(
name = "aot_only_var_handle_op",
srcs = ["aot_only_var_handle_op.cc"],
hdrs = ["aot_only_var_handle_op.h"],
visibility = [
"//tensorflow/compiler/tf2xla:__pkg__",
],
deps = [
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/core:framework",
],
alwayslink = 1,
)

View File

@ -13,9 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/aot/aot_only_var_handle_op.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace {
@ -51,6 +54,31 @@ void XlaAotOnlyVarHandleOp::Compile(XlaOpKernelContext* context) {
}
} // namespace
REGISTER_XLA_OP(Name("VarHandleOp").CompilationOnly(), XlaAotOnlyVarHandleOp);
REGISTER_OP(tfcompile::kXlaAotOnlyVarHandleOp)
.Doc(R"doc(
Internal VarHandleOp registration used for XLA AOT compilation.
)doc")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.Attr("dtype: type")
.Attr("shape: shape")
.Output("resource: resource")
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Scalar());
DataType t;
TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t));
PartialTensorShape p;
TF_RETURN_IF_ERROR(c->GetAttr("shape", &p));
shape_inference::ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s));
c->set_output_handle_shapes_and_types(
0, std::vector<shape_inference::ShapeAndType>{{s, t}});
return Status::OK();
});
REGISTER_XLA_OP(Name(tfcompile::kXlaAotOnlyVarHandleOp).CompilationOnly(),
XlaAotOnlyVarHandleOp);
} // namespace tensorflow

View File

@ -0,0 +1,27 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_
#define TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_
namespace tensorflow {
namespace tfcompile {
static constexpr const char* const kXlaAotOnlyVarHandleOp =
"_XlaAotOnlyVarHandleOp";
} // namespace tfcompile
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_

View File

@ -74,16 +74,16 @@ void DumpStatsToStdout(const Stats& stats) {
const int kBufSize = 1000;
char buf[kBufSize];
snprintf(buf, kBufSize, "Mean with %2.0f%% trimmed:", trim_ratio * 100);
const string label_trimmed(buf);
std::string label_trimmed(buf);
snprintf(buf, kBufSize, "Mean of %2.0f%% best:", best_ratio * 100);
const string label_best(buf);
std::vector<std::pair<string, double>> groups = {
std::string label_best(buf);
std::vector<std::pair<std::string, double>> groups = {
{"Best:", sorted_us.front()},
{"Worst:", sorted_us.back()},
{"Median:", sorted_us[count_us / 2]},
{"Mean:", sum_us / count_us},
{label_trimmed, sum_us_trimmed / count_us_trimmed},
{label_best, sum_us_best / count_us_best},
{std::move(label_trimmed), sum_us_trimmed / count_us_trimmed},
{std::move(label_best), sum_us_best / count_us_best},
};
int max_label_size = 0;
double max_us = 0;
@ -102,7 +102,7 @@ void DumpStatsToStdout(const Stats& stats) {
}
// Dump stats out.
printf("Benchmark ran %zu iterations over %lld us\n", count_us,
stats.total_us);
static_cast<long long>(stats.total_us)); // NOLINT
for (const auto& g : groups) {
printf(" %-*s %*.3f us\n", max_label_size, g.first.c_str(), max_digits + 4,
g.second);
@ -114,7 +114,8 @@ void Benchmark(const Options& options, const BenchmarkFn& fn, Stats* stats) {
const int64 max_us = (options.max_micros <= 0 && options.max_iters <= 0)
? Options::kDefaultMicros
: options.max_micros;
printf("Running benchmark for %lld us\n", max_us);
// NOLINTNEXTLINE
printf("Running benchmark for %lld us\n", static_cast<long long>(max_us));
const int64 start_us = NowMicros();
int64 iters = 0;
while (true) {

View File

@ -423,8 +423,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
const string include_xla_data_proto =
opts.gen_program_shape
?
R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
? R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
: "";
const string include_hlo_profile_printer_data_proto =

View File

@ -20,6 +20,9 @@ limitations under the License.
#include <utility>
#include <vector>
#include "absl/base/call_once.h"
#include "llvm-c/Target.h"
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/compiler/tf2xla/tf2xla.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
@ -90,7 +93,7 @@ Status CompileXla(xla::CompileOnlyClient* client,
} // namespace
Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
const MainFlags& flags, CompileResult* compile_result) {
// Converts the graph into an XLA computation, and compiles the
// computation.
@ -108,8 +111,8 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
if (!flags.mlir_components.empty()) {
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
}
TF_RETURN_IF_ERROR(
ConvertGraphDefToXla(graph_def, config, client, &computation));
TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config,
client, &computation));
}
if (!flags.out_session_module.empty()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
@ -132,5 +135,96 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
return CompileXla(client, computation, aot_opts, compile_result);
}
static Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
if (absl::EndsWith(fname, ".pbtxt")) {
return ReadTextProto(Env::Default(), fname, proto);
} else {
return ReadBinaryProto(Env::Default(), fname, proto);
}
}
static absl::once_flag targets_init;
static void InitializeTargets() {
// Initialize all LLVM targets so we can cross compile.
#if TF_LLVM_AARCH64_AVAILABLE
LLVMInitializeAArch64Target();
LLVMInitializeAArch64TargetInfo();
LLVMInitializeAArch64TargetMC();
LLVMInitializeAArch64AsmPrinter();
#endif
LLVMInitializeARMTarget();
LLVMInitializeARMTargetInfo();
LLVMInitializeARMTargetMC();
LLVMInitializeARMAsmPrinter();
LLVMInitializePowerPCTarget();
LLVMInitializePowerPCTargetInfo();
LLVMInitializePowerPCTargetMC();
LLVMInitializePowerPCAsmPrinter();
LLVMInitializeX86Target();
LLVMInitializeX86TargetInfo();
LLVMInitializeX86TargetMC();
LLVMInitializeX86AsmPrinter();
}
Status Main(const MainFlags& flags) {
absl::call_once(targets_init, &InitializeTargets);
// Process config.
tf2xla::Config config;
if (flags.config.empty()) {
return errors::InvalidArgument("Must specify --config");
}
TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
TF_RETURN_IF_ERROR(ValidateConfig(config));
if (flags.dump_fetch_nodes) {
std::set<string> nodes;
for (const tf2xla::Fetch& fetch : config.fetch()) {
nodes.insert(fetch.id().node_name());
}
std::cout << absl::StrJoin(nodes, ",");
return Status::OK();
}
// Read and initialize the graph.
if (flags.graph.empty()) {
return errors::InvalidArgument("Must specify --graph");
}
GraphDef graph_def;
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
CompileResult compile_result;
TF_RETURN_IF_ERROR(
CompileGraph(std::move(graph_def), config, flags, &compile_result));
// Write output files.
Env* env = Env::Default();
const std::vector<char>& obj = compile_result.aot->object_file_data();
TF_RETURN_IF_ERROR(
WriteStringToFile(env, flags.out_function_object,
absl::string_view(obj.data(), obj.size())));
CodegenOpts codegen_opts;
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
codegen_opts.gen_program_shape = flags.gen_program_shape;
codegen_opts.target_triple = flags.target_triple;
if (flags.cpp_class.empty()) {
return errors::InvalidArgument("Must specify --cpp_class");
}
codegen_opts.gen_hlo_profile_printer_data =
xla::GetDebugOptionsFromFlags().xla_hlo_profile();
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
&codegen_opts.namespaces));
MetadataResult metadata_result;
TF_RETURN_IF_ERROR(
GenerateMetadata(codegen_opts, compile_result, &metadata_result));
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object,
metadata_result.object_file_data));
string header;
TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result,
metadata_result, &header));
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
return Status::OK();
}
} // namespace tfcompile
} // namespace tensorflow

View File

@ -42,9 +42,12 @@ struct CompileResult {
// that performs the graph operations.
//
// The XLA compilation options are specified in the flags.
Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
const MainFlags& flags, CompileResult* compile_result);
// The full compilation method, for reuse in a library setting.
Status Main(const MainFlags& flags);
} // namespace tfcompile
} // namespace tensorflow

View File

@ -25,6 +25,7 @@ namespace tensorflow {
namespace tfcompile {
// Flags for the tfcompile binary. See *.cc file for descriptions.
struct MainFlags {
string graph;
string config;

View File

@ -25,6 +25,7 @@ test_suite(
":test_graph_tfmatmulandadd_test",
":test_graph_tfsplits_test",
":test_graph_tftop_k_test",
":test_graph_tfvariable_readonly_test",
":test_graph_tfvariable_sequential_updates_test",
":test_graph_tfvariable_test",
":tfcompile_test",
@ -73,6 +74,7 @@ genrule(
"test_graph_tfsplits.pb",
"test_graph_tftop_k.pb",
"test_graph_tfvariable.pb",
"test_graph_tfvariable_readonly.pb",
"test_graph_tfvariable_sequential_updates.pb",
],
# Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any
@ -238,6 +240,17 @@ tf_library(
],
)
tf_library(
name = "test_graph_tfvariable_readonly",
testonly = 1,
config = "test_graph_tfvariable_readonly.config.pbtxt",
cpp_class = "VariableReadonlyComp",
graph = "test_graph_tfvariable_readonly.pb",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfvariable_sequential_updates",
testonly = 1,
@ -269,6 +282,7 @@ tf_cc_test(
":test_graph_tfsplits",
":test_graph_tftop_k",
":test_graph_tfvariable",
":test_graph_tfvariable_readonly",
":test_graph_tfvariable_sequential_updates",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@ -323,6 +337,42 @@ tf_library(
],
)
tf_library(
name = "test_graph_tfcond_mlir_bridge",
testonly = 1,
config = "test_graph_tfcond.config.pbtxt",
cpp_class = "CondComp",
graph = "test_graph_tfcond.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfassert_eq_mlir_bridge",
testonly = 1,
config = "test_graph_tfassert_eq.config.pbtxt",
cpp_class = "AssertComp",
graph = "test_graph_tfassert_eq.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfgather_mlir_bridge",
testonly = 1,
config = "test_graph_tfgather.config.pbtxt",
cpp_class = "GatherComp",
graph = "test_graph_tfgather.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfmatmul_mlir_bridge",
testonly = 1,
@ -361,6 +411,66 @@ tf_library(
],
)
tf_library(
name = "test_graph_tfsplits_mlir_bridge",
testonly = 1,
config = "test_graph_tfsplits.config.pbtxt",
cpp_class = "SplitsComp",
graph = "test_graph_tfsplits.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tftop_k_mlir_bridge",
testonly = 1,
config = "test_graph_tftop_k.config.pbtxt",
cpp_class = "TopKComp",
graph = "test_graph_tftop_k.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfvariable_readonly_mlir_bridge",
testonly = 1,
config = "test_graph_tfvariable_readonly.config.pbtxt",
cpp_class = "VariableReadonlyComp",
graph = "test_graph_tfvariable_readonly.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfvariable_mlir_bridge",
testonly = 1,
config = "test_graph_tfvariable.config.pbtxt",
cpp_class = "VariableComp",
graph = "test_graph_tfvariable.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfvariable_sequential_updates_mlir_bridge",
testonly = 1,
config = "test_graph_tfvariable_sequential_updates.config.pbtxt",
cpp_class = "VariableSequentialUpdatesComp",
graph = "test_graph_tfvariable_sequential_updates.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_cc_test(
name = "tfcompile_test_mlir_bridge",
srcs = ["tfcompile_test.cc"],
@ -372,9 +482,17 @@ tf_cc_test(
":test_graph_tfadd_mlir_bridge",
":test_graph_tfadd_with_ckpt_mlir_bridge",
":test_graph_tfadd_with_ckpt_saver_mlir_bridge",
":test_graph_tfassert_eq_mlir_bridge",
":test_graph_tfcond_mlir_bridge",
":test_graph_tfgather_mlir_bridge",
":test_graph_tfmatmul_mlir_bridge",
":test_graph_tfmatmulandadd_mlir_bridge",
":test_graph_tfmatmulandadd_with_profiling_mlir_bridge",
":test_graph_tfsplits_mlir_bridge",
":test_graph_tftop_k_mlir_bridge",
":test_graph_tfvariable_mlir_bridge",
":test_graph_tfvariable_readonly_mlir_bridge",
":test_graph_tfvariable_sequential_updates_mlir_bridge",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_data_proto_cc",

View File

@ -34,6 +34,7 @@ from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variables
@ -153,11 +154,21 @@ def tftop_k(_):
array_ops.identity(output[1], name='indices')
def tfvariable(_):
def tfvariable_readonly(_):
x = variables.Variable(1000.0, name='x')
old_x = x.value()
with ops.control_dependencies([old_x]):
new_x = x.assign_add(42.0)
new_value = math_ops.add(old_x, 42.0)
array_ops.identity(new_value, name='result')
# TODO(b/147908587): Change x and the two constants back to have a scalar shape
# when the bug is fixed.
def tfvariable(_):
x = variables.Variable([1000.0], name='x', shape=[1])
old_x = x.value()
with ops.control_dependencies([old_x]):
new_x = x.assign_add([42.0])
array_ops.stack([old_x, new_x], name='result')
@ -184,6 +195,7 @@ def write_graph(build_graph, out_dir):
def main(_):
control_flow_util.enable_control_flow_v2()
write_graph(tfadd, FLAGS.out_dir)
write_graph(tfadd_with_ckpt, FLAGS.out_dir)
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
@ -196,6 +208,7 @@ def main(_):
write_graph(tfsplits, FLAGS.out_dir)
write_graph(tftop_k, FLAGS.out_dir)
write_graph(tfvariable, FLAGS.out_dir)
write_graph(tfvariable_readonly, FLAGS.out_dir)
write_graph(tfvariable_sequential_updates, FLAGS.out_dir)

View File

@ -0,0 +1,12 @@
# Text form of tensorflow.tf2xla.Config proto.
fetch {
id { node_name: "result" }
}
variable {
node_name: "x"
shape {
}
type: DT_FLOAT
readonly: true
}

View File

@ -30,9 +30,17 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfcond_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfgather_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tftop_k_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates_mlir_bridge.h"
#else
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
@ -47,6 +55,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h"
#include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates.h"
#endif
@ -167,8 +176,6 @@ TEST(TFCompileTest, AddWithCkptSaver) {
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
}
// TODO(bixia): the following tests failed with MLIR bridge.
#if !defined(ENABLE_MLIR_BRIDGE_TEST)
TEST(TFCompileTest, Cond) {
CondComp cond;
EXPECT_EQ(cond.arg0_data(), cond.arg_data(0));
@ -233,7 +240,6 @@ TEST(TFCompileTest, Gather) {
EXPECT_EQ(gather_const.result0_data(), gather.results()[0]);
}
}
#endif
TEST(TFCompileTest, MatMul2) {
Eigen::ThreadPool tp(2);
@ -439,6 +445,7 @@ TEST(TFCompileTest, Function) {
EXPECT_EQ(add_fn.result0_data()[0], 3);
EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]);
}
#endif
TEST(TFCompileTest, Splits) {
Eigen::ThreadPool tp(1);
@ -492,6 +499,20 @@ TEST(TFCompileTest, TopK) {
EXPECT_EQ(expected_indices[1], fn.result1(1));
}
TEST(TFCompileTest, VariableReadonly) {
Eigen::ThreadPool tp(1);
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
VariableReadonlyComp fn;
float x = 23;
fn.set_var_x_data(&x);
fn.set_thread_pool(&device);
fn.Run();
EXPECT_EQ(fn.result0(), 65);
EXPECT_EQ(fn.var_x(), 23);
}
TEST(TFCompileTest, Variable) {
Eigen::ThreadPool tp(1);
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
@ -665,6 +686,11 @@ TEST(TFCompileTest, HloProfiling) {
/*clock_rate_ghz=*/1.0);
VLOG(1) << "Original HLO profile string:\n" << hlo_profile_as_string;
// Replace Arg_n with argn when the MLIR bridge is used.
#if defined(ENABLE_MLIR_BRIDGE_TEST)
RE2::GlobalReplace(&hlo_profile_as_string, "(Arg_)([0-9].)", "arg\\2");
#endif
// Strip away identifier details from the profile string to avoid this test
// being a change detector for xla internals. Identifiers such as '%dot.0.7'
// just become '%dot'.
@ -690,7 +716,6 @@ TEST(TFCompileTest, HloProfiling) {
IsSupersetOf({header, total_cycles_profile_line, dot_profile_line,
add_profile_line, tuple_profile_line}));
}
#endif
} // namespace
} // namespace tfcompile

View File

@ -407,6 +407,7 @@ def target_llvm_triple():
"//tensorflow:android_arm64": "aarch64-none-android",
"//tensorflow:android_x86": "i686-none-android",
"//tensorflow:ios": "arm64-none-ios",
"//tensorflow:ios_x86_64": "x86_64-apple-ios",
"//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
"//tensorflow:macos": "x86_64-none-darwin",
"//conditions:default": "x86_64-pc-linux",

View File

@ -21,7 +21,6 @@ limitations under the License.
#include "absl/strings/match.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "llvm-c/Target.h"
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/compile.h"
#include "tensorflow/compiler/aot/flags.h"
@ -56,88 +55,6 @@ const char kUsageHeader[] =
"--cpp_class=\"mynamespace::MyComputation\"\n"
"\n";
Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
if (absl::EndsWith(fname, ".pbtxt")) {
return ReadTextProto(Env::Default(), fname, proto);
} else {
return ReadBinaryProto(Env::Default(), fname, proto);
}
}
Status Main(const MainFlags& flags) {
// Initialize all LLVM targets so we can cross compile.
LLVMInitializeAArch64Target();
LLVMInitializeAArch64TargetInfo();
LLVMInitializeAArch64TargetMC();
LLVMInitializeAArch64AsmPrinter();
LLVMInitializeARMTarget();
LLVMInitializeARMTargetInfo();
LLVMInitializeARMTargetMC();
LLVMInitializeARMAsmPrinter();
LLVMInitializePowerPCTarget();
LLVMInitializePowerPCTargetInfo();
LLVMInitializePowerPCTargetMC();
LLVMInitializePowerPCAsmPrinter();
LLVMInitializeX86Target();
LLVMInitializeX86TargetInfo();
LLVMInitializeX86TargetMC();
LLVMInitializeX86AsmPrinter();
// Process config.
tf2xla::Config config;
if (flags.config.empty()) {
return errors::InvalidArgument("Must specify --config");
}
TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
TF_RETURN_IF_ERROR(ValidateConfig(config));
if (flags.dump_fetch_nodes) {
std::set<string> nodes;
for (const tf2xla::Fetch& fetch : config.fetch()) {
nodes.insert(fetch.id().node_name());
}
std::cout << absl::StrJoin(nodes, ",");
return Status::OK();
}
// Read and initialize the graph.
if (flags.graph.empty()) {
return errors::InvalidArgument("Must specify --graph");
}
GraphDef graph_def;
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
CompileResult compile_result;
TF_RETURN_IF_ERROR(CompileGraph(graph_def, config, flags, &compile_result));
// Write output files.
Env* env = Env::Default();
const std::vector<char>& obj = compile_result.aot->object_file_data();
TF_RETURN_IF_ERROR(
WriteStringToFile(env, flags.out_function_object,
absl::string_view(obj.data(), obj.size())));
CodegenOpts codegen_opts;
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
codegen_opts.gen_program_shape = flags.gen_program_shape;
codegen_opts.target_triple = flags.target_triple;
if (flags.cpp_class.empty()) {
return errors::InvalidArgument("Must specify --cpp_class");
}
codegen_opts.gen_hlo_profile_printer_data =
xla::GetDebugOptionsFromFlags().xla_hlo_profile();
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
&codegen_opts.namespaces));
MetadataResult metadata_result;
TF_RETURN_IF_ERROR(
GenerateMetadata(codegen_opts, compile_result, &metadata_result));
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object,
metadata_result.object_file_data));
string header;
TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result,
metadata_result, &header));
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
return Status::OK();
}
} // end namespace tfcompile
} // end namespace tensorflow

View File

@ -2,14 +2,10 @@ load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "tf_cc_
load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps")
load("//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
load("//tensorflow/core/platform:build_config_root.bzl", "tf_cuda_tests_tags")
package(
default_visibility = [
":internal",
# BEGIN-GOOGLE-INTERNAL
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
# END-GOOGLE-INTERNAL
],
default_visibility = [":internal"],
licenses = ["notice"], # Apache 2.0
)
@ -61,6 +57,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":jit_compilation_passes",
":xla_kernel_creator", # buildcleaner: keep
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
@ -74,6 +71,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = if_cuda_or_rocm([
":jit_compilation_passes",
":xla_kernel_creator", # buildcleaner: keep
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
@ -82,19 +80,6 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "xla_mlir_gpu_jit",
visibility = ["//visibility:public"],
deps = if_cuda_or_rocm([
":jit_compilation_passes",
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/xla/service:mlir_gpu_plugin",
]),
alwayslink = 1,
)
cc_library(
name = "xla_cpu_device",
srcs = ["xla_cpu_device.cc"],
@ -120,6 +105,7 @@ cc_library(
srcs = ["xla_gpu_device.cc"],
visibility = [":friends"],
deps = [
":flags",
":jit_compilation_passes",
":xla_device",
":xla_kernel_creator", # buildcleaner: keep
@ -128,6 +114,7 @@ cc_library(
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:gpu_init",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@ -172,7 +159,9 @@ XLA_DEVICE_DEPS = [
":common",
":xla_launch_util",
":xla_tensor",
"@com_google_absl//absl/base",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:optional",
"//tensorflow/compiler/jit/ops:xla_ops",
@ -265,13 +254,26 @@ cc_library(
}),
)
# Internal targets below this point.
cc_library(
name = "flags",
srcs = ["flags.cc"],
hdrs = ["flags.h"],
visibility = [":friends"],
deps = [
"//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@com_google_absl//absl/base",
"@com_google_absl//absl/strings",
],
)
# Header-only version of "flags" library, for linking from the shared object
# without ODR violations.
cc_library(
name = "flags_headers_only",
hdrs = ["flags.h"],
visibility = [":friends"],
deps = [
"//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal",
@ -291,6 +293,8 @@ cc_library(
visibility = [":friends"],
)
# Internal targets below this point.
cc_library(
name = "xla_launch_util",
srcs = ["xla_launch_util.cc"],
@ -412,6 +416,7 @@ cc_library(
"xla_kernel_creator.h",
],
deps = [
":flags",
":jit_compilation_passes",
":xla_kernel_creator_util",
"//tensorflow/core:core_cpu_internal",
@ -500,6 +505,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
@ -639,6 +645,7 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
@ -770,7 +777,7 @@ tf_cc_test(
],
# TODO(b/141643254) Re-enable msan after fixing use-of-uninitialized-value
# error.
tags = ["nomsan"],
tags = ["nomsan"] + tf_cuda_tests_tags(),
deps = [
":common",
":compilation_passes",

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/graph_node_util.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/hash/hash.h"
@ -1583,7 +1584,6 @@ DeadnessAnalysis::~DeadnessAnalysis() {}
absl::flat_hash_map<TensorId, string, TensorId::Hasher>
DeadnessAnalysisImpl::PredicateMapAsString() const {
absl::flat_hash_map<TensorId, string, TensorId::Hasher> result;
std::vector<TensorId> tensor_ids;
for (const auto& kv_pair : predicate_map_) {
CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second);
}

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph_node_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/proto_serialization.h"

View File

@ -374,39 +374,6 @@ xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
return new_def;
}
TF_ATTRIBUTE_NOINLINE Status
ValidateOutsideCompilationCallNode(Node* call_node) {
// DT_INT64 as input/output for outside compilation is not supported yet:
// b/120809951.
for (const Edge* e : call_node->in_edges()) {
if (e->IsControlEdge()) {
continue;
}
DataType dtype = e->src()->output_type(e->src_output());
if (dtype == DT_INT64) {
return errors::Unimplemented(
"int64 input for outside compilation is not supported yet: "
"b/120809951. Please cast output of node ",
e->src()->DebugString(),
" to int32 before feeding it into outside compilation.");
}
}
for (const Edge* e : call_node->out_edges()) {
if (e->IsControlEdge()) {
continue;
}
DataType dtype = e->dst()->input_type(e->dst_input());
if (dtype == DT_INT64) {
return errors::Unimplemented(
"int64 output for outside compilation is not supported yet: "
"b/120809951. Please cast input of node ",
e->dst()->DebugString(),
" to int32 before returning it from outside compilation.");
}
}
return Status::OK();
}
// Replace outside compilation function call node with XlaHostCompute node.
TF_ATTRIBUTE_NOINLINE xla::StatusOr<Node*> ReplaceOutsideCompilationCallNode(
Graph* g, Node* call_node, const std::map<string, int>& host_compute_core,
@ -2130,6 +2097,53 @@ Status ExtractOutsideCompilationForNodesWithAssociatedFunctions(
return Status::OK();
}
Status CopyOutsideCompilationConstNodes(
Graph* g, const string& outside_compilation_attr_name) {
for (Node* n : g->op_nodes()) {
if (!n->IsConstant() ||
!HasNodeAttr(n->def(), outside_compilation_attr_name)) {
continue;
}
std::vector<const Edge*> out_edges(n->out_edges().begin(),
n->out_edges().end());
bool has_non_oc_output = false;
for (const Edge* e : out_edges) {
if (!e->IsControlEdge() &&
!HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
has_non_oc_output = true;
break;
}
}
if (!has_non_oc_output) {
continue;
}
NodeDef copy_def = n->def();
copy_def.set_name(g->NewName(n->name()));
copy_def.mutable_attr()->erase(outside_compilation_attr_name);
Status s;
Node* copy_node = g->AddNode(copy_def, &s);
TF_RETURN_IF_ERROR(s);
for (const Edge* e : n->in_edges()) {
if (e->IsControlEdge()) {
g->AddControlEdge(e->src(), copy_node);
}
}
for (const Edge* e : out_edges) {
if (!e->IsControlEdge() &&
!HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
Node* dst = e->dst();
int dst_input = e->dst_input();
g->RemoveEdge(e);
g->AddEdge(copy_node, 0, dst, dst_input);
}
}
}
return Status::OK();
}
} // namespace
Status RewriteOutsideCompilationSubgraphFn::operator()(
@ -2279,6 +2293,10 @@ Status ExtractOutsideCompilationForFunction(
std::vector<string> outside_compilation_host_graphs;
std::vector<string> shape_inference_graphs_to_rewrite;
if (*has_outside_compilation) {
// Copy outside compilation Const nodes with non outside compilation users.
TF_RETURN_IF_ERROR(CopyOutsideCompilationConstNodes(
fbody->graph, outside_compilation_attr_name));
// Find dependencies between outside compilation clusters.
TF_ASSIGN_OR_RETURN(auto cluster_deps,
OutsideCompilationClusterDependencies(
@ -2333,7 +2351,6 @@ Status ExtractOutsideCompilationForFunction(
}
std::map<string, Node*> host_compute_nodes;
for (Node* n : outside_compilation_nodes) {
TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n));
auto host_compute_node_or = ReplaceOutsideCompilationCallNode(
graph_out.get(), n, host_compute_core, *cluster_deps);
TF_RETURN_IF_ERROR(host_compute_node_or.status());

View File

@ -13,13 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/flags.h"
#include <mutex> // NOLINT
#include "absl/base/call_once.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_split.h"
#include "absl/strings/strip.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
@ -32,7 +35,7 @@ XlaOpsCommonFlags* ops_flags;
IntroduceFloatingPointJitterPassFlags* jitter_flags;
std::vector<Flag>* flag_list;
std::once_flag flags_init;
absl::once_flag flags_init;
bool SetterForXlaAutoJitFlag(const string& value) {
int32 opt_level;
@ -155,6 +158,7 @@ void AllocateAndParseFlags() {
device_flags = new XlaDeviceFlags;
device_flags->tf_xla_compile_on_demand = false;
device_flags->tf_xla_enable_xla_devices = true;
ops_flags = new XlaOpsCommonFlags;
ops_flags->tf_xla_always_defer_compilation = false;
@ -187,6 +191,12 @@ void AllocateAndParseFlags() {
"Switch a device into 'on-demand' mode, where instead of "
"autoclustering ops are compiled one by one just-in-time."),
Flag("tf_xla_enable_xla_devices",
&device_flags->tf_xla_enable_xla_devices,
"Generate XLA_* devices, where placing a computation on such a "
"device"
"forces compilation by XLA. Deprecated."),
Flag("tf_xla_always_defer_compilation",
&ops_flags->tf_xla_always_defer_compilation, ""),
@ -206,38 +216,45 @@ void AllocateAndParseFlags() {
} // namespace
bool SetXlaAutoJitFlagFromFlagString(const string& value) {
std::call_once(flags_init, &AllocateAndParseFlags);
absl::call_once(flags_init, &AllocateAndParseFlags);
return SetterForXlaAutoJitFlag(value);
}
BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags() {
std::call_once(flags_init, &AllocateAndParseFlags);
absl::call_once(flags_init, &AllocateAndParseFlags);
return build_ops_flags;
}
MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() {
std::call_once(flags_init, &AllocateAndParseFlags);
absl::call_once(flags_init, &AllocateAndParseFlags);
return mark_for_compilation_flags;
}
XlaDeviceFlags* GetXlaDeviceFlags() {
std::call_once(flags_init, &AllocateAndParseFlags);
absl::call_once(flags_init, &AllocateAndParseFlags);
return device_flags;
}
const XlaOpsCommonFlags& GetXlaOpsCommonFlags() {
std::call_once(flags_init, &AllocateAndParseFlags);
absl::call_once(flags_init, &AllocateAndParseFlags);
return *ops_flags;
}
const IntroduceFloatingPointJitterPassFlags&
GetIntroduceFloatingPointJitterPassFlags() {
std::call_once(flags_init, &AllocateAndParseFlags);
absl::call_once(flags_init, &AllocateAndParseFlags);
return *jitter_flags;
}
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
std::call_once(flags_init, &AllocateAndParseFlags);
absl::call_once(flags_init, &AllocateAndParseFlags);
AppendMarkForCompilationPassFlagsInternal(flag_list);
}
static bool xla_is_enabled = false;
void SetXlaIsEnabled() { xla_is_enabled = true; }
bool IsXlaEnabled() { return xla_is_enabled; }
} // namespace tensorflow

View File

@ -87,6 +87,9 @@ struct XlaDeviceFlags {
// Enabling this mode by a legacy flag is a temporary mechanism. When this
// feature is battle-tested, we will switch this to be a session option.
bool tf_xla_compile_on_demand;
// Enables "XLA" devices if this flag is set.
bool tf_xla_enable_xla_devices;
};
// Flags common to the _Xla* ops and their kernels.
@ -151,6 +154,15 @@ GetIntroduceFloatingPointJitterPassFlags();
// Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet.
void AppendMarkForCompilationPassFlags(
std::vector<tensorflow::Flag>* flag_list);
// Makes all future calls to `IsXlaEnabled()` return `true`.
//
// Should only be called when XLA is linked in.
void SetXlaIsEnabled();
// Returns whether XLA is enabled.
bool IsXlaEnabled();
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <unordered_map>
#include <unordered_set>
#include "absl/base/call_once.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_join.h"
@ -1616,8 +1617,8 @@ StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileClusterImpl(
if (!should_compile && global_jit_level_ != OptimizerOptions::OFF &&
device_type.type_string() == DEVICE_CPU) {
static std::once_flag once;
std::call_once(once, [] {
static absl::once_flag once;
absl::call_once(once, [] {
LOG(WARNING)
<< "(One-time warning): Not using XLA:CPU for cluster because envvar "
"TF_XLA_FLAGS=--tf_xla_cpu_global_jit was not set. If you want "
@ -1776,9 +1777,9 @@ absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable() {
"Lgamma", "Digamma",
// Binary
"Add", "AddV2", "Sub", "Mul", "Div", "Atan2", "Complex", "DivNoNan",
"MulNoNan", "FloorDiv", "Xlogy", "Xdivy", "FloorMod", "BitwiseAnd",
"BitwiseOr", "BitwiseXor", "LeftShift", "RightShift", "LogicalAnd",
"LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
"MulNoNan", "FloorDiv", "Xlogy", "Xlog1py", "Xdivy", "FloorMod",
"BitwiseAnd", "BitwiseOr", "BitwiseXor", "LeftShift", "RightShift",
"LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
"ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "TruncateDiv",
"TruncateMod", "Equal", "NotEqual", "Greater", "GreaterEqual",
"Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", "SoftsignGrad",
@ -1872,6 +1873,8 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"Einsum",
"EmptyTensorList",
"ExtractImagePatches",
"Igamma",
"Igammac",
"FFT",
"FFT2D",
"FFT3D",

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/node_matchers.h"
#include <utility>
#include "absl/algorithm/container.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
@ -24,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/graph_node_util.h"
namespace tensorflow {
namespace testing {

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/graph/graph_node_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/public/version.h"

View File

@ -17,7 +17,10 @@ limitations under the License.
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/util/dump_graph.h"
@ -39,7 +42,7 @@ Status ShapeHandleToTensorShape(shape_inference::InferenceContext* context,
return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape);
}
Status PropagateShapes(const Graph& graph,
Status PropagateShapes(Graph* graph,
const std::map<int, InferredShape>& arg_shapes,
const std::vector<BackEdgeHelper::BackEdge>& back_edges,
ShapeRefiner* shape_refiner) {
@ -54,7 +57,7 @@ Status PropagateShapes(const Graph& graph,
// shapes.
// TODO(phawkins): handle cyclic graphs.
std::vector<Node*> order;
GetReversePostOrder(graph, &order);
GetReversePostOrder(*graph, &order);
for (Node* n : order) {
// Ignore the status returned by the shape_refiner. We want the best effort
@ -99,6 +102,67 @@ Status PropagateShapes(const Graph& graph,
}
}
// Sometimes we have VariableShape nodes in while loop (after Enter nodes).
// They won't be constant-folded because TensorFlow constant folding does
// not handle Enter nodes (and thus does not handle any nodes after Enter
// nodes). We try to replace such VariableShape nodes with Const nodes here.
if (n->type_string() == "VariableShape") {
shape_inference::InferenceContext* context = shape_refiner->GetContext(n);
auto handle_shapes_and_types = context->input_handle_shapes_and_types(0);
if (handle_shapes_and_types && !handle_shapes_and_types->empty()) {
shape_inference::ShapeHandle handle =
handle_shapes_and_types->at(0).shape;
TensorShapeProto shape_proto;
context->ShapeHandleToProto(handle, &shape_proto);
if (!shape_proto.unknown_rank()) {
NodeDef const_def;
const_def.set_op("Const");
Node* var_node;
TF_RETURN_IF_ERROR(n->input_node(0, &var_node));
const_def.set_name(
graph->NewName(absl::StrCat("var_shape_", var_node->name())));
DataType dtype = n->output_type(0);
AddNodeAttr("dtype", dtype, &const_def);
TensorProto value;
value.set_dtype(dtype);
value.mutable_tensor_shape()->add_dim()->set_size(
shape_proto.dim_size());
for (const auto& dim : shape_proto.dim()) {
if (dtype == DT_INT32) {
value.add_int_val(dim.size());
} else {
value.add_int64_val(dim.size());
}
}
AddNodeAttr("value", value, &const_def);
for (auto const& attr : n->attrs()) {
if (*attr.first.begin() == '_') {
AddNodeAttr(attr.first, attr.second, &const_def);
}
}
Status s;
Node* const_node = graph->AddNode(const_def, &s);
TF_RETURN_IF_ERROR(s);
graph->AddControlEdge(var_node, const_node);
std::vector<const Edge*> out_edges(n->out_edges().begin(),
n->out_edges().end());
for (const Edge* e : out_edges) {
if (e->IsControlEdge()) {
graph->AddControlEdge(const_node, e->dst());
graph->RemoveEdge(e);
} else {
Node* dst = e->dst();
int dst_input = e->dst_input();
graph->RemoveEdge(e);
graph->AddEdge(const_node, 0, dst, dst_input);
}
}
}
}
}
// Merge node causes a loop so we remove NextIteration->Merge edge before
// performing shape inference. But removing those edges also prevents us
// from inferring output shape for Merge node (we need shapes for all its
@ -196,7 +260,7 @@ Status InferShapes(Graph* graph, const std::map<int, InferredShape>& arg_shapes,
// the shape inference is complete.
BackEdgeHelper back_edge;
TF_RETURN_IF_ERROR(back_edge.Remove(graph));
TF_RETURN_IF_ERROR(PropagateShapes(*graph, arg_shapes,
TF_RETURN_IF_ERROR(PropagateShapes(graph, arg_shapes,
back_edge.RemovedEdges(), &shape_refiner));
TF_RETURN_IF_ERROR(back_edge.Replace());

View File

@ -163,12 +163,11 @@ Status XlaCompilationCache::BuildExecutable(
build_options.set_device_allocator(options.device_allocator);
build_options.set_alias_passthrough_params(options.alias_passthrough_params);
auto compile_result =
client_->Compile(*result.computation, argument_layouts, build_options);
if (!compile_result.ok()) {
return compile_result.status();
}
*executable = std::move(compile_result.ValueOrDie());
TF_ASSIGN_OR_RETURN(
auto executables,
client_->Compile(*result.computation, argument_layouts, build_options));
TF_RET_CHECK(executables.size() == 1);
*executable = std::move(executables[0]);
return Status::OK();
}

View File

@ -36,8 +36,13 @@ class XlaCpuDeviceFactory : public DeviceFactory {
};
Status XlaCpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0"));
XlaDeviceFlags* flags = GetXlaDeviceFlags();
if (!flags->tf_xla_enable_xla_devices) {
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
return Status::OK();
}
devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0"));
return Status::OK();
}
@ -45,6 +50,10 @@ Status XlaCpuDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) {
XlaDeviceFlags* flags = GetXlaDeviceFlags();
if (!flags->tf_xla_enable_xla_devices) {
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
return Status::OK();
}
bool compile_on_demand = flags->tf_xla_compile_on_demand;
XlaOpRegistry::DeviceRegistration registration;

View File

@ -20,7 +20,9 @@ limitations under the License.
#include <unordered_set>
#include <utility>
#include "absl/base/call_once.h"
#include "absl/memory/memory.h"
#include "absl/strings/match.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
#include "tensorflow/compiler/jit/xla_device_context.h"
@ -386,14 +388,33 @@ Status XlaDevice::TryGetDeviceContext(DeviceContext** out_context) {
return Status::OK();
}
// Warn about XLA_CPU/XLA_GPU exactly once.
static void ShowXlaDeviceDeprecationWarning(
absl::string_view compilation_device_name) {
static absl::once_flag once;
if (absl::StrContains(compilation_device_name, "CPU") ||
absl::StrContains(compilation_device_name, "GPU")) {
absl::call_once(once, [] {
LOG(WARNING)
<< "XLA_GPU and XLA_CPU devices are deprecated and will be "
"removed in subsequent releases. Instead, use either "
"@tf.function(experimental_compile=True) for must-compile "
"semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 "
"for auto-clustering best-effort compilation.";
});
}
}
void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":"
<< op_kernel->type_string();
ShowXlaDeviceDeprecationWarning(jit_device_name_.type_string());
op_kernel->Compute(context);
}
void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) {
ShowXlaDeviceDeprecationWarning(jit_device_name_.type_string());
VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
<< op_kernel->type_string();
op_kernel->ComputeAsync(context, done);

View File

@ -140,7 +140,6 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
// The device tensor should always be fresh.
TF_RET_CHECK(!xla_tensor->has_shaped_buffer());
xla_tensor->set_host_tensor(*cpu_tensor);
TF_RETURN_IF_ERROR(
xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
stream_->parent()->device_ordinal()));

View File

@ -14,17 +14,20 @@ limitations under the License.
==============================================================================*/
// Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
// operators using XLA via the XLA "CUDA" (GPU) backend.
// operators using XLA via the XLA "CUDA" or "ROCM" (GPU) backend.
#include <set>
#include "absl/memory/memory.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_split.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
@ -61,7 +64,14 @@ class XlaGpuDeviceFactory : public DeviceFactory {
};
Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA");
XlaDeviceFlags* flags = GetXlaDeviceFlags();
if (!flags->tf_xla_enable_xla_devices) {
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
return Status::OK();
}
auto platform =
se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
if (!platform.ok()) {
// Treat failures as non-fatal; there might not be a GPU in the machine.
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
@ -84,6 +94,12 @@ Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
Status XlaGpuDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) {
XlaDeviceFlags* flags = GetXlaDeviceFlags();
if (!flags->tf_xla_enable_xla_devices) {
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
return Status::OK();
}
XlaOpRegistry::DeviceRegistration registration;
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
registration.autoclustering_policy =
@ -103,7 +119,8 @@ Status XlaGpuDeviceFactory::CreateDevices(
RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT);
(void)registrations;
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA");
auto platform =
se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
if (!platform.ok()) {
// Treat failures as non-fatal; there might not be a GPU in the machine.
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_kernel_creator.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_kernel_creator_util.h"
#include "tensorflow/core/common_runtime/function.h"
@ -39,6 +40,10 @@ bool RegisterLaunchOpCreator() {
}
static bool register_me = RegisterLaunchOpCreator();
static bool register_xla = [] {
SetXlaIsEnabled();
return true;
}();
} // end namespace
} // namespace tensorflow

View File

@ -222,8 +222,9 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
OpKernelConstruction construction(
DeviceType(dev->device_type()), dev,
dev->GetAllocator(AllocatorAttributes()), &node_def,
&fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types,
fbody->ret_types, output_memory_types, flr->graph_def_version(), &s);
&fbody->fdef.signature(), flr, dev->resource_manager(), fbody->arg_types,
input_memory_types, fbody->ret_types, output_memory_types,
flr->graph_def_version(), &s);
*kernel = absl::make_unique<XlaLocalLaunchBase>(
&construction, constant_arg_indices, resource_arg_indices, function);

View File

@ -44,8 +44,11 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"@llvm-project//llvm:support",
"@llvm-project//mlir:AffineDialectRegistration",
"@llvm-project//mlir:LoopDialectRegistration",
"@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOpsDialectRegistration",
"@llvm-project//mlir:Support",
"@llvm-project//mlir/test:TestTransforms",
],
@ -63,6 +66,8 @@ cc_library(
"//tensorflow/compiler/mlir/lite:tensorflow_lite_optimize",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
"//tensorflow/compiler/mlir/lite/quantization:quantization_passes",
"//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_to_quant",
"//tensorflow/compiler/mlir/lite/quantization/xla:hlo_xla_quantization_passes",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
@ -74,15 +79,16 @@ cc_library(
"//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_linalg",
"//tensorflow/compiler/mlir/xla:xla_dialect_registration",
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/mlir/xla:xla_legalize_to_linalg",
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
"//tensorflow/compiler/mlir/xla:xla_lower",
"@llvm-project//mlir:AffineDialectRegistration",
"//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts",
"//tensorflow/compiler/mlir/xla:xla_test_passes",
"@llvm-project//mlir:AffineOps",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:QuantOpsDialectRegistration",
],
)

View File

@ -26,9 +26,11 @@ package_group(
filegroup(
name = "tensorflow_lite_ops_td_files",
srcs = [
"ir/tfl_op_interfaces.td",
"ir/tfl_ops.td",
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:include/mlir/Transforms/LoopLikeInterface.td",
],
)
@ -55,6 +57,25 @@ gentbl(
],
)
gentbl(
name = "tensorflow_lite_op_interfaces_inc_gen",
tbl_outs = [
(
"-gen-op-interface-decls",
"ir/tfl_ops_interface.h.inc",
),
(
"-gen-op-interface-defs",
"ir/tfl_ops_interface.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ir/tfl_op_interfaces.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
],
)
gentbl(
name = "tensorflow_lite_prepare_tf_inc_gen",
tbl_outs = [
@ -177,11 +198,12 @@ cc_library(
"ir/tfl_ops.cc",
"ir/tfl_ops.cc.inc",
"ir/tfl_ops.h.inc",
"ir/tfl_ops_interface.cc.inc",
"ir/tfl_ops_interface.h.inc",
"utils/attribute_utils.cc",
],
hdrs = [
"ir/tfl_ops.h",
"ir/tfl_traits.h",
"transforms/passes.h",
"utils/attribute_utils.h",
"//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",
@ -190,8 +212,6 @@ cc_library(
deps = [
":tensorflow_lite_ops_inc_gen",
":validators",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/lite/schema:schema_fbs",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:Dialect",
@ -200,6 +220,10 @@ cc_library(
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
# TODO(jpienaar): Move this out after splitting out LoopLikeOpInterface.
"@llvm-project//mlir:Transforms",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/lite/schema:schema_fbs",
],
alwayslink = 1,
)
@ -258,6 +282,7 @@ tf_cc_test(
cc_library(
name = "tensorflow_lite_legalize_tf",
srcs = [
"transforms/dilated_conv.cc",
"transforms/extract_ophint.cc",
"transforms/generated_legalize_tf.inc",
"transforms/generated_lower_static_tensor_list.inc",
@ -273,6 +298,7 @@ cc_library(
"transforms/unroll_batch_matmul.cc",
],
hdrs = [
"transforms/dilated_conv.h",
"transforms/passes.h",
"transforms/unroll_batch_matmul.h",
],
@ -284,13 +310,16 @@ cc_library(
":validators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:tensor_list",
"//tensorflow/core/platform:logging",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
@ -316,6 +345,7 @@ cc_library(
deps = [
":tensorflow_lite",
":validators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
@ -330,6 +360,7 @@ cc_library(
cc_library(
name = "tensorflow_lite_quantize",
srcs = [
"transforms/default_quant_params.cc",
"transforms/generated_post_quantize.inc",
"transforms/generated_quantize.inc",
"transforms/load_quantization_recipe.cc",
@ -346,6 +377,7 @@ cc_library(
":validators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/lite/quantization/lite:tfl_to_std",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
@ -370,6 +402,8 @@ genrule(
name = "op_quant_spec_getters_inc",
srcs = [
"ir/tfl_ops.td",
"ir/tfl_op_interfaces.td",
"@llvm-project//mlir:include/mlir/Transforms/LoopLikeInterface.td",
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
],
outs = [
@ -436,8 +470,13 @@ cc_library(
deps = [
":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
"//tensorflow/lite/kernels/internal:kernel_utils",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@flatbuffers",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
@ -501,6 +540,7 @@ cc_library(
"//tensorflow/lite:schema_fbs_version",
"//tensorflow/lite:string_util",
"//tensorflow/lite/delegates/flex:whitelisted_flex_ops_lib",
"//tensorflow/lite/kernels/internal:kernel_utils",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/tools/versioning:op_version",
"@com_google_absl//absl/base",
@ -666,12 +706,16 @@ cc_library(
],
)
exports_files(
["transforms/passes.h"],
cc_library(
name = "empty_passes",
hdrs = ["transforms/passes.h"],
visibility = [
"//configs/devtools/hawkeye/tflite:__subpackages__",
"//learning/brain/models/app_benchmarks:__subpackages__",
"//tensorflow/compiler/mlir/lite:friends",
"//tensorflow/lite/experimental/mlir:__subpackages__",
],
deps = [
"@llvm-project//llvm:support",
],
)

View File

@ -31,10 +31,11 @@ struct PassConfig {
: emit_builtin_tflite_ops(true),
lower_tensor_list_ops(false),
trim_functions_whitelist({}),
quant_specs(specs),
quant_specs(std::move(specs)),
skip_control_dialect(false),
form_clusters(false),
inline_functions(false) {}
inline_functions(true),
unfold_batch_matmul(true) {}
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
// added, which produces TF Lite ops.
@ -57,6 +58,9 @@ struct PassConfig {
// Inline function calls within the main function in the MLIR module, prior
// to legalization to TFLite.
bool inline_functions;
// if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set
// of tfl.fully_connected ops.
bool unfold_batch_matmul;
};
} // namespace TFL

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <algorithm>
#include <cctype>
#include <cstdint>
#include <iostream>
#include <sstream>
#include <string>
@ -103,12 +104,26 @@ using llvm::cl::opt;
// Commandline flag to enable the control of flatbuffer import.
bool use_external_constant;
// Commandline flag to enable graph pruning.
bool experimental_prune_unreachable_nodes_unconditionally;
// NOLINTNEXTLINE
static opt<bool, true> use_external_constant_flag(
"use-external-constant",
llvm::cl::desc("Use external constant during flatbuffer import"),
llvm::cl::location(use_external_constant), llvm::cl::init(false));
// TODO(b/147111261): After the importer supports generic custom ops, we should
// change the flag to a more lightwise flag, e.g.
// "import_custom_ops_as_side_effect_free_ops", and let the MLIR DCE to prune
// the operations.
// NOLINTNEXTLINE
static opt<bool, true> experimental_prune_unreachable_nodes_unconditionally_flg(
"experimental-prune-unreachable-nodes-unconditionally",
llvm::cl::desc("Prune nodes that are not ancestors of the output nodes."),
llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally),
llvm::cl::init(false));
namespace {
bool IsScalar(const TensorT& tensor) {
// TODO(b/138222071) We can't distinguish scalars and unranked tensors
@ -217,7 +232,7 @@ mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b,
// min/max stats is just for comments, so ignore it.
if (!tensor.quantization || IsQuantized(tensor)) return nullptr;
// If the result isn't float and unquantizable, the min/max is ignored.
if (!res->getType()
if (!res.getType()
.cast<mlir::ShapedType>()
.getElementType()
.isa<mlir::FloatType>()) {
@ -255,10 +270,23 @@ mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b,
}
StatusOr<std::string> OpNameForOpCode(const tflite::OperatorCodeT opcode) {
// TODO(krzysd) Support custom ops
// TODO(b/143872630): Support custom ops
if (opcode.builtin_code == tflite::BuiltinOperator_CUSTOM) {
return errors::Unimplemented("unsupported custom operation: ",
opcode.custom_code);
// Adding some custom op supported on GPU.
const absl::string_view custom_name = opcode.custom_code;
if (custom_name == "MaxPoolingWithArgmax2D") {
return std::string("tfl.max_pooling_with_argmax_2d");
}
if (custom_name == "Convolution2DTransposeBias") {
return std::string("tfl.convolution_2d_transpose_bias");
}
if (custom_name == "MaxUnpooling2D") {
return std::string("tfl.max_unpooling_2d");
}
// Use an unsupported op name instead of throwing an error here in case the
// op is pruned during the import.
return std::string(
llvm::Twine("tfl.UNSUPPORTED_custom_", opcode.custom_code).str());
}
if (opcode.builtin_code == tflite::BuiltinOperator_IF) {
return std::string("tf.If");
@ -361,7 +389,6 @@ StatusOr<mlir::ElementsAttr> ConvertIntBuffer(
mlir::RankedTensorType shaped_type, mlir::Type elem_type,
const std::vector<uint8_t>& buffer) {
unsigned bit_width;
mlir::RankedTensorType buffer_type;
if (auto itype = elem_type.dyn_cast<mlir::IntegerType>()) {
bit_width = itype.getWidth();
} else if (auto qtype = elem_type.dyn_cast<QuantizedType>()) {
@ -495,6 +522,13 @@ bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) {
}
}
// Returns true if this is a custom op.
bool IsCustomOp(const std::string& op_name) {
return op_name == "tfl.max_pooling_with_argmax_2d" ||
op_name == "tfl.max_unpooling_2d" ||
op_name == "tfl.convolution_2d_transpose_bias";
}
// TODO(krzysd) Handle function calls
StatusOr<Operation*> ConvertOp(
const tflite::OperatorT& op, const std::vector<Value>& vals_map,
@ -557,7 +591,15 @@ StatusOr<Operation*> ConvertOp(
}
llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
mlir::BuiltinOptionsToAttributes(op.builtin_options, builder, attrs);
if (IsCustomOp(op_name)) {
auto status = mlir::CustomOptionsToAttributes(op_name, op.custom_options,
builder, loc, &attrs);
if (!status.ok()) {
return emitError(loc, status.ToString()), status;
}
} else {
mlir::BuiltinOptionsToAttributes(op.builtin_options, builder, attrs);
}
op_state.addAttributes(attrs);
// Handle the conversion from subgraph index to functions for If and While
@ -619,6 +661,49 @@ mlir::NamedAttribute BuildTFEntryFunctionAttribute(
name, builder->getStringAttr(llvm::join(tensor_names, ",")));
}
// Given a list of output indices, traverses the subgraph and returns the set of
// ops that are ancestors of the output tensors.
StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
const tflite::SubGraphT& subgraph, ArrayRef<int32_t> output_indices) {
// Create a map from tensor index to defining op.
absl::flat_hash_map<int32_t, const tflite::OperatorT*> defining_op;
for (const auto& op : subgraph.operators) {
for (int32_t output : op->outputs) {
defining_op[output] = op.get();
}
}
std::vector<const tflite::OperatorT*> queue;
for (int32_t output : output_indices) {
if (auto& op = defining_op[output]) {
queue.push_back(op);
} else {
return errors::InvalidArgument("Output tensor doesn't have defining op");
}
}
// Traverse the graph towards inputs.
absl::flat_hash_set<const tflite::OperatorT*> visited;
while (!queue.empty()) {
const tflite::OperatorT* op = queue.back();
queue.pop_back();
if (!visited.insert(op).second) {
// The node has already been visited.
continue;
}
for (int32_t input : op->inputs) {
// Input tensor may not have a defining op in case it is a subgraph input
// or a constant tensor.
if (auto& op = defining_op[input]) {
queue.push_back(op);
}
}
}
return visited;
}
// Build a FuncOp from a tflite SubGraph
// The op_names are a mapping from indexes into the TFLite operators array to
// the operator name MLIR expects (tfl.foo_op). The buffers are directly taken
@ -635,7 +720,8 @@ StatusOr<FuncOp> ConvertSubgraph(
const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
Location base_loc, Builder builder,
const std::vector<std::string>& ordered_output_arrays, bool is_entry_point,
bool use_external_constant) {
bool use_external_constant,
bool experimental_prune_unreachable_nodes_unconditionally) {
llvm::SmallVector<mlir::Type, 2> ret_types;
llvm::SmallVector<mlir::Type, 4> input_types;
@ -731,8 +817,19 @@ StatusOr<FuncOp> ConvertSubgraph(
func.setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
}
absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
if (experimental_prune_unreachable_nodes_unconditionally) {
TF_ASSIGN_OR_RETURN(pruned_subgraph_ops,
PruneSubgraph(subgraph, func_outputs));
}
// Construct MLIR operators from TFLite operators
for (auto& op : subgraph.operators) {
if (experimental_prune_unreachable_nodes_unconditionally &&
!pruned_subgraph_ops.contains(op)) {
continue;
}
for (auto input_num : op->inputs) {
// The operators in a graph are topologically sorted
// and so if no previous operation has produced a tensor
@ -822,22 +919,21 @@ StatusOr<FuncOp> ConvertSubgraph(
// represents TFLite, this entry point must be called "main"
// TODO(b/131175224,b/132239787) Support multiple entry points
std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
if (subgraph.name.empty()) {
if (index == 0) {
return "main";
} else {
return llvm::formatv("fn_{0}", index).str();
}
} else {
return subgraph.name;
if (index == 0) {
return "main";
}
if (subgraph.name.empty()) {
return llvm::formatv("fn_{0}", index).str();
}
return subgraph.name;
}
} // namespace
OwningModuleRef tflite::FlatBufferToMlir(
absl::string_view buffer, MLIRContext* context, Location base_loc,
const std::vector<std::string>& ordered_output_arrays,
bool use_external_constant) {
bool use_external_constant,
bool experimental_prune_unreachable_nodes_unconditionally) {
auto model_ptr =
FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
if (nullptr == model_ptr) {
@ -892,7 +988,8 @@ OwningModuleRef tflite::FlatBufferToMlir(
// TODO(b/131175224,b/132239787) Support multiple entry points
builder, ordered_output_arrays,
/*is_entry_point=*/e.index() == 0,
/*use_external_constant=*/use_external_constant);
/*use_external_constant=*/use_external_constant,
experimental_prune_unreachable_nodes_unconditionally);
if (!func_or_error.ok()) {
return emitError(base_loc, "could not translate function ")
<< subgraph->name,
@ -905,9 +1002,10 @@ OwningModuleRef tflite::FlatBufferToMlir(
return OwningModuleRef(module);
}
static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
MLIRContext* context,
bool use_external_constant) {
static OwningModuleRef FlatBufferFileToMlirTrans(
llvm::SourceMgr* source_mgr, MLIRContext* context,
bool use_external_constant,
bool experimental_prune_unreachable_nodes_unconditionally) {
const llvm::MemoryBuffer* input =
source_mgr->getMemoryBuffer(source_mgr->getMainFileID());
std::string error;
@ -924,12 +1022,14 @@ static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
return tflite::FlatBufferToMlir(
absl::string_view(input->getBufferStart(), input->getBufferSize()),
context, loc, outputs, use_external_constant);
context, loc, outputs, use_external_constant,
experimental_prune_unreachable_nodes_unconditionally);
}
static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg(
"tflite-flatbuffer-to-mlir",
[](llvm::SourceMgr& source_mgr, MLIRContext* context) {
return FlatBufferFileToMlirTrans(&source_mgr, context,
use_external_constant);
return FlatBufferFileToMlirTrans(
&source_mgr, context, use_external_constant,
experimental_prune_unreachable_nodes_unconditionally);
});

View File

@ -31,11 +31,14 @@ namespace tflite {
// on failure, and more specific errors will be emitted via the context.
// If `use_external_constant` is true, it will create `tfl.external_const`
// instead of `tfl.const`.
// If `experimental_prune_unreachable_nodes_unconditionally` is true, nodes that
// are not ancestors of the output nodes will be pruned.
mlir::OwningModuleRef FlatBufferToMlir(
absl::string_view buffer, mlir::MLIRContext* context,
mlir::Location base_loc,
const std::vector<std::string>& ordered_output_arrays,
bool use_external_constant = false);
bool use_external_constant = false,
bool experimental_prune_unreachable_nodes_unconditionally = false);
} // namespace tflite
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_

View File

@ -17,6 +17,8 @@ limitations under the License.
#include <vector>
#include "absl/strings/str_cat.h"
#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSwitch.h"
#include "mlir/IR/Attributes.h" // TF:llvm-project
@ -24,8 +26,36 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace {
using ::tensorflow::Status;
using ::tensorflow::errors::InvalidArgument;
using ::xla::StatusOr;
StatusOr<mlir::StringAttr> GetPaddingAttr(TfLitePadding pad_params,
mlir::Builder builder,
mlir::Location loc) {
auto padding = tflite::Padding::Padding_VALID;
if (pad_params == TfLitePadding::kTfLitePaddingSame) {
padding = tflite::Padding_SAME;
} else if (pad_params == TfLitePadding::kTfLitePaddingValid) {
padding = tflite::Padding_VALID;
} else {
return InvalidArgument(
absl::StrCat("Invalid padding type", std::to_string(pad_params)));
}
const char* option_name = tflite::EnumNamePadding(padding);
return builder.getStringAttr(option_name);
}
} // namespace
// TODO(jpienaar): This is a placeholder. This should be done in more efficient
// way when part of the translation of module.
static tflite::ActivationFunctionType ConvertTFL_AFAttrForOptionWriter(
@ -212,5 +242,44 @@ static mlir::Attribute BuildTFL_PaddingAttr(tflite::Padding value,
return builder.getStringAttr(option_name);
}
Status mlir::CustomOptionsToAttributes(
const std::string& op_name, const std::vector<uint8_t>& custom_options,
mlir::Builder builder, mlir::Location loc,
llvm::SmallVectorImpl<mlir::NamedAttribute>* attributes) {
if (op_name == "tfl.max_pooling_with_argmax_2d" ||
op_name == "tfl.max_unpooling_2d") {
auto* pool_params =
reinterpret_cast<const TfLitePoolParams*>(custom_options.data());
TF_ASSIGN_OR_RETURN(auto padding_attribute,
GetPaddingAttr(pool_params->padding, builder, loc));
attributes->emplace_back(
builder.getNamedAttr("padding", padding_attribute));
attributes->emplace_back(builder.getNamedAttr(
"stride_h", builder.getI32IntegerAttr(pool_params->stride_height)));
attributes->emplace_back(builder.getNamedAttr(
"stride_w", builder.getI32IntegerAttr(pool_params->stride_width)));
attributes->emplace_back(builder.getNamedAttr(
"filter_h", builder.getI32IntegerAttr(pool_params->filter_height)));
attributes->emplace_back(builder.getNamedAttr(
"filter_w", builder.getI32IntegerAttr(pool_params->filter_width)));
return Status::OK();
} else if (op_name == "tfl.convolution_2d_transpose_bias") {
auto* conv_params = reinterpret_cast<const TfLiteTransposeConvParams*>(
custom_options.data());
TF_ASSIGN_OR_RETURN(auto padding_attribute,
GetPaddingAttr(conv_params->padding, builder, loc));
attributes->emplace_back(
builder.getNamedAttr("padding", padding_attribute));
attributes->emplace_back(builder.getNamedAttr(
"stride_h", builder.getI32IntegerAttr(conv_params->stride_height)));
attributes->emplace_back(builder.getNamedAttr(
"stride_w", builder.getI32IntegerAttr(conv_params->stride_width)));
return Status::OK();
}
return InvalidArgument(absl::StrCat("invalid custom op type: ", op_name));
}
// Pull in FlatBuffer writers for TFLite generated using TableGen
#include "tensorflow/compiler/mlir/lite/operator_converters.inc"

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace mlir {
@ -45,7 +46,7 @@ llvm::Optional<flatbuffers::Offset<tflite::Operator>> CreateFlatBufferOperator(
const std::vector<int32_t> &operands, const std::vector<int32_t> &results,
flatbuffers::FlatBufferBuilder *fbb);
// Populate the array of mlir::NamedAttributes corresponding to the given
// Populates the array of mlir::NamedAttributes corresponding to the given
// tflite::FlatbufferOptionsUnion.
// We use an out parameter per LLVM convention
void BuiltinOptionsToAttributes(
@ -53,6 +54,15 @@ void BuiltinOptionsToAttributes(
// NOLINTNEXTLINE
llvm::SmallVectorImpl<mlir::NamedAttribute> &attributes);
// Populates the array of mlir::NamedAttributes corresponding to the given
// custom_options.
// We use an out parameter per LLVM convention
tensorflow::Status CustomOptionsToAttributes(
const std::string &op_name, const std::vector<uint8_t> &custom_options,
mlir::Builder builder,
// NOLINTNEXTLINE
Location loc, llvm::SmallVectorImpl<mlir::NamedAttribute> *attributes);
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_OPERATOR_H_

View File

@ -71,6 +71,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/tools/versioning/op_version.h"
@ -89,6 +90,7 @@ using mlir::MLIRContext;
using mlir::ModuleOp;
using mlir::NoneType;
using mlir::Operation;
using mlir::Region;
using mlir::StringAttr;
using mlir::TensorType;
using mlir::TranslateFromMLIRRegistration;
@ -218,6 +220,13 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
auto qtype = type.cast<mlir::quant::UniformQuantizedPerAxisType>();
return GetTFLiteType(qtype.getStorageType(), qtype.isSigned());
}
case mlir::TF::TensorFlowTypes::RESOURCE: {
// Treat tf.resource values as integer values in flatbuffer.
// TODO(b/146131919): Maybe need to have a detailed design for supporting
// other resource types beyonds hash table resources and resource
// variables.
return tflite::TensorType_INT32;
}
default:
// TFLite export fills FLOAT32 for unknown data types. Returning an error
// for now for safety and this could be revisited when required.
@ -233,17 +242,17 @@ static bool IsConst(Operation* op) {
template <typename T>
static bool HasValidTFLiteType(Value value, T& error_handler) {
// None type is allowed to represent unspecified operands.
if (value->getType().isa<NoneType>()) return true;
if (value.getType().isa<NoneType>()) return true;
auto type = value->getType().dyn_cast<TensorType>();
auto type = value.getType().dyn_cast<TensorType>();
if (!type) {
if (auto op = value->getDefiningOp()) {
if (auto op = value.getDefiningOp()) {
error_handler.emitError()
<< '\'' << op << "' should produce value of tensor type instead of "
<< value->getType();
<< value.getType();
return false;
}
error_handler.emitError("expected tensor type, got ") << value->getType();
error_handler.emitError("expected tensor type, got ") << value.getType();
return false;
}
@ -282,7 +291,7 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
for (auto arg : bb.getArguments()) {
if (!HasValidTFLiteType(arg, fn))
return fn.emitError("invalid TFLite type: ") << arg->getType(), false;
return fn.emitError("invalid TFLite type: ") << arg.getType(), false;
}
// Verify that all operations except the terminator have exactly one
@ -292,7 +301,7 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
for (auto result : inst.getResults()) {
if (!HasValidTFLiteType(result, inst))
return fn.emitError("invalid TFLite type: ") << result->getType(),
return fn.emitError("invalid TFLite type: ") << result.getType(),
false;
}
}
@ -301,7 +310,7 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
return true;
}
static std::unique_ptr<::tensorflow::NodeDef> getTensorFlowNodeDef(
static std::unique_ptr<::tensorflow::NodeDef> GetTensorFlowNodeDef(
::mlir::Operation* inst) {
// We pass empty string for the original node_def name since Flex runtime
// does not care about this being set correctly on node_def. There is no
@ -317,6 +326,48 @@ static std::unique_ptr<::tensorflow::NodeDef> getTensorFlowNodeDef(
return std::move(status_or_node_def.ValueOrDie());
}
// Converts a mlir padding StringRef to TfLitePadding.
// Returns llvm::None if conversion fails.
static Optional<TfLitePadding> GetTflitePadding(Operation* inst,
llvm::StringRef padding) {
const tflite::Padding padding_attr =
std::move(llvm::StringSwitch<tflite::Padding>(padding)
.Case("SAME", tflite::Padding_SAME)
.Case("VALID", tflite::Padding_VALID));
if (padding_attr == tflite::Padding_SAME) {
return kTfLitePaddingSame;
}
if (padding_attr == tflite::Padding_VALID) {
return kTfLitePaddingValid;
}
return inst->emitOpError() << "Invalid padding attribute: " << padding,
llvm::None;
}
// Extracts TfLitePoolParams from a TFL custom op.
// Template parameter, TFLOp, should be a TFL custom op containing attributes
// generated from TfLitePoolParams.
// Returns llvm::None if conversion fails.
template <typename TFLOp>
static Optional<TfLitePoolParams> GetTflitePoolParams(Operation* inst,
TFLOp op) {
TfLitePoolParams pool_params;
pool_params.stride_height = op.stride_h().getSExtValue();
pool_params.stride_width = op.stride_w().getSExtValue();
pool_params.filter_height = op.filter_h().getSExtValue();
pool_params.filter_width = op.filter_w().getSExtValue();
const auto padding = GetTflitePadding(inst, op.padding());
if (padding) {
pool_params.padding = *padding;
pool_params.activation = kTfLiteActNone;
pool_params.computed.padding = TfLitePaddingValues{0, 0, 0, 0};
return pool_params;
}
return llvm::None;
}
namespace {
// Translates an MLIR module in TFLite dialect to TFLite FlatBuffer.
@ -375,9 +426,36 @@ class Translator {
mlir::TF::WhileOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
// Build while operator where cond & body are regions.
BufferOffset<tflite::Operator> BuildWhileOperator(
mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
// Builds custom operators.
// Templated on a) data type of custom_option to be stored into flatbuffer,
// and b) TFL custom op type.
template <typename CustomOptionType, typename TFLOp>
BufferOffset<tflite::Operator> BuildCustomOperator(
const CustomOptionType& custom_option, const std::string& opcode_name,
TFLOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
BufferOffset<tflite::Operator> BuildNumericVerifyOperator(
mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
Optional<BufferOffset<tflite::Operator>>
BuildConvolution2DTransposeBiasOperator(
Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op,
const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
Optional<BufferOffset<tflite::Operator>> BuildMaxPoolingWithArgMax2DOperator(
Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op,
const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
Optional<BufferOffset<tflite::Operator>> BuildMaxUnpooling2DOperator(
Operation* inst, mlir::TFL::MaxUnpooling2DOp op,
const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
Optional<CustomOptionsOffset> CreateFlexOpCustomOptions(
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
@ -400,7 +478,10 @@ class Translator {
Operation* inst, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(FuncOp fn);
// Build a subgraph with a given name out of the region either corresponding
// to a function's body or while op.
Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(
const std::string& name, Region* region);
// Builds Metadata with the given `name` and buffer `content`.
BufferOffset<tflite::Metadata> BuildMetadata(StringRef name,
@ -422,6 +503,12 @@ class Translator {
// Returns a unique name for `val`.
std::string UniqueName(mlir::Value val);
// Returns the names of the subgraphs corresponding the regions of the op. The
// names are supposed to be unique as the op name is unique and the suffix is
// not a valid name.
std::string GetWhileBodyName(mlir::TFL::WhileOp while_op);
std::string GetWhileCondName(mlir::TFL::WhileOp while_op);
ModuleOp module_;
tensorflow::OpOrArgNameMapper& name_mapper_;
@ -451,7 +538,7 @@ class Translator {
};
std::string Translator::UniqueName(mlir::Value val) {
return name_mapper_.GetUniqueName(val);
return std::string(name_mapper_.GetUniqueName(val));
}
Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
@ -504,7 +591,7 @@ Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
Value value, const std::string& name, unsigned buffer_idx) {
auto type = value->getType().cast<TensorType>();
auto type = value.getType().cast<TensorType>();
// TFLite requires tensor shape only for the inputs and constants.
// However, we output all known shapes for better round-tripping
@ -516,19 +603,20 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range))
return mlir::emitError(
value->getLoc(),
value.getLoc(),
"result shape dimensions out of 32 bit int type range");
return mlir::success();
};
std::vector<int32_t> shape;
std::vector<int32_t> shape_signature;
if (type.hasStaticShape()) {
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
} else if (auto* inst = value->getDefiningOp()) {
} else if (auto* inst = value.getDefiningOp()) {
if (IsConst(inst)) {
// Const op can have a result of dynamic shaped type (e.g. due to constant
// folding), but we can still derive the shape of a constant tensor for
@ -540,7 +628,17 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
}
} else if (type.hasRank()) {
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
shape.reserve(shape_ref.size());
for (auto& dim : shape_ref) {
shape.push_back(dim == -1 ? 1 : dim);
}
shape_signature = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
}
Type element_type = type.getElementType();
tflite::TensorType tflite_element_type =
GetTFLiteType(type.getElementType()).ValueOrDie();
@ -571,16 +669,25 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
// marked as a stateful. If so, set the tensor's is_variable as true
// This is v1 ref variable semantics in the TFLite runtime.
bool is_variable = false;
for (auto& use : value->getUses()) {
for (auto& use : value.getUses()) {
is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber());
if (is_variable) {
break;
}
}
return tflite::CreateTensor(
builder_, builder_.CreateVector(shape), tflite_element_type,
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
/*is_variable=*/is_variable);
if (shape_signature.empty()) {
return tflite::CreateTensor(
builder_, builder_.CreateVector(shape), tflite_element_type,
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
/*is_variable=*/is_variable);
} else {
return tflite::CreateTensor(
builder_, builder_.CreateVector(shape), tflite_element_type,
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
/*is_variable=*/is_variable, /*sparsity=*/0,
/*shape_signature=*/builder_.CreateVector(shape_signature));
}
}
BufferOffset<tflite::Operator> Translator::BuildIfOperator(
@ -615,19 +722,96 @@ BufferOffset<tflite::Operator> Translator::BuildWhileOperator(
builtin_options);
}
std::string Translator::GetWhileBodyName(mlir::TFL::WhileOp while_op) {
return (name_mapper_.GetUniqueName(while_op.getOperation()) + "$body").str();
}
std::string Translator::GetWhileCondName(mlir::TFL::WhileOp while_op) {
return (name_mapper_.GetUniqueName(while_op.getOperation()) + "$cond").str();
}
BufferOffset<tflite::Operator> Translator::BuildWhileOperator(
mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results) {
auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE);
int body_subgraph_index = subgraph_index_map_.at(GetWhileBodyName(op));
int cond_subgraph_index = subgraph_index_map_.at(GetWhileCondName(op));
auto builtin_options = tflite::CreateWhileOptions(
builder_, cond_subgraph_index, body_subgraph_index)
.Union();
auto inputs = builder_.CreateVector(operands);
auto outputs = builder_.CreateVector(results);
return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
tflite::BuiltinOptions_WhileOptions,
builtin_options);
}
template <typename CustomOptionType, typename TFLOp>
BufferOffset<tflite::Operator> Translator::BuildCustomOperator(
const CustomOptionType& custom_option, const std::string& opcode_name,
TFLOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results) {
std::vector<uint8_t> custom_option_vector(sizeof(CustomOptionType));
memcpy(custom_option_vector.data(), &custom_option, sizeof(CustomOptionType));
auto opcode_index =
GetOpcodeIndex(opcode_name, tflite::BuiltinOperator_CUSTOM);
return tflite::CreateOperator(
builder_, opcode_index, builder_.CreateVector(operands),
builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
/*builtin_options=*/0,
builder_.CreateVector<uint8_t>(custom_option_vector),
tflite::CustomOptionsFormat_FLEXBUFFERS);
}
BufferOffset<tflite::Operator> Translator::BuildNumericVerifyOperator(
mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results) {
float tolerance = op.tolerance().convertToFloat();
std::vector<uint8_t> custom_options(sizeof(float));
memcpy(custom_options.data(), &tolerance, sizeof(float));
auto opcode_index =
GetOpcodeIndex("NumericVerify", tflite::BuiltinOperator_CUSTOM);
return tflite::CreateOperator(
builder_, opcode_index, builder_.CreateVector(operands),
builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
/*builtin_options=*/0, builder_.CreateVector<uint8_t>(custom_options),
tflite::CustomOptionsFormat_FLEXBUFFERS);
return BuildCustomOperator(tolerance, "NumericVerify", op, operands, results);
}
Optional<BufferOffset<tflite::Operator>>
Translator::BuildConvolution2DTransposeBiasOperator(
Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op,
const std::vector<int32_t>& operands, const std::vector<int32_t>& results) {
TfLiteTransposeConvParams conv_params;
conv_params.stride_height = op.stride_h().getSExtValue();
conv_params.stride_width = op.stride_w().getSExtValue();
const auto padding = GetTflitePadding(inst, op.padding());
if (padding) {
conv_params.padding = *padding;
return BuildCustomOperator(conv_params, "Convolution2DTransposeBias", op,
operands, results);
}
return llvm::None;
}
Optional<BufferOffset<tflite::Operator>>
Translator::BuildMaxPoolingWithArgMax2DOperator(
Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op,
const std::vector<int32_t>& operands, const std::vector<int32_t>& results) {
const auto pool_params = GetTflitePoolParams(inst, op);
if (pool_params) {
return BuildCustomOperator(*pool_params, "MaxPoolingWithArgmax2D", op,
operands, results);
}
return llvm::None;
}
Optional<BufferOffset<tflite::Operator>>
Translator::BuildMaxUnpooling2DOperator(Operation* inst,
mlir::TFL::MaxUnpooling2DOp op,
const std::vector<int32_t>& operands,
const std::vector<int32_t>& results) {
const auto pool_params = GetTflitePoolParams(inst, op);
if (pool_params) {
return BuildCustomOperator(*pool_params, "MaxUnpooling2D", op, operands,
results);
}
return llvm::None;
}
Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
@ -769,6 +953,24 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
if (auto verify_op = dyn_cast<mlir::TFL::NumericVerifyOp>(inst)) {
return BuildNumericVerifyOperator(verify_op, operands, results);
}
if (auto conv_transpose_bias_op =
dyn_cast<mlir::TFL::Convolution2DTransposeBiasOp>(inst)) {
return BuildConvolution2DTransposeBiasOperator(
inst, conv_transpose_bias_op, operands, results);
}
if (auto max_pooling_with_arg_max_op =
dyn_cast<mlir::TFL::MaxPoolingWithArgMax2DOp>(inst)) {
return BuildMaxPoolingWithArgMax2DOperator(
inst, max_pooling_with_arg_max_op, operands, results);
}
if (auto max_unpooling_op = dyn_cast<mlir::TFL::MaxUnpooling2DOp>(inst)) {
return BuildMaxUnpooling2DOperator(inst, max_unpooling_op, operands,
results);
}
if (auto whileOp = dyn_cast<mlir::TFL::WhileOp>(inst)) {
return BuildWhileOperator(whileOp, operands, results);
}
inst->emitOpError("is not a supported TFLite op");
return llvm::None;
}
@ -805,7 +1007,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
// we emit op as flex.
// if custom is enabled
// we emit the op as custom.
auto node_def = getTensorFlowNodeDef(inst);
auto node_def = GetTensorFlowNodeDef(inst);
if (!node_def) {
return llvm::None;
}
@ -904,18 +1106,16 @@ void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) {
bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) {
std::vector<int> operand_indices;
// TODO(b/138254427): When the bug is addressed, we'll be able to inspect
// for the presence of a specific OpTrait using mlir::Operation, without
// having to cast it to specific ops like below.
// Until then, when a new RNN/LSTM op is added to TFLite and has stateful
// tensors as operands, they will need to be added here as well.
if (!mlir::TFL::IsStatefulOp(op, &operand_indices)) return false;
return absl::c_find(operand_indices, operand_index) != operand_indices.end();
}
Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
const std::string& name, Region* region) {
bool has_input_attr = false;
InitializeNamesFromAttribute(fn, &has_input_attr);
if (auto fn = dyn_cast<FuncOp>(region->getParentOp())) {
InitializeNamesFromAttribute(fn, &has_input_attr);
}
std::vector<BufferOffset<tflite::Tensor>> tensors;
llvm::DenseMap<Value, int> tensor_index_map;
@ -923,7 +1123,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
// on failure.
auto build_tensor_and_buffer = [&](Value value, const std::string& name) {
// NoneType represents optional and may be skipped here.
if (value->getType().isa<NoneType>()) {
if (value.getType().isa<NoneType>()) {
return true;
}
@ -936,7 +1136,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
// make the Buffer empty apart from setting the buffer_idx=0 in the Tensor.
// This does not seem to affect runtime behavior for RNN/LSTM, but would be
// good for reducing memory footprint.
if (auto* inst = value->getDefiningOp()) {
if (auto* inst = value.getDefiningOp()) {
auto buffer_or = BuildBuffer(inst);
if (!buffer_or) return false;
buffers_.push_back(*buffer_or);
@ -947,7 +1147,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
};
std::vector<BufferOffset<tflite::Operator>> operators;
auto& bb = fn.getBlocks().front();
auto& bb = region->front();
// Main function's arguments are first passed to `input` op so they don't
// have associated tensor and buffer. Build FlatBuffer tensor and buffer for
@ -955,7 +1155,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
for (unsigned i = 0, e = bb.getNumArguments(); i < e; ++i) {
mlir::BlockArgument arg = bb.getArgument(i);
std::string name;
if (has_input_attr) name = name_mapper_.GetUniqueName(arg);
if (has_input_attr) name = std::string(name_mapper_.GetUniqueName(arg));
if (name.empty()) name = absl::StrCat("arg", i);
if (!build_tensor_and_buffer(arg, name)) return llvm::None;
}
@ -976,7 +1176,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
std::vector<int32_t> operands;
operands.reserve(inst.getNumOperands());
for (auto operand : inst.getOperands()) {
if (operand->getType().isa<NoneType>())
if (operand.getType().isa<NoneType>())
operands.push_back(kTfLiteOptionalTensor);
else
operands.push_back(tensor_index_map.lookup(operand));
@ -1007,7 +1207,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
return tflite::CreateSubGraph(
builder_, builder_.CreateVector(tensors), builder_.CreateVector(inputs),
builder_.CreateVector(outputs), builder_.CreateVector(operators),
/*name=*/builder_.CreateString(fn.getName().str()));
/*name=*/builder_.CreateString(name));
}
BufferOffset<tflite::Metadata> Translator::BuildMetadata(StringRef name,
@ -1050,35 +1250,45 @@ Optional<std::string> Translator::Translate(
}
Optional<std::string> Translator::TranslateInternal() {
// Create a list of functions in the module with main function being the
// first function in the list. This is required as the first subgraph in the
// model is entry point for the model.
std::vector<FuncOp> functions;
functions.reserve(std::distance(module_.begin(), module_.end()));
// A list of named regions in the module with main function being the first in
// the list. The main function is required as the first subgraph in the model
// is entry point for the model.
std::vector<std::pair<std::string, Region*>> named_regions;
named_regions.reserve(std::distance(module_.begin(), module_.end()));
int subgraph_idx = 0;
FuncOp main_fn = module_.lookupSymbol<FuncOp>("main");
subgraph_index_map_[main_fn.getName().str()] = subgraph_idx++;
functions.push_back(main_fn);
for (auto fn : module_.getOps<FuncOp>()) {
if (fn == main_fn) continue;
named_regions.emplace_back("main", &main_fn.getBody());
// Walk over the module collection ops with functions and while ops.
module_.walk([&](Operation* op) {
if (auto fn = dyn_cast<FuncOp>(op)) {
if (fn != main_fn) {
subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
named_regions.emplace_back(fn.getName().str(), &fn.getBody());
}
} else if (auto wo = dyn_cast<mlir::TFL::WhileOp>(op)) {
std::string name = GetWhileCondName(wo);
subgraph_index_map_[name] = subgraph_idx++;
named_regions.emplace_back(GetWhileCondName(wo), &wo.cond());
name = GetWhileBodyName(wo);
subgraph_index_map_[name] = subgraph_idx++;
named_regions.emplace_back(name, &wo.body());
}
});
subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
functions.push_back(fn);
}
// Build subgraph for each of the functions.
// Build subgraph for each of the named regions.
std::vector<BufferOffset<tflite::SubGraph>> subgraphs;
subgraphs.reserve(functions.size());
subgraphs.reserve(named_regions.size());
int first_failed_func = -1;
for (int i = 0; i < functions.size(); ++i) {
auto subgraph_or = BuildSubGraph(functions[i]);
for (auto it : llvm::enumerate(named_regions)) {
auto subgraph_or = BuildSubGraph(it.value().first, it.value().second);
if (!subgraph_or) {
if (first_failed_func == -1)
// Record the index of the first function that cannot be converted.
// Record the index of the first region that cannot be converted.
// Keep looping through all subgraphs in the module to make sure that
// we collect the list of missing ops from the entire module.
first_failed_func = i;
first_failed_func = it.index();
} else {
subgraphs.push_back(*subgraph_or);
}
@ -1099,9 +1309,10 @@ Optional<std::string> Translator::TranslateInternal() {
"-emit-custom-ops flag): " +
failed_custom_ops_list;
return functions[first_failed_func].emitError("failed while converting: '")
<< functions[first_failed_func].getName() << "\'\n"
<< err,
auto& failed_region = named_regions[first_failed_func];
return failed_region.second->getParentOp()->emitError()
<< "failed while converting: '" << failed_region.first
<< "': " << err,
llvm::None;
}

View File

@ -0,0 +1,58 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the operation interface definition file for TensorFlow Lite.
#ifndef TFL_OP_INTERFACES
#define TFL_OP_INTERFACES
include "mlir/IR/OpBase.td"
//===----------------------------------------------------------------------===//
// TFL op interface for stateful operands.
def TFL_StatefulOp : OpInterface<"StatefulOpInterface"> {
let description = [{
Interface for ops that are stateful and need to identify stateful operands.
Stateful operands correspond to TF's variables semantics. An op that has 1
or more stateful operands is a stateful op.
}];
let methods = [
InterfaceMethod<
[{Returns the indices of stateful operands.}],
"std::vector<int>", "GetStatefulOperands", (ins)
>,
];
}
//===----------------------------------------------------------------------===//
// TFL op interface for output channel index.
def TFL_ChannelDimIndexInterface : OpInterface<"ChannelDimIndexInterface"> {
let description = [{
Interface for defining the index of out channel index.
}];
let methods = [
InterfaceMethod<
[{Returns the dimension index of the output channels.}],
"int", "GetChannelDimIndex", (ins)
>,
];
}
#endif // TFL_OP_INTERFACES

View File

@ -304,11 +304,11 @@ Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
void buildComparisonBinOp(Builder *builder, OperationState &result, Value lhs,
Value rhs) {
auto result_type =
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
if (!result_type)
emitError(result.location)
<< "non-broadcastable operands: " << lhs->getType() << " and "
<< rhs->getType();
<< "non-broadcastable operands: " << lhs.getType() << " and "
<< rhs.getType();
result.addOperands({lhs, rhs});
// Comparison binary ops always return i1 tensor.
if (auto shaped_type = result_type.dyn_cast<RankedTensorType>()) {
@ -324,12 +324,12 @@ void buildFusedBroadcastableBinOp(Builder *builder, OperationState &result,
Value lhs, Value rhs,
StringAttr fused_activation_function) {
auto result_type =
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
if (!result_type)
emitError(result.location)
<< "non-broadcastable operands: " << lhs->getType() << " and "
<< rhs->getType();
<< "non-broadcastable operands: " << lhs.getType() << " and "
<< rhs.getType();
result.addOperands({lhs, rhs});
result.addAttribute("fused_activation_function", fused_activation_function);
@ -358,7 +358,7 @@ OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
namespace {
int64_t GetConcatenationOpAxis(ConcatenationOp op) {
auto output_type = op.output()->getType().cast<RankedTensorType>();
auto output_type = op.output().getType().cast<RankedTensorType>();
int64_t axis = op.axis().getSExtValue();
if (axis < 0) axis += output_type.getRank();
return axis;
@ -452,7 +452,7 @@ LogicalResult VerifyConcatenationOpTypes(Operation *op,
}
LogicalResult Verify(ConcatenationOp op) {
auto output_type = op.output()->getType().dyn_cast<RankedTensorType>();
auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
// If the output type is unranked, there is nothing else to be verified.
if (!output_type) return success();
@ -463,7 +463,7 @@ LogicalResult Verify(ConcatenationOp op) {
SmallVector<TensorType, 4> operand_types;
for (Value operand : op.values())
operand_types.push_back(operand->getType().cast<TensorType>());
operand_types.push_back(operand.getType().cast<TensorType>());
return VerifyConcatenationOpTypes(op.getOperation(), output_type,
operand_types, axis);
@ -520,7 +520,7 @@ DenseElementsAttr ConstFoldConcatenateOpDense(ArrayRef<Attribute> operands,
OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
if (fused_activation_function() == "NONE") {
if (auto output_type = output()->getType().dyn_cast<RankedTensorType>()) {
if (auto output_type = output().getType().dyn_cast<RankedTensorType>()) {
const int64_t axis = GetConcatenationOpAxis(*this);
if (IsConcatenationOpConstFoldable(*this, operands, output_type, axis))
return ConstFoldConcatenateOpDense(operands, output_type, axis);
@ -530,7 +530,7 @@ OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
// Remove all empty values.
SmallVector<Value, 4> non_empty_values;
for (Value value : this->values()) {
const auto shaped_type = value->getType().cast<ShapedType>();
const auto shaped_type = value.getType().cast<ShapedType>();
if (shaped_type.hasStaticShape() && shaped_type.getNumElements() == 0) {
continue;
}
@ -559,8 +559,8 @@ OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
LogicalResult Verify(FullyConnectedOp op) {
ShapedType input_type = op.input()->getType().cast<ShapedType>();
ShapedType filter_type = op.filter()->getType().cast<ShapedType>();
ShapedType input_type = op.input().getType().cast<ShapedType>();
ShapedType filter_type = op.filter().getType().cast<ShapedType>();
if (filter_type.hasRank() && filter_type.getRank() != 2) {
return op.emitOpError("expect 2d filter, got ") << filter_type;
}
@ -582,7 +582,7 @@ LogicalResult Verify(FullyConnectedOp op) {
// format.
if (op.weights_format() == "DEFAULT") {
ShapedType output_type =
(*op.output().begin())->getType().cast<ShapedType>();
(*op.output().begin()).getType().cast<ShapedType>();
if (!output_type.hasStaticShape()) {
return mlir::success();
}
@ -610,8 +610,8 @@ LogicalResult Verify(FullyConnectedOp op) {
static void BuildGatherOp(Builder *builder, OperationState &result,
Value params, Value indices, IntegerAttr axis) {
auto params_type = params->getType().cast<TensorType>();
auto indices_type = indices->getType().cast<TensorType>();
auto params_type = params.getType().cast<TensorType>();
auto indices_type = indices.getType().cast<TensorType>();
// If params/indices is unranked, then output is unranked.
if (!params_type.hasRank() || !indices_type.hasRank())
@ -705,7 +705,7 @@ static LogicalResult Verify(PackOp op) {
return op.emitOpError("input count should match 'values_count' attribute");
Value operand0 = op.getOperand(0);
auto input_type = operand0->getType().cast<ShapedType>();
auto input_type = operand0.getType().cast<ShapedType>();
// Check axis bounds.
if (input_type.hasRank()) {
@ -718,7 +718,7 @@ static LogicalResult Verify(PackOp op) {
// Make sure all inputs have the same shape and element type.
// TODO(rahulsp): Simplify once b/135032064 is fixed.
for (Value operand : op.getOperands()) {
auto other_type = operand->getType().cast<ShapedType>();
auto other_type = operand.getType().cast<ShapedType>();
if (input_type != other_type)
return op.emitOpError("operands should be of the same type. got ")
<< input_type << ", " << other_type;
@ -732,9 +732,9 @@ static LogicalResult Verify(PackOp op) {
//===----------------------------------------------------------------------===//
static LogicalResult Verify(PReluOp op) {
auto input_type = op.input()->getType().cast<ShapedType>();
auto alpha_type = op.alpha()->getType().cast<ShapedType>();
auto output_type = op.output()->getType().cast<ShapedType>();
auto input_type = op.input().getType().cast<ShapedType>();
auto alpha_type = op.alpha().getType().cast<ShapedType>();
auto output_type = op.output().getType().cast<ShapedType>();
if (input_type.hasStaticShape() && alpha_type.hasStaticShape()) {
if (input_type.getRank() != alpha_type.getRank() + 1) {
@ -783,13 +783,13 @@ struct RemoveAdjacentReshape : public RewritePattern {
PatternMatchResult match(Operation *op) const override {
auto thisOp = cast<ReshapeOp>(op);
auto prevOp = thisOp.getOperand(0)->getDefiningOp();
auto prevOp = thisOp.getOperand(0).getDefiningOp();
return isa_and_nonnull<ReshapeOp>(prevOp) ? matchSuccess() : matchFailure();
}
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
auto thisOp = cast<ReshapeOp>(op);
auto prevOp = cast<ReshapeOp>(thisOp.getOperand(0)->getDefiningOp());
auto prevOp = cast<ReshapeOp>(thisOp.getOperand(0).getDefiningOp());
// Replace
// %1 = "tfl.reshape"(%0, %shape0)
@ -797,8 +797,7 @@ struct RemoveAdjacentReshape : public RewritePattern {
// With
// %2 = "tfl.reshape"(%0, %shape1)
rewriter.replaceOpWithNewOp<ReshapeOp>(
{prevOp.getResult()}, op, thisOp.getType(), prevOp.getOperand(0),
thisOp.getOperand(1));
op, thisOp.getType(), prevOp.getOperand(0), thisOp.getOperand(1));
}
};
@ -807,7 +806,7 @@ struct RemoveAdjacentReshape : public RewritePattern {
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
// Remove identity reshape with both static result and input shape.
auto result_type = getType().cast<ShapedType>();
auto input_type = getOperand(0)->getType().cast<ShapedType>();
auto input_type = getOperand(0).getType().cast<ShapedType>();
if (result_type.hasStaticShape() && result_type == input_type) {
return getOperand(0);
}
@ -865,7 +864,7 @@ struct RemoveRedundantUnpackPack : public RewritePattern {
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
TFL::PackOp pack_op = cast<TFL::PackOp>(op);
Operation *first_input = pack_op.getOperand(0)->getDefiningOp();
Operation *first_input = pack_op.getOperand(0).getDefiningOp();
if (!first_input) return matchFailure();
auto input_unpack_op = dyn_cast_or_null<TFL::UnpackOp>(first_input);
if (!input_unpack_op) return matchFailure();
@ -905,9 +904,9 @@ void PackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
//===----------------------------------------------------------------------===//
static LogicalResult Verify(SliceOp op) {
auto input_type = op.input()->getType().cast<ShapedType>();
auto begin_type = op.begin()->getType().cast<ShapedType>();
auto size_type = op.size()->getType().cast<ShapedType>();
auto input_type = op.input().getType().cast<ShapedType>();
auto begin_type = op.begin().getType().cast<ShapedType>();
auto size_type = op.size().getType().cast<ShapedType>();
if (input_type.hasStaticShape() && begin_type.hasStaticShape() &&
size_type.hasStaticShape()) {
if (input_type.getRank() != begin_type.getNumElements()) {
@ -995,7 +994,7 @@ static void BuildTopKOp(Builder *builder, OperationState &result, Value input,
// TODO(jpienaar): This should use a helper function.
const_k = cst.getValue<IntegerAttr>({}).getValue().getSExtValue();
auto val_type = input->getType().cast<TensorType>();
auto val_type = input.getType().cast<TensorType>();
// If value is unranked, then so is results.
if (!val_type.hasRank())
return TFL::TopKV2Op::build(
@ -1035,7 +1034,7 @@ struct DropFakeQuant : public RewritePattern {
// If all the users of this op have valid "minmax" attributes, it is matched
// and can be removed.
auto fakeQuantOp = cast<FakeQuantOp>(op);
for (auto *operand : fakeQuantOp.getResult()->getUsers())
for (auto *operand : fakeQuantOp.getResult().getUsers())
if (!HasValidMinMaxAttribute(operand)) return matchFailure();
return matchSuccess();
@ -1102,7 +1101,7 @@ static LogicalResult VerifySplitOpOutputTypes(
for (int64_t i = 0; i < num_splits; ++i) {
auto expected_output_type = get_expected_output_type(i);
Value output = op->getResult(i);
auto output_type = output->getType().dyn_cast<RankedTensorType>();
auto output_type = output.getType().dyn_cast<RankedTensorType>();
if (!output_type || output_type != expected_output_type)
return op->emitOpError()
<< "output #" << i << " should be " << expected_output_type;
@ -1121,7 +1120,7 @@ static LogicalResult Verify(SplitOp op) {
if (!split_dim_opt) return success();
// If 'input' is not a ranked tensor, there are no other checks.
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>();
auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
if (!input_type) return success();
int64_t split_dim = split_dim_opt.getValue();
@ -1157,7 +1156,7 @@ static LogicalResult Verify(SplitVOp op) {
if (!split_dim_opt) return success();
// If 'input' is not a ranked tensor, there are no other checks.
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>();
auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
if (!input_type) return success();
int64_t split_dim = split_dim_opt.getValue();
@ -1177,8 +1176,7 @@ static LogicalResult Verify(SplitVOp op) {
return success();
if (size_splits_attr.getNumElements() != num_splits) {
auto size_splits_type =
op.size_splits()->getType().cast<RankedTensorType>();
auto size_splits_type = op.size_splits().getType().cast<RankedTensorType>();
RankedTensorType expected_size_splits_type =
RankedTensorType::get({num_splits}, size_splits_type.getElementType());
return op.emitOpError("'size_splits' should be ")
@ -1303,6 +1301,19 @@ OpFoldResult AbsOp::fold(ArrayRef<Attribute> operands) {
return ConstFoldUnaryOp(result_type, operands[0], compute);
}
//===----------------------------------------------------------------------===//
// NegOp
//===----------------------------------------------------------------------===//
OpFoldResult NegOp::fold(ArrayRef<Attribute> operands) {
Type result_type = getType();
// Only constant fold for tensor of f32 is implemented.
if (!IsF32ShapedType(result_type)) return nullptr;
auto compute = [](APFloat value) -> APFloat { return llvm::neg(value); };
return ConstFoldUnaryOp(result_type, operands[0], compute);
}
//===----------------------------------------------------------------------===//
// SinOp
//===----------------------------------------------------------------------===//
@ -1414,7 +1425,7 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
}
// Also fold if `input` has a known rank.
auto input_type = input()->getType().cast<ShapedType>();
auto input_type = input().getType().cast<ShapedType>();
// Do not fold if rank is zero because the TFLite converter doesn't
// distinguish between unranked input and scalar input due to b/138865275.
// TODO(b/138865275): Remove `input_type.getRank() != 0` in the following
@ -1445,18 +1456,18 @@ OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
static void BuildSelectV2Op(Builder *builder, OperationState &result,
Value cond, Value x, Value y) {
auto operand_type =
OpTrait::util::getBroadcastedType(x->getType(), y->getType());
OpTrait::util::getBroadcastedType(x.getType(), y.getType());
if (!operand_type)
emitError(result.location) << "non-broadcastable operands: " << x->getType()
<< " and " << y->getType();
emitError(result.location) << "non-broadcastable operands: " << x.getType()
<< " and " << y.getType();
bool has_static_cond_shape = false;
bool has_static_operand_shape = false;
ArrayRef<int64_t> cond_shape;
ArrayRef<int64_t> operand_shape;
if (auto shaped_type = cond->getType().dyn_cast<ShapedType>()) {
if (auto shaped_type = cond.getType().dyn_cast<ShapedType>()) {
if (shaped_type.hasStaticShape()) {
has_static_cond_shape = true;
cond_shape = shaped_type.getShape();
@ -1474,12 +1485,12 @@ static void BuildSelectV2Op(Builder *builder, OperationState &result,
!OpTrait::util::getBroadcastedShape(cond_shape, operand_shape,
broadcastedShape)) {
emitError(result.location) << "non-broadcastable operands: " << operand_type
<< " and " << cond->getType();
<< " and " << cond.getType();
}
result.addOperands({cond, x, y});
auto elementType = x->getType().dyn_cast<ShapedType>().getElementType();
auto elementType = x.getType().dyn_cast<ShapedType>().getElementType();
if (has_static_cond_shape && has_static_operand_shape) {
result.types.push_back(
RankedTensorType::get(broadcastedShape, elementType));
@ -1571,9 +1582,8 @@ OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
static LogicalResult Verify(TransposeConvOp op) {
ShapedType output_type = op.output()->getType().cast<ShapedType>();
ShapedType output_shape_type =
op.output_shape()->getType().cast<ShapedType>();
ShapedType output_type = op.output().getType().cast<ShapedType>();
ShapedType output_shape_type = op.output_shape().getType().cast<ShapedType>();
if (output_type.hasRank() && output_shape_type.hasStaticShape()) {
if (output_type.getRank() != output_shape_type.getDimSize(0)) {
return op.emitOpError(llvm::formatv(
@ -1679,9 +1689,9 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
}
static LogicalResult Verify(TransposeOp op) {
auto input_type = op.x()->getType().cast<ShapedType>();
auto perm_type = op.perm()->getType().cast<ShapedType>();
auto output_type = op.y()->getType().cast<ShapedType>();
auto input_type = op.x().getType().cast<ShapedType>();
auto perm_type = op.perm().getType().cast<ShapedType>();
auto output_type = op.y().getType().cast<ShapedType>();
if (input_type.hasStaticShape() && perm_type.hasStaticShape()) {
if (perm_type.getNumElements() != input_type.getRank()) {
return op.emitOpError(
@ -1726,10 +1736,25 @@ static LogicalResult Verify(TransposeOp op) {
return success();
}
Region &WhileOp::getLoopBody() { return body(); }
bool WhileOp::isDefinedOutsideOfLoop(Value value) {
// TODO(jpienaar): This is to overly conservative and disables anything other
// than constant hoisting initially.
return false;
}
LogicalResult WhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *>) {
// TODO(jpienaar): Fail any hoisting until post test case and refining
// isDefinedOutsideOfLoop.
return failure();
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.cc.inc"
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"

View File

@ -27,7 +27,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
#include "mlir/Transforms/LoopLikeInterface.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/lite/schema/schema_generated.h"
@ -44,6 +44,7 @@ class TensorFlowLiteDialect : public Dialect {
Location loc) override;
};
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.h.inc"
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc"

View File

@ -19,6 +19,8 @@ limitations under the License.
#define TFL_OPS
include "mlir/IR/OpBase.td"
include "mlir/Transforms/LoopLikeInterface.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td"
include "tensorflow/compiler/mlir/lite/quantization/quantization.td"
def TFL_Dialect : Dialect {
@ -135,7 +137,7 @@ def TFL_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TFL_Int32Or64]>;
//===----------------------------------------------------------------------===//
class TFL_OperandIsUnrankedPred<int n> :
CPred<"$_op.getOperand(" # n # ")->getType().isa<UnrankedTensorType>()">;
CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">;
// TODO: Some of these could be generalized and/or moved to more general
// location.
@ -144,38 +146,38 @@ class TFL_OperandHasRank<int n, int m> :
PredOpTrait<"operand " # n # " is " # m # "-D",
Or<[TFL_OperandIsUnrankedPred<n>,
CPred<"$_op.getOperand(" # n #
")->getType().cast<ShapedType>().getRank() == " # m>]>>;
").getType().cast<ShapedType>().getRank() == " # m>]>>;
// Returns true if the n-th operand is ranked and has rank dim.
class TFL_OperandHasKnownRank<int n, int dim> : And<[
CPred<"$_op.getOperand(" # n # ")->getType().isa<RankedTensorType>()">,
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>().getRank() == "
CPred<"$_op.getOperand(" # n # ").getType().isa<RankedTensorType>()">,
CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>().getRank() == "
# dim>]>;
// True if operand n is ranked and has a rank > dim.
class TFL_OperandIsRankedAndHasDimPred<int n, int dim> : And<[
CPred<"$_op.getOperand(" # n # ")->getType().isa<RankedTensorType>()">,
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>().getRank() > "
CPred<"$_op.getOperand(" # n # ").getType().isa<RankedTensorType>()">,
CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>().getRank() > "
# dim>]>;
class TFL_OperandDimEquals<int n, int dim, int size> : And<[
TFL_OperandIsRankedAndHasDimPred<n, dim>,
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>()"
CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>()"
".getShape()[" # dim # " ] == " # size>]>;
// Returns true if the n-th operand has unknown rank or at least rank m.
class TFL_OperandHasAtleastRank<int n, int m> :
PredOpTrait<"operand " # n # " is " # m # "-D",
Or<[CPred<"$_op.getOperand(" # n # ")->getType().isa<UnrankedTensorType>()">,
Or<[CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">,
CPred<"$_op.getOperand(" # n #
")->getType().cast<ShapedType>().getRank() >= " # m>]>>;
").getType().cast<ShapedType>().getRank() >= " # m>]>>;
class TFL_OperandRankEquals1DimOfOperand<int x, int y> :
PredOpTrait<"operand " # x # "'s rank equals operand " # y # "'s size",
CPred<"$_op.getOperand(" # x #
")->getType().cast<ShapedType>().getRank() == "
").getType().cast<ShapedType>().getRank() == "
"$_op.getOperand(" # y #
")->getType().cast<ShapedType>().getShape()[0]">>;
").getType().cast<ShapedType>().getShape()[0]">>;
class TFL_Operand0DOr1ElementTensor<int x> :
PredOpTrait<"operand #" # x # " is an 0-d tensor or 1-d tensor w/ 1 element",
@ -195,7 +197,7 @@ class TFL_OperandHasRankLessThan<int n, int m> :
PredOpTrait<"operand " # n # " is maximum " # m # "-D",
Or<[TFL_OperandIsUnrankedPred<n>,
CPred<"$_op.getOperand(" # n #
")->getType().cast<ShapedType>().getRank() <= " # m>]>>;
").getType().cast<ShapedType>().getRank() <= " # m>]>>;
// This is a quantization-aware version of TCresVTEtIsSameAsOp
class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[
@ -227,7 +229,7 @@ def TFL_BroadcastableBinaryBuilder : OpBuilder<
"Builder *builder, OperationState &result, Value lhs, Value rhs",
[{
auto resultType =
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
if (!resultType)
mlir::emitError(result.location, "non-broadcastable operands");
result.addOperands({lhs, rhs});
@ -248,16 +250,6 @@ def TFL_ComparisonBinaryBuilder : OpBuilder<
buildComparisonBinOp(builder, result, lhs, rhs);
}]>;
//===----------------------------------------------------------------------===//
// TFL native op trait for stateful operands and channel indices.
class StatefulOperands<list<int> operands>
: ParamNativeOpTrait<"TFL::StatefulOperands", StrJoinInt<operands>.result>;
class ChannelDimIndex<int index>
: ParamNativeOpTrait<"TFL::ChannelDimIndex", !cast<string>(index)>;
//===----------------------------------------------------------------------===//
// TFL op base class.
//===----------------------------------------------------------------------===//
@ -285,7 +277,7 @@ class TFL_Op<string mnemonic, list<OpTrait> traits = []> :
class TFL_ConvOp<string mnemonic, string opSummary, int index> :
TFL_Op<mnemonic, [NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
ChannelDimIndex<index>, AffineOpCoefficient<index, 1>]> {
TFL_ChannelDimIndexInterface, AffineOpCoefficient<index, 1>]> {
let summary = opSummary # " operator";
let description = [{
@ -335,7 +327,7 @@ an output element, this operation computes \\(y = |x|\\).
let hasFolder = 1;
}
def TFL_AddOp : TFL_Op<"add", [Broadcastable, NoSideEffect, Commutative]> {
def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape, NoSideEffect, Commutative]> {
let summary = "Addition operator";
let description = [{
@ -427,6 +419,33 @@ def TFL_TransposeConvOp:
let verifier = [{ return Verify(*this); }];
}
def TFL_Convolution2DTransposeBiasOp :
Op<TFL_Dialect, "convolution_2d_transpose_bias", [NoSideEffect]> {
let summary = " Transpose convolution with bias operator";
let description = [{
Performs transpose convolution operation on inputs,
with the option of adding a bias.
Note this is a custom op that is not supported in the standard runtime.
Inputs:
`inputs[0]`: required: the input activation tensor
`inputs[1]`: required: the filter weight tensor
`inputs[2]`: optional: the bias tensor
}];
let arguments = (
ins AnyTensor:$input,
AnyTensor:$filter,
TFL_TensorOfOrNone<[AnyType]>:$bias,
TFL_PaddingAttr:$padding,
I32Attr:$stride_h,
I32Attr:$stride_w
);
let results = (outs AnyTensor:$output);
}
def TFL_AveragePool2DOp:
TFL_Op<"average_pool_2d", [NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Average_pool_2d operator";
@ -459,8 +478,7 @@ def TFL_ArgMaxOp : TFL_Op<"arg_max", [NoSideEffect]> {
}];
let arguments = (
// TODO: Add support for uint8.
ins TensorOf<[F32, I32, I8]>:$input,
ins TensorOf<[F32, I32, I8, TFL_Uint8, QI8, QUI8]>:$input,
TFL_I32OrI64Tensor:$dim
);
@ -471,7 +489,7 @@ def TFL_ArgMaxOp : TFL_Op<"arg_max", [NoSideEffect]> {
let hasOptions = 1;
DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{
return getResult()->getType().cast<TensorType>().getElementType().
return getResult().getType().cast<TensorType>().getElementType().
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
tflite::TensorType_INT32;
}]>;
@ -488,8 +506,7 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> {
}];
let arguments = (
// TODO(pkanwar): Add support for uint8.
ins TensorOf<[F32, I32, I8]>:$input,
ins TensorOf<[F32, I32, I8, TFL_Uint8, QI8, QUI8]>:$input,
TFL_I32OrI64Tensor:$dim
);
@ -500,7 +517,7 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> {
let hasOptions = 1;
DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{
return getResult()->getType().cast<TensorType>().getElementType().
return getResult().getType().cast<TensorType>().getElementType().
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
tflite::TensorType_INT32;
}]>;
@ -590,7 +607,12 @@ def TFL_ExternalConstOp : Op<TFL_Dialect, "external_const", [NoSideEffect]> {
let results = (outs AnyTensor:$output);
}
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0>;
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0> {
let extraClassDeclaration = [{
// StatefulOpInterface:
int GetChannelDimIndex() { return 0; }
}];
}
def TFL_CosOp: TFL_Op<"cos", [
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
@ -610,6 +632,11 @@ def TFL_CosOp: TFL_Op<"cos", [
def TFL_DepthwiseConv2DOp :
TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> {
let arguments = !con(TFL_Conv2DOp.arguments, (ins I32Attr:$depth_multiplier));
let extraClassDeclaration = [{
// StatefulOpInterface:
int GetChannelDimIndex() { return 3; }
}];
}
def TFL_FCWO_Default : StrEnumAttrCase<"DEFAULT">;
@ -623,7 +650,8 @@ def TFL_FullyConnectedOptionsWeightFormatAttr :
// TODO(jpienaar): Update post discussion on semantics of FC OP.
def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
NoSideEffect, AccumulatorUniformScale<2, 0, 1>, ChannelDimIndex<0>,
NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
TFL_ChannelDimIndexInterface,
AffineOpCoefficient<-1, 1>]> {
let summary = "Fully connected op";
@ -645,6 +673,11 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
let verifier = [{ return Verify(*this); }];
let hasOptions = 1;
let extraClassDeclaration = [{
// ChannelDimIndexInterface:
int GetChannelDimIndex() { return 0; }
}];
}
def TFL_GatherOp : TFL_Op<"gather", [
@ -652,7 +685,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
SameOperandsAndResultsScale,
TFL_OperandHasAtleastRank<0, 1>,
PredOpTrait<"params and output must have same element type",
TCresVTEtIsSameAsOp<0, 0>>
TFL_TCresVTEtIsSameAsOp<0, 0>>
]> {
let summary = "Gather operator";
@ -661,7 +694,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
}];
let arguments = (ins
TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8]>:$params,
TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8, QI16]>:$params,
TensorOf<[I32, I64]>:$indices,
I32Attr:$axis
);
@ -674,7 +707,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
];
let results = (outs
TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8]>:$output
TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8, QI16]>:$output
);
let hasOptions = 1;
@ -697,9 +730,9 @@ def TFL_GatherNdOp : TFL_Op<"gather_nd", [NoSideEffect]> {
);
}
// Same type check of lhs and rhs is handled by the Broadcastable trait.
// Same type check of lhs and rhs is handled by the ResultsBroadcastableShape trait.
def TFL_LessEqualOp : TFL_Op<"less_equal", [
Broadcastable, NoSideEffect, NoQuantizableResult]> {
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
let summary = "Less_equal operator";
let description = [{
@ -755,7 +788,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag
}
def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [
Broadcastable, NoSideEffect, NoQuantizableResult]> {
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
let summary = "Greater_equal operator";
let description = [{
@ -916,7 +949,7 @@ larger than 0.
}
def TFL_NotEqualOp : TFL_Op<"not_equal", [
Broadcastable, Commutative, NoSideEffect, NoQuantizableResult]> {
ResultsBroadcastableShape, Commutative, NoSideEffect, NoQuantizableResult]> {
let summary = "Not_equal operator";
let description = [{
@ -943,7 +976,7 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
}
def TFL_DivOp : TFL_Op<"div", [Broadcastable, NoSideEffect]> {
def TFL_DivOp : TFL_Op<"div", [ResultsBroadcastableShape, NoSideEffect]> {
let summary = "Division operator";
let description = [{
@ -1002,7 +1035,7 @@ def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup",
let results = (outs TensorOf<[F32, I8, TFL_Uint8]>:$output);
}
def TFL_EqualOp: TFL_Op<"equal", [Commutative, Broadcastable,
def TFL_EqualOp: TFL_Op<"equal", [Commutative, ResultsBroadcastableShape,
NoQuantizableResult,
PredOpTrait<"Operands have same value type", TCopVTEtIsSameAs<0, 1>>]> {
let summary = "Equal operator";
@ -1036,7 +1069,8 @@ def TFL_ExpOp: TFL_Op<"exp", [NoSideEffect, SameOperandsAndResultType]> {
let hasOptions = 0b1;
}
def TFL_ExpandDimsOp: TFL_Op<"expand_dims", [NoSideEffect]> {
def TFL_ExpandDimsOp: TFL_Op<"expand_dims", [
NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Inserts a dimension of 1 into a tensor's shape.";
let description = [{
@ -1146,7 +1180,7 @@ def TFL_FloorOp: TFL_Op<"floor", [NoSideEffect, SameOperandsAndResultType]> {
}
def TFL_FloorDivOp : TFL_Op<"floor_div", [
Broadcastable, NoSideEffect, BinaryOpSameElementTypeConstraint]> {
ResultsBroadcastableShape, NoSideEffect, BinaryOpSameElementTypeConstraint]> {
let summary = "Floor div operator";
let description = [{
@ -1165,7 +1199,7 @@ def TFL_FloorDivOp : TFL_Op<"floor_div", [
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
}
def TFL_FloorModOp : TFL_Op<"floor_mod", [Broadcastable, NoSideEffect]> {
def TFL_FloorModOp : TFL_Op<"floor_mod", [ResultsBroadcastableShape, NoSideEffect]> {
let summary = "Division reminder";
let description = [{
@ -1181,7 +1215,8 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [Broadcastable, NoSideEffect]> {
let builders = [TFL_BroadcastableBinaryBuilder];
}
def TFL_GreaterOp : TFL_Op<"greater", [NoSideEffect, NoQuantizableResult]> {
def TFL_GreaterOp : TFL_Op<"greater", [
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
let summary = "Greater operator";
let description = [{
@ -1194,6 +1229,8 @@ def TFL_GreaterOp : TFL_Op<"greater", [NoSideEffect, NoQuantizableResult]> {
let results = (outs AnyTensor:$output);
let builders = [TFL_ComparisonBinaryBuilder];
let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }];
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
@ -1260,7 +1297,8 @@ def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [NoSideEffect, SameOperandsAndResultTy
let hasOptions = 0b1;
}
def TFL_LessOp : TFL_Op<"less", [NoSideEffect, NoQuantizableResult]> {
def TFL_LessOp : TFL_Op<"less", [
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
let summary = "Less operator";
let description = [{
@ -1427,8 +1465,65 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
let customOption = "Pool2DOptions";
}
def TFL_MaxPoolingWithArgMax2DOp :
Op<TFL_Dialect, "max_pooling_with_argmax_2d", [NoSideEffect]> {
let summary = "Max Pool 2D with argmax op";
let description = [{
Performs max pooling on the input and outputs both max values and indices.
Each index is a flatten index in a sub-array of "filter_w" x "filter_h" size
Note this is a custom op that is not supported in the standard runtime.
Inputs:
`inputs[0]`: required: the input activation tensor
}];
let arguments = (
ins AnyTensor:$input,
TFL_PaddingAttr:$padding,
I32Attr:$stride_w,
I32Attr:$stride_h,
I32Attr:$filter_w,
I32Attr:$filter_h
);
let results = (outs
AnyTensor:$value,
AnyTensor:$indices
);
}
def TFL_MaxUnpooling2DOp :
Op<TFL_Dialect, "max_unpooling_2d", [NoSideEffect]> {
let summary = "Max Unpool 2D";
let description = [{
Performs max unpool operation.
To some extent this is the reverse operation of max pooling:
the elements in the input activation tensor is stored into the position
specified by the input indices.
Note this is a custom op that is not supported in the standard runtime.
Inputs:
`inputs[0]`: required: the input activation tensor
`inputs[1]`: required: the input indices
}];
let arguments = (
ins AnyTensor:$input,
AnyTensor:$indices,
TFL_PaddingAttr:$padding,
I32Attr:$stride_w,
I32Attr:$stride_h,
I32Attr:$filter_w,
I32Attr:$filter_h
);
let results = (outs AnyTensor:$outputs);
}
def TFL_MaximumOp : TFL_Op<"maximum", [
Broadcastable, NoSideEffect, Commutative, SameOperandsAndResultsScale,
ResultsBroadcastableShape, NoSideEffect, Commutative, SameOperandsAndResultsScale,
TFL_OperandHasRankLessThan<0, 4>, TFL_OperandHasRankLessThan<1, 4>]> {
let summary = "Max operator";
let description = [{
@ -1567,7 +1662,8 @@ def TFL_SumOp: TFL_Op<"sum", [NoSideEffect]> {
let customOption = "ReducerOptions";
}
def TFL_ReduceMinOp: TFL_Op<"reduce_min", [NoSideEffect]> {
def TFL_ReduceMinOp: TFL_Op<"reduce_min", [
NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Min-reduction operator";
let description = [{
@ -1586,7 +1682,8 @@ def TFL_ReduceMinOp: TFL_Op<"reduce_min", [NoSideEffect]> {
let customOption = "ReducerOptions";
}
def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [NoSideEffect]> {
def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [
NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Max-reduction operator";
let description = [{
@ -1625,7 +1722,7 @@ def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [NoSideEffect]> {
}
def TFL_MinimumOp : TFL_Op<"minimum", [
Broadcastable, NoSideEffect, Commutative, SameOperandsAndResultsScale,
ResultsBroadcastableShape, NoSideEffect, Commutative, SameOperandsAndResultsScale,
TFL_OperandHasRankLessThan<0, 4>, TFL_OperandHasRankLessThan<1, 4>]> {
let summary = "Min operator";
let description = [{
@ -1646,7 +1743,7 @@ def TFL_MinimumOp : TFL_Op<"minimum", [
let hasOptions = 0;
}
def TFL_MulOp : TFL_Op<"mul", [Broadcastable, NoSideEffect, Commutative]> {
def TFL_MulOp : TFL_Op<"mul", [ResultsBroadcastableShape, NoSideEffect, Commutative]> {
let summary = "Multiplication operator";
let description = [{
@ -1683,6 +1780,8 @@ def TFL_NegOp: TFL_Op<"neg", [NoSideEffect, SameOperandsAndResultType]> {
let results = (outs AnyTensor:$y);
let hasOptions = 0b1;
let hasFolder = 1;
}
def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> {
@ -1716,14 +1815,14 @@ def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> {
}];
let arguments = (ins
Variadic<TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8]>>:$values,
Variadic<TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8, QI16]>>:$values,
I32Attr:$values_count,
I32Attr:$axis
);
let results = (outs
TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8]>:$output
TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8, QI16]>:$output
);
let verifier = [{ return Verify(*this); }];
@ -1821,7 +1920,7 @@ def TFL_PadV2Op : TFL_Op<"padv2", [
let hasOptions = 1;
}
def TFL_PowOp : TFL_Op<"pow", [Broadcastable, NoSideEffect, NoQuantizableResult]> {
def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
let summary = "Power operator";
let description = [{
@ -1996,7 +2095,7 @@ def TFL_ShapeOp: TFL_Op<"shape", [NoSideEffect]> {
let results = (outs AnyTensor:$output);
DerivedTypeAttr out_type = DerivedTypeAttr<[{
return getResult()->getType().cast<TensorType>().getElementType();
return getResult().getType().cast<TensorType>().getElementType();
}]>;
let hasOptions = 1;
@ -2039,7 +2138,7 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2",
Args:
tensor: A Tensor. Must be one of the following types:
int16, int32, int64, float32 Up to 8-D.
uint8, int16, int32, int64, float32, bool Up to 8-D.
axis: A Tensor. Must be one of the following types: int32, int64.
with only 1 element which is the axis index.
@ -2048,12 +2147,12 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2",
let arguments = (
ins
TensorOf<[F32, I16, I32, I64]>:$input,
TensorOf<[F32, I16, I32, I64, TFL_Uint8, I1]>:$input,
TensorOf<[I32, I64]>:$axis
);
let results = (outs
TensorOf<[F32, I16, I32, I64, I8]>:$output
TensorOf<[F32, I16, I32, I64, TFL_Uint8, I1]>:$output
);
}
@ -2083,7 +2182,7 @@ def TFL_SelectOp : TFL_Op<"select", [NoSideEffect,
let builders = [OpBuilder<"Builder *builder, OperationState &result, "
"Value condition, Value x, Value y",
[{
auto resultType = x->getType();
auto resultType = x.getType();
result.addOperands({condition, x, y});
result.types.push_back(resultType);
}]>];
@ -2190,7 +2289,7 @@ def TFL_SquareOp: TFL_Op<"square", [
let hasFolder = 1;
}
def TFL_SubOp : TFL_Op<"sub", [Broadcastable, NoSideEffect]> {
def TFL_SubOp : TFL_Op<"sub", [ResultsBroadcastableShape, NoSideEffect]> {
let summary = "Subtraction operator";
let description = [{
@ -2218,7 +2317,7 @@ def TFL_SubOp : TFL_Op<"sub", [Broadcastable, NoSideEffect]> {
// TODO(jpienaar): Expand the kernel implementation to support all types besides
// I32 and F32.
def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [
Broadcastable, NoSideEffect, NoQuantizableResult]> {
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
let summary = "Squared difference operator";
let description = [{
@ -2257,9 +2356,9 @@ def TFL_TanhOp: TFL_Op<"tanh", [
let results = (outs TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$y);
}
def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, SameOperandsAndResultsScale,
PredOpTrait<"resultant element type needs to match first operand type",
TCresVTEtIsSameAsOp<0,0>>]> {
TFL_TCresVTEtIsSameAsOp<0,0>>]> {
let summary = "Tile operator.";
let description = [{
Constructs a tensor by tiling a given tensor.
@ -2272,10 +2371,11 @@ def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
}];
let arguments = (ins
TensorOf<[F32, I1, I32, I64, TFL_Uint8]>:$input,
TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8]>:$input,
TFL_I32OrI64Tensor:$multiples);
let results = (outs TensorOf<[F32, I1, I32, I64, TFL_Uint8]>:$output);
let results = (outs
TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8]>:$output);
let hasOptions = 0;
}
@ -2285,7 +2385,7 @@ def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
// TODO(jpienaar): Check that k is less or equal the internal dimension
def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>,
PredOpTrait<"result and input element type match",
TCresVTEtIsSameAsOp<0,0>>]> {
TCresVTEtIsSameAsOp<0,0>>, SameOperandsAndResultsScale]> {
let summary = "TopK operator";
let description = [{
@ -2295,11 +2395,11 @@ def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>,
}];
let arguments = (ins
TensorOf<[F32, I8, I32, I64, TFL_Uint8]>:$input,
TensorOf<[F32, I8, I32, I64, TFL_Uint8, QI8, QUI8]>:$input,
I32Tensor:$k);
let results = (outs
AnyTensor:$values,
TensorOf<[F32, I8, I32, I64, TFL_Uint8, QI8, QUI8]>:$values,
I32Tensor:$indices);
let builders = [OpBuilder<"Builder *builder, OperationState &result, "
@ -2338,7 +2438,7 @@ def TFL_TransposeOp : TFL_Op<"transpose",
let hasFolder = 1;
}
def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect]> {
def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Unpacks a tensor along a dimension into multiple tensors";
let description = [{
@ -2554,7 +2654,9 @@ def TFL_ResizeBilinearOp: TFL_Op<"resize_bilinear", [
// TODO(ycling): Support quantized types.
TensorOf<[F32, I32, QI8, QUI8]>:$input,
TensorOf<[I32]>:$size,
BoolAttr:$align_corners);
BoolAttr:$align_corners,
DefaultValuedAttr<BoolAttr, "false">:$half_pixel_centers
);
let results = (outs
TensorOf<[F32, QI8, QUI8]>:$output
@ -2663,12 +2765,11 @@ def TFL_CastOp : TFL_Op<"cast", [
Casts input from input type to output type.
}];
// TODO(b/135538711): Add complex types here.
let arguments = (ins
TensorOf<[F32, I1, I32, I64, TFL_Quint8, TFL_Uint8]>:$input
TensorOf<[F32, I1, I32, I64, TFL_Quint8, TFL_Uint8, Complex<F<32>>]>:$input
);
let results = (outs TensorOf<[F32, I1, I32, I64]>:$output);
let results = (outs TensorOf<[F32, I1, I32, I64, Complex<F<32>>]>:$output);
// TFLite's cast op does not utilize CastOptions, instead derives types
// from the TfLiteTensors.
@ -2733,7 +2834,7 @@ in the unique output `y`. In other words:
);
DerivedTFLiteTypeAttr idx_out_type = DerivedTFLiteTypeAttr<[{
return getResult(1)->getType().cast<TensorType>().getElementType().
return getResult(1).getType().cast<TensorType>().getElementType().
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
tflite::TensorType_INT32;
}]>;
@ -2768,7 +2869,9 @@ def TFL_FakeQuantOp : TFL_Op<"fake_quant", [NoSideEffect]> {
let arguments = (
ins AnyTensor:$input,
// The expected [min, max] range of values.
MinMaxAttr:$minmax,
F32Attr:$min,
F32Attr:$max,
// The bitwidth of the quantization; between 2 and 16, inclusive.
I32Attr:$num_bits,
// Quantization range starts from 0 or 1; starts from 1 if true.
@ -2777,6 +2880,8 @@ def TFL_FakeQuantOp : TFL_Op<"fake_quant", [NoSideEffect]> {
let results = (outs AnyTensor:$output);
let hasCanonicalizer = 0b1;
let hasOptions = 1;
}
def TFL_QConstOp : Op<TFL_Dialect, "pseudo_qconst", [
@ -2823,6 +2928,20 @@ def TFL_QuantizeOp: TFL_Op<"quantize", [
let results = (outs AnyTensor:$output);
}
def TFL_DensifyOp: TFL_Op<"densify", [NoSideEffect,
SameOperandsAndResultType,
NoQuantizableResult]> {
let summary = "Densify operator";
let description = [{
Converts sparse tensor to dense format.
}];
let arguments = (ins AnyTensor:$input);
let results = (outs AnyTensor:$output);
}
//===----------------------------------------------------------------------===//
// LSTM Ops
//===----------------------------------------------------------------------===//
@ -2912,7 +3031,7 @@ def TFL_LSTMOp :
LstmOptionalPeepholeWeightConstraint,
LstmProjectionWeightBiasConstraint,
LstmResultConstraint,
StatefulOperands<[18, 19]>]> {
TFL_StatefulOp]> {
let summary = "The full lstm operator";
let description = [{
@ -2996,6 +3115,11 @@ Ba et al. “Layer Normalization”
let hasOptions = 1;
let verifier = [{ return Verify(*this); }];
let extraClassDeclaration = [{
// StatefulOpInterface:
std::vector<int> GetStatefulOperands() { return {18, 19}; }
}];
}
// UnidirectionalSequenceLstm op.
@ -3007,7 +3131,7 @@ def TFL_UnidirectionalSequenceLSTMOp :
LstmOptionalPeepholeWeightConstraint,
LstmProjectionWeightBiasConstraint,
LstmResultConstraint,
StatefulOperands<[18, 19]>]> {
TFL_StatefulOp]> {
let summary = "Unidirectional sequence lstm operator";
let description = [{
@ -3076,6 +3200,11 @@ def TFL_UnidirectionalSequenceLSTMOp :
let hasOptions = 1;
let verifier = [{ return Verify(*this); }];
let extraClassDeclaration = [{
// StatefulOpInterface:
std::vector<int> GetStatefulOperands() { return {18, 19}; }
}];
}
def RnnResultConstraint : PredOpTrait<
@ -3085,7 +3214,7 @@ def RnnResultConstraint : PredOpTrait<
// UnidirectionalSequenceRNN op.
def TFL_UnidirectionalSequenceRNNOp :
TFL_Op<"unidirectional_sequence_rnn",
[RnnResultConstraint, StatefulOperands<[4]>]> {
[RnnResultConstraint, TFL_StatefulOp]> {
let summary = "Unidirectional sequence rnn operator";
@ -3129,6 +3258,11 @@ def TFL_UnidirectionalSequenceRNNOp :
let customOption = "SequenceRNNOptions";
let verifier = [{ return Verify(*this); }];
let extraClassDeclaration = [{
// StatefulOpInterface:
std::vector<int> GetStatefulOperands() { return {4}; }
}];
}
def TFL_WhereOp : TFL_Op<"where", [NoSideEffect]> {
@ -3180,7 +3314,7 @@ def SVDFResultConstraint: PredOpTrait<
// SVDF op.
def TFL_SVDFOp :
TFL_Op<"svdf",
[SVDFResultConstraint, StatefulOperands<[4]>]> {
[SVDFResultConstraint, TFL_StatefulOp]> {
let summary = "Single value decomposition filter operator";
@ -3216,6 +3350,67 @@ def TFL_SVDFOp :
let hasOptions = 1;
let verifier = [{ return Verify(*this); }];
let extraClassDeclaration = [{
// StatefulOpInterface:
std::vector<int> GetStatefulOperands() { return {4}; }
}];
}
def TFL_SegmentSumOp: TFL_Op<"segment_sum", [NoSideEffect]> {
let summary = "SegmentSum operator";
let description = [{
Computes the sum along segments of a tensor.
}];
let arguments = (ins
TensorOf<[F32, I32]>:$data,
I32Tensor:$segment_ids
);
let results = (outs TensorOf<[F32, I32]>:$output);
}
def TFL_YieldOp : Op<TFL_Dialect, "yield", [Terminator]> {
let summary = "Yield operation";
let description = [{
The "yield" operation represents a return operation within the conditional
and body of structured control flow (e.g., while). The operation takes
variable number of operands and produces no results. The operand number and
types must match the signature of the region that contains the operation.
}];
let arguments = (ins Variadic<AnyType>:$operands);
}
def TFL_WhileOp : Op<TFL_Dialect, "while", [
DeclareOpInterfaceMethods<LoopLikeOpInterface>,
SingleBlockImplicitTerminator<"YieldOp">,
// Make isolated from above to force values through operands to simplify
// exporting to subgraphs.
IsolatedFromAbove]> {
let summary = [{While loop}];
let description = [{
output = input; while (cond(output)) { output = body(output) }
input: A list of input tensors whose types are T.
output: A list of output tensors whose types are T.
cond: A region takes 'input' and returns a boolean scalar tensor.
body: A region that takes a list of tensors and returns another
list of tensors. Both lists have the same types.
}];
let arguments = (ins
Variadic<AnyTensor>:$input,
// Used to map StatelessWhile and While op defined in TensorFlow to a common
// op.
DefaultValuedAttr<BoolAttr, "false">:$is_stateless
);
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
let results = (outs Variadic<AnyTensor>:$output);
}
#endif // TFL_OPS

View File

@ -1,67 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the op traits used in the MLIR TensorFlow Lite dialect.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LLVM.h" // TF:llvm-project
namespace mlir {
namespace OpTrait {
namespace TFL {
// The trait to specify that the specified operands of the TFL op are stateful.
// This is used as a trait like this:
//
// class LSTMOp
// : public Op<LSTMOp, OpTrait::TFL::StatefulOperands<18, 19>::Impl> {
//
template <int... Operands>
class StatefulOperands {
public:
template <typename ConcreteType>
class Impl
: public TraitBase<ConcreteType, StatefulOperands<Operands...>::Impl> {
public:
static std::vector<int> GetStatefulOperands() {
return std::vector<int>({Operands...});
}
};
};
// The trait to specify the channel dimension index of the input (first operand)
// of an affine TFL op (Conv2D, DepthwiseConv2D, FullyConnected).
//
// class Conv2DOp
// : public Op<Conv2DOp, OpTrait::TFL::ChannelDimIndex<0>::Impl> {
//
template <int Index>
class ChannelDimIndex {
public:
template <typename ConcreteType>
class Impl : public TraitBase<ConcreteType, ChannelDimIndex<Index>::Impl> {
public:
static int GetChannelDimIndex() { return Index; }
};
};
} // namespace TFL
} // namespace OpTrait
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_

View File

@ -122,7 +122,7 @@ static void EmitOptionBuilders(const RecordKeeper &record_keeper,
os << formatv(
" auto {0} = Convert{1}ForOptionWriter(op.{0}(), fbb);\n",
val.getName(), record->getClasses()[0]->getName());
options.push_back(val.getName());
options.push_back(std::string(val.getName()));
}
}
}

View File

@ -32,6 +32,6 @@ cc_library(
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:ViewOpGraph",
"@llvm-project//mlir:Transforms",
],
)

View File

@ -107,9 +107,6 @@ void WarningUnusedFlags(const toco::ModelFlags& model_flags,
if (toco_flags.output_format()) {
LOG(WARNING) << "Ignored output_format.";
}
if (toco_flags.default_ranges_min() || toco_flags.default_ranges_max()) {
LOG(WARNING) << "Ignored default_ranges_stats.";
}
if (toco_flags.drop_control_dependency()) {
LOG(WARNING) << "Ignored drop_control_dependency.";
}
@ -242,6 +239,13 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
tensorflow::ParseOutputArrayInfo(output_arrays, &specs.outputs));
// Other flags.
if (toco_flags.has_default_ranges_min()) {
quant_specs.default_ranges.first = toco_flags.default_ranges_min();
}
if (toco_flags.has_default_ranges_max()) {
quant_specs.default_ranges.second = toco_flags.default_ranges_max();
}
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
bool emit_custom_ops = toco_flags.allow_custom_ops();

View File

@ -71,18 +71,17 @@ cc_library(
"quantization_utils.cc",
],
hdrs = [
"quantization_traits.h",
"quantization_utils.h",
],
deps = [
"//tensorflow/core:lib_proto_parsing",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
# TODO(fengliuai): remove this dependence.
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/core:lib_proto_parsing",
],
)

View File

@ -78,8 +78,8 @@ class ImportQuantStatsPass : public FunctionPass<ImportQuantStatsPass> {
bool IsQuantizableResult(Operation *op, int index) {
if (index < 0 || index >= op->getNumResults()) return false;
Value res = op->getResult(index);
return res->getType().isa<ShapedType>() &&
res->getType().cast<ShapedType>().getElementType().isa<FloatType>();
return res.getType().isa<ShapedType>() &&
res.getType().cast<ShapedType>().getElementType().isa<FloatType>();
}
// A method to retrieve the name for the given op.
@ -123,7 +123,7 @@ void ImportQuantStatsPass::InsertStatsOpAtResult(OpBuilder b, Value res,
IntegerAttr axis) {
auto stats_op = b.create<quant::StatisticsOp>(b.getUnknownLoc(), res,
layer_stats, axis_stats, axis);
res->replaceAllUsesWith(stats_op);
res.replaceAllUsesWith(stats_op);
stats_op.getOperation()->replaceUsesOfWith(stats_op, res);
}
@ -206,10 +206,17 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
std::unique_ptr<OpPassBase<FuncOp>>
CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) {
auto get_name_func = [](Operation *op) {
if (auto name = op->getAttrOfType<StringAttr>("name"))
return name.getValue();
else
return llvm::StringRef("");
Location loc = op->getLoc();
if (auto name = loc.dyn_cast<NameLoc>()) {
return name.getName().strref();
} else if (auto fused_name = loc.dyn_cast<FusedLoc>()) {
for (auto sub_loc : fused_name.getLocations()) {
if (auto named_sub_loc = sub_loc.dyn_cast<NameLoc>()) {
return named_sub_loc.getName().strref();
}
}
}
return llvm::StringRef("");
};
return CreateImportQuantStatsPass(get_name_func, stats_str);

View File

@ -12,6 +12,7 @@ package_group(
includes = ["//third_party/mlir:subpackages"],
packages = [
"//learning/brain/experimental/mlir/...",
"//tensorflow/compiler/mlir/lite/...",
"//tensorflow/lite/...",
],
)
@ -23,7 +24,6 @@ cc_library(
],
hdrs = [
"quantize_model.h",
"//tensorflow/compiler/mlir/lite:transforms/passes.h",
],
deps = [
"//tensorflow/compiler/mlir/lite:common",
@ -42,6 +42,24 @@ cc_library(
],
)
cc_library(
name = "tfl_to_std",
srcs = [
"tfl_to_std.cc",
],
hdrs = [
"tfl_to_std.h",
],
deps = [
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
],
)
# Binary to apply quantization on the annotated files.
tf_cc_binary(
name = "tfl_quantizer",

View File

@ -73,19 +73,19 @@ TfLiteStatus QuantizeModel(
// Apply quantization passes
PassManager pm(module->getContext());
TFL::QuantizationSpecs pass_config;
pass_config.inference_type = tensorflow::DT_QINT8;
pass_config.post_training_quantization = true;
TFL::QuantizationSpecs quant_specs;
quant_specs.inference_type = tensorflow::DT_QINT8;
quant_specs.post_training_quantization = true;
bool emit_adaptor = false;
auto input_tf_type = tflite::TflTypeToTfType(input_type);
if (input_tf_type == tensorflow::DT_FLOAT) {
emit_adaptor = true;
} else if (input_tf_type == tensorflow::DT_UINT8) {
pass_config.inference_type = tensorflow::DT_QUINT8;
quant_specs.inference_type = tensorflow::DT_QUINT8;
}
pm.addPass(TFL::CreatePrepareQuantizePass(pass_config));
pm.addPass(TFL::CreatePrepareQuantizePass(quant_specs));
pm.addPass(TFL::CreateQuantizePass());
pm.addPass(TFL::CreatePostQuantizePass(emit_adaptor));

Some files were not shown because too many files have changed in this diff Show More