Merge branch 'master' into google_upstream_training_ops

This commit is contained in:
ekuznetsov139 2020-02-04 03:37:53 -08:00 committed by GitHub
commit 9760afc119
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2616 changed files with 118944 additions and 39057 deletions

View File

@ -69,6 +69,7 @@
# rbe_linux_py3: Linux Python 3 RBE config
#
# rbe_win_py37: Windows Python 3.7 RBE config
# rbe_win_py38: Windows Python 3.8 RBE config
#
# tensorflow_testing_rbe_linux: RBE options to use RBE with tensorflow-testing project on linux
# tensorflow_testing_rbe_win: RBE options to use RBE with tensorflow-testing project on windows
@ -279,7 +280,6 @@ build:windows --host_linkopt=/OPT:REF
build:windows --linkopt=/OPT:ICF
build:windows --host_linkopt=/OPT:ICF
build:windows --experimental_strict_action_env=true
build:windows --incompatible_windows_native_test_wrapper
# Verbose failure logs when something goes wrong
build:windows --verbose_failures
@ -344,6 +344,7 @@ build:rbe_linux --config=avx_linux
build:rbe_linux --config=short_logs
# TODO(gunan): Check why we need this specified in rbe, but not in other builds.
build:rbe_linux --linkopt=-lrt
build:rbe_linux --linkopt=-lm
build:rbe_cpu_linux --config=rbe_linux
build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain"
@ -392,6 +393,7 @@ build:rbe_win --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe
# TODO(gunan): Remove once we use MSVC 2019 with latest patches.
build:rbe_win --define=override_eigen_strong_inline=true
build:rbe_win --jobs=500
build:rbe_win_py37 --config=rbe
build:rbe_win_py37 --repo_env=PYTHON_BIN_PATH=C:\\Python37\\python.exe
@ -399,6 +401,12 @@ build:rbe_win_py37 --repo_env=PYTHON_LIB_PATH=C:\\Python37\\lib\\site-packages
build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py37
build:rbe_win_py37 --python_path=C:\\Python37\\python.exe
build:rbe_win_py38 --config=rbe
build:rbe_win_py38 --repo_env=PYTHON_BIN_PATH=C:\\Python38\\python.exe
build:rbe_win_py38 --repo_env=PYTHON_LIB_PATH=C:\\Python38\\lib\\site-packages
build:rbe_win_py38 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py38
build:rbe_win_py38 --python_path=C:\\Python38\\python.exe
# These you may need to change for your own GCP project.
build:tensorflow_testing_rbe --project_id=tensorflow-testing
common:tensorflow_testing_rbe_linux --remote_instance_name=projects/tensorflow-testing/instances/default_instance

File diff suppressed because one or more lines are too long

View File

@ -245,4 +245,4 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc=
### Known Vulnerabilities
For a list of known vulnerabilities and security advisories for TensorFlow,
[click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/index.md).
[click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/README.md).

View File

@ -1,11 +1,13 @@
workspace(name = "org_tensorflow")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("//third_party:repo.bzl", "tf_http_archive")
http_archive(
tf_http_archive(
name = "io_bazel_rules_closure",
sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9",
strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149",
patch_file = "@org_tensorflow//third_party:rules_closure.patch",
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz",
"https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13
@ -48,38 +50,6 @@ load("//third_party/toolchains/preconfig/generate:workspace.bzl",
remote_config_workspace()
# Apple and Swift rules.
http_archive(
name = "build_bazel_rules_apple",
sha256 = "a045a436b642c70fb0c10ca84ff0fd2dcbd59cc89100d597a61e8374afafb366",
urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.18.0/rules_apple.0.18.0.tar.gz"],
) # https://github.com/bazelbuild/rules_apple/releases
http_archive(
name = "build_bazel_rules_swift",
sha256 = "18cd4df4e410b0439a4935f9ca035bd979993d42372ba79e7f2d4fafe9596ef0",
urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.12.1/rules_swift.0.12.1.tar.gz"],
) # https://github.com/bazelbuild/rules_swift/releases
http_archive(
name = "build_bazel_apple_support",
sha256 = "122ebf7fe7d1c8e938af6aeaee0efe788a3a2449ece5a8d6a428cb18d6f88033",
urls = ["https://github.com/bazelbuild/apple_support/releases/download/0.7.1/apple_support.0.7.1.tar.gz"],
) # https://github.com/bazelbuild/apple_support/releases
http_archive(
name = "bazel_skylib",
sha256 = "1dde365491125a3db70731e25658dfdd3bc5dbdfd11b840b3e987ecf043c7ca0",
urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.9.0/bazel-skylib.0.9.0.tar.gz"],
) # https://github.com/bazelbuild/bazel-skylib/releases
http_archive(
name = "com_github_apple_swift_swift_protobuf",
type = "zip",
strip_prefix = "swift-protobuf-1.6.0/",
urls = ["https://github.com/apple/swift-protobuf/archive/1.6.0.zip"],
) # https://github.com/apple/swift-protobuf/releases
http_file(
name = "xctestrunner",
executable = 1,
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.9/ios_test_runner.par"],
) # https://github.com/google/xctestrunner/releases
# Use `swift_rules_dependencies` to fetch the toolchains. With the
# `git_repository` rules above, the following call will skip redefining them.
load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies")

View File

@ -1221,7 +1221,7 @@ def is_reduced_optimize_huge_functions_available(environ_cp):
only, as of 2019-11-19). TensorFlow needs this flag to massively reduce
compile times, but until 16.4 is officially released, we can't depend on it.
See also https://groups.google.com/a/tensorflow.org/g/build/c/SsW98Eo7l3o
See also https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
Because it's very annoying to check this manually (to check the MSVC installed
versions, you need to use the registry, and it's not clear if Bazel will be

View File

@ -2,6 +2,7 @@
# TensorFlow is a computational framework, primarily for use in machine
# learning applications.
load("@bazel_skylib//lib:selects.bzl", "selects")
load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary")
load(
"//tensorflow/core/platform:build_config.bzl",
@ -478,6 +479,7 @@ bzl_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core/platform:build_config_root_bzl",
"//tensorflow/core/platform:rules_cc_bzl",
"//tensorflow/core/platform/default:cuda_build_defs_bzl",
"//third_party/mkl:build_defs_bzl",
"//third_party/mkl_dnn:build_defs_bzl",

View File

@ -23,10 +23,6 @@ from __future__ import print_function
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
from tensorflow.python.util.lazy_loader import LazyLoader
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
del LazyLoader
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
from tensorflow.python.platform import app # pylint: disable=g-import-not-at-top
app.flags = flags

View File

@ -54,9 +54,10 @@ filegroup(
)
filegroup(
name = "pywrap_eager_hdrs",
name = "pywrap_required_hdrs",
srcs = [
"c_api_internal.h",
"python_api.h",
"tf_status_helper.h",
"tf_status_internal.h",
"tf_tensor_internal.h",
@ -98,6 +99,17 @@ tf_cuda_library(
],
)
filegroup(
name = "pywrap_tf_session_hdrs",
srcs = [
"python_api.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)
cc_library(
name = "tf_attrtype",
hdrs = ["tf_attrtype.h"],
@ -302,6 +314,7 @@ tf_cuda_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/platform",
"@com_google_absl//absl/strings",
@ -639,7 +652,7 @@ tf_cuda_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/kernels:ops_testutil",
"//third_party/eigen3",
"@com_google_absl//absl/container:inlined_vector",
],
)

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/shape_inference.h"
@ -549,7 +550,7 @@ TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(TFE_Op* op,
TF_Status* status) {
TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification;
n->thread.reset(op->operation.EagerContext()->TFEnv()->StartThread(
n->thread.reset(op->operation.EagerContext().TFEnv()->StartThread(
tensorflow::ThreadOptions(), "ExecuteOpThread",
[op, retvals, num_retvals, n]() {
TFE_Execute(op, retvals, num_retvals, n->status.get());
@ -767,8 +768,9 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
} while (0);
// New server created for new server_def. Unused if updating server_def.
tensorflow::EagerContext* context = ctx->context;
tensorflow::GrpcServer* grpc_server =
dynamic_cast<tensorflow::GrpcServer*>(ctx->context->GetServer());
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
if (grpc_server == nullptr) {
std::unique_ptr<tensorflow::ServerInterface> new_server;
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
@ -779,12 +781,12 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
}
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
LOG_AND_RETURN_IF_ERROR(ctx->context->StoreCollectiveOpsServer(
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
std::move(new_server), grpc_server->worker_env()->device_mgr,
grpc_server->worker_env()->collective_executor_mgr));
} else {
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
LOG_AND_RETURN_IF_ERROR(ctx->context->StoreCollectiveOpsServer(
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
/*new_server=*/nullptr, grpc_server->worker_env()->device_mgr,
grpc_server->worker_env()->collective_executor_mgr));
}

View File

@ -1260,11 +1260,10 @@ TEST_F(CApiFunctionTest, GraphToFunctionDefWithPlaceholderAttr) {
NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node3", "v2",
&node3);
TF_Output inputs[] = {};
TF_Output outputs[] = {{node1, 0}, {node2, 0}, {node3, 0}};
func_ = TF_GraphToFunction(
func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
/*opers=*/nullptr, 0, inputs, 3, outputs,
/*opers=*/nullptr, 0, nullptr, 3, outputs,
/*output_names=*/nullptr,
/*opts=*/nullptr, /*description=*/nullptr, s.get());
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
@ -1300,10 +1299,9 @@ TEST_F(CApiFunctionTest, GraphToFunctionDefWithArgAttr) {
&node);
TF_Output inputs[] = {{node, 0}};
TF_Output outputs[] = {};
func_ = TF_GraphToFunction(
func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
/*opers=*/nullptr, 1, inputs, 0, outputs,
/*opers=*/nullptr, 1, inputs, 0, nullptr,
/*output_names=*/nullptr,
/*opts=*/nullptr, /*description=*/nullptr, s.get());
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
@ -1603,11 +1601,10 @@ void DefineStatefulFunction(const char* name, TF_Function** func) {
TF_Operation* random =
RandomUniform(shape, TF_FLOAT, func_graph.get(), s.get());
TF_Output inputs[] = {};
TF_Output outputs[] = {{random, 0}};
*func = TF_GraphToFunction(func_graph.get(), name,
/*append_hash_to_fn_name=*/false, -1,
/*opers=*/nullptr, 0, inputs, 1, outputs,
/*opers=*/nullptr, 0, nullptr, 1, outputs,
/*output_names=*/nullptr,
/*opts=*/nullptr, "", s.get());
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());

View File

@ -17,7 +17,7 @@ limitations under the License.
#include <memory.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/time.h>
#include <time.h>
#include <unistd.h>
#include "tensorflow/c/c_api.h"
@ -58,12 +58,8 @@ int main(int argc, char** argv) {
}
char file_name[100];
struct timeval t;
if (gettimeofday(&t, NULL)) {
perror("gettimeofday failed");
return 1;
}
snprintf(file_name, sizeof(file_name), "test-%d-%ld.txt", getpid(), t.tv_sec);
time_t t = time(NULL);
snprintf(file_name, sizeof(file_name), "test-%d-%ld.txt", getpid(), t);
size_t length = 2 + strlen(path) + strlen(file_name);
char* full_path = malloc(length);

View File

@ -26,8 +26,8 @@ tf_cuda_library(
"c_api.cc",
"c_api_debug.cc",
"c_api_experimental.h",
"c_api_internal.cc",
"c_api_internal.h",
"tensor_handle_interface.h",
],
hdrs = ["c_api.h"],
copts = tf_copts() + tfe_xla_copts(),
@ -89,10 +89,11 @@ tf_cuda_library(
)
filegroup(
name = "pywrap_eager_hdrs",
name = "pywrap_required_hdrs",
srcs = [
"c_api_experimental.h",
"c_api_internal.h",
"tensor_handle_interface.h",
],
visibility = [
"//tensorflow/core:__pkg__",
@ -102,7 +103,10 @@ filegroup(
tf_cuda_library(
name = "c_api_internal",
srcs = ["c_api_experimental.h"],
srcs = [
"c_api_experimental.h",
"tensor_handle_interface.h",
],
hdrs = ["c_api_internal.h"],
visibility = [
"//learning/deepmind/courier:__subpackages__",
@ -125,18 +129,6 @@ tf_cuda_library(
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/distributed_runtime:remote_device",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/eager:remote_tensor_handle",
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
"//tensorflow/core/profiler/lib:profiler_lib",
"//tensorflow/core/profiler/lib:profiler_session",
],
)

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
@ -43,6 +44,7 @@ limitations under the License.
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/platform.h" // NOLINT
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/protobuf/device_filters.pb.h"
#include "tensorflow/core/util/device_name_utils.h"
#ifdef TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@ -81,6 +83,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
@ -93,10 +96,8 @@ using tensorflow::string;
namespace {
const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) {
if (op->inference_ctx) {
return op->inference_ctx->op_def;
}
const tensorflow::OpDef* op_def;
const tensorflow::OpDef* op_def = op->operation.OpDef();
if (op_def) return op_def;
status->status =
tensorflow::OpDefForOp(op->operation.Name().c_str(), &op_def);
return op_def;
@ -265,9 +266,9 @@ tensorflow::Status GetReplacedFromExistingWorkers(
}
tensorflow::Status CreateRemoteContexts(
const std::vector<string>& remote_workers, tensorflow::uint64 context_id,
tensorflow::uint64 context_view_id, int keep_alive_secs,
const tensorflow::ServerDef& server_def,
TFE_Context* ctx, const std::vector<string>& remote_workers,
tensorflow::uint64 context_id, tensorflow::uint64 context_view_id,
int keep_alive_secs, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
const bool lazy_copy_remote_function_inputs,
const tensorflow::eager::CreateContextRequest& base_request) {
@ -296,7 +297,7 @@ tensorflow::Status CreateRemoteContexts(
continue;
}
tensorflow::eager::CreateContextRequest request(base_request);
tensorflow::eager::CreateContextRequest request;
tensorflow::eager::CreateContextResponse* response =
new tensorflow::eager::CreateContextResponse();
request.set_context_id(context_id);
@ -304,6 +305,21 @@ tensorflow::Status CreateRemoteContexts(
*request.mutable_server_def() = server_def;
request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task);
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
server_def.default_session_config());
std::vector<bool> filtered_device_mask;
ctx->context->FilterDevicesForRemoteWorkers(
remote_worker, base_request.cluster_device_attributes(),
&filtered_device_mask);
DCHECK_EQ(filtered_device_mask.size(),
base_request.cluster_device_attributes_size());
for (int i = 0; i < filtered_device_mask.size(); i++) {
if (filtered_device_mask[i]) {
const auto& da = base_request.cluster_device_attributes(i);
*request.add_cluster_device_attributes() = da;
}
}
request.set_async(async);
request.set_keep_alive_secs(keep_alive_secs);
request.set_lazy_copy_remote_function_inputs(
@ -325,13 +341,34 @@ tensorflow::Status CreateRemoteContexts(
}
tensorflow::Status UpdateRemoteContexts(
const std::vector<string>& remote_workers, tensorflow::uint64 context_id,
TFE_Context* ctx, const std::vector<string>& remote_workers,
const std::vector<string>& added_workers,
const std::vector<string>& removed_workers, tensorflow::uint64 context_id,
tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers,
const tensorflow::eager::CreateContextRequest& base_request) {
int num_remote_workers = remote_workers.size();
tensorflow::BlockingCounter counter(num_remote_workers);
std::vector<tensorflow::Status> statuses(num_remote_workers);
int cluster_device_count = base_request.cluster_device_attributes_size();
std::unordered_set<string> added_or_removed(added_workers.begin(),
added_workers.end());
std::copy(removed_workers.begin(), removed_workers.end(),
std::inserter(added_or_removed, added_or_removed.end()));
// Whether each device is in the updated (added or removed) workers
std::vector<bool> device_added_or_removed(cluster_device_count);
for (int i = 0; i < base_request.cluster_device_attributes_size(); i++) {
const auto& da = base_request.cluster_device_attributes().at(i);
tensorflow::DeviceNameUtils::ParsedName pn;
tensorflow::DeviceNameUtils::ParseFullName(da.name(), &pn);
string task_name;
tensorflow::DeviceNameUtils::GetTaskName(pn, &task_name);
if (added_or_removed.find(task_name) != added_or_removed.end()) {
device_added_or_removed[i] = true;
}
}
for (int i = 0; i < num_remote_workers; i++) {
const string& remote_worker = remote_workers[i];
tensorflow::DeviceNameUtils::ParsedName parsed_name;
@ -354,17 +391,42 @@ tensorflow::Status UpdateRemoteContexts(
continue;
}
std::vector<bool> filtered_device_mask;
ctx->context->FilterDevicesForRemoteWorkers(
remote_worker, base_request.cluster_device_attributes(),
&filtered_device_mask);
DCHECK_EQ(filtered_device_mask.size(), cluster_device_count);
// If any of the devices that match the device filters are in the set of
// added or removed workers, we must send a complete UpdateContextRequest.
// Otherwise, only send a simple request to increment context view ID.
std::vector<bool> added_or_removed_filtered_devices(cluster_device_count);
std::transform(device_added_or_removed.begin(),
device_added_or_removed.end(), filtered_device_mask.begin(),
added_or_removed_filtered_devices.begin(),
std::logical_and<bool>());
const bool full_update_request =
std::accumulate(added_or_removed_filtered_devices.begin(),
added_or_removed_filtered_devices.end(), false,
std::logical_or<bool>());
tensorflow::eager::UpdateContextRequest request;
auto* response = new tensorflow::eager::UpdateContextResponse();
*request.mutable_server_def() = server_def;
request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task);
for (const auto& da : base_request.cluster_device_attributes()) {
*request.add_cluster_device_attributes() = da;
}
request.set_context_id(context_id);
request.set_context_view_id(context_view_id);
if (full_update_request) {
*request.mutable_server_def() = server_def;
request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task);
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
server_def.default_session_config());
for (int i = 0; i < cluster_device_count; i++) {
if (filtered_device_mask[i]) {
const auto& da = base_request.cluster_device_attributes(i);
*request.add_cluster_device_attributes() = da;
}
}
}
eager_client->UpdateContextAsync(
&request, response,
@ -409,6 +471,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
// New server created for new server_def. Unused if updating server_def.
std::unique_ptr<tensorflow::ServerInterface> new_server;
tensorflow::EagerContext* context = ctx->context;
tensorflow::GrpcServer* grpc_server;
if (reset_context) {
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
@ -416,26 +479,25 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
LOG_AND_RETURN_IF_ERROR(
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
} else {
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(
ctx->context->GetServer(), worker_name, &curr_remote_workers));
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
&curr_remote_workers));
// No need to check the cast here, since `ListRemoteWorkers` already checks
// if the server is a GRPC server or not.
grpc_server =
dynamic_cast<tensorflow::GrpcServer*>(ctx->context->GetServer());
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
LOG_AND_RETURN_IF_ERROR(
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
}
tensorflow::uint64 context_id = ctx->context->GetContextId();
tensorflow::uint64 context_view_id = ctx->context->GetContextViewId();
tensorflow::uint64 context_id = context->GetContextId();
tensorflow::uint64 context_view_id = context->GetContextViewId();
if (reset_context) {
context_id = tensorflow::EagerContext::NewContextId();
context_view_id = 0;
// Make master eager context accessible by local eager service, which might
// receive send tensor requests from remote workers.
LOG_AND_RETURN_IF_ERROR(grpc_server->AddMasterEagerContextToEagerService(
context_id, ctx->context));
LOG_AND_RETURN_IF_ERROR(
grpc_server->AddMasterEagerContextToEagerService(context_id, context));
}
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
@ -464,11 +526,11 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
&new_remote_device_mgr));
remote_device_mgr = new_remote_device_mgr.get();
} else {
ctx->context->ClearCachesAndDefaultExecutor();
context->ClearCachesAndDefaultExecutor();
// TODO(b/143914772): Potential memory leak if rendezvous has pending
// tensors for removed / replaced workers.
remote_device_mgr = ctx->context->GetOwnedRemoteDeviceMgr();
remote_device_mgr = context->GetOwnedRemoteDeviceMgr();
if (remote_device_mgr == nullptr) {
LOG_AND_RETURN_IF_ERROR(tensorflow::errors::InvalidArgument(
"Updating context with an invalid set of remote devices."));
@ -479,8 +541,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
&added_workers, &removed_workers,
&existing_workers);
LOG_AND_RETURN_IF_ERROR(GetReplacedFromExistingWorkers(
&existing_workers, context_id, ctx->context->GetContextViewId(),
server_def, remote_eager_workers.get(), &replaced_workers));
&existing_workers, context_id, context->GetContextViewId(), server_def,
remote_eager_workers.get(), &replaced_workers));
if (VLOG_IS_ON(1)) {
VLOG(1) << "Updating cluster with following changes";
for (const string& w : added_workers) VLOG(1) << " Added worker " << w;
@ -516,7 +578,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
grpc_server->worker_env()->device_mgr->ListDeviceAttributes(
&local_device_attributes);
// This request make sure that we can create Rendevzous properly between
// This request make sure that we can create Rendezvous properly between
// Local and Remote context.
tensorflow::eager::CreateContextRequest base_request;
for (const auto& da : cluster_device_attributes) {
@ -525,18 +587,14 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
for (const auto& da : local_device_attributes) {
*base_request.add_cluster_device_attributes() = da;
}
base_request.mutable_server_def()
->mutable_default_session_config()
->MergeFrom(server_def.default_session_config());
// Initialize remote eager workers.
// TODO(b/138847548) Create remote eager contexts in async mode by default.
if (reset_context) {
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
remote_workers, context_id, context_view_id, keep_alive_secs,
server_def, remote_eager_workers.get(),
ctx->context->Executor().Async(),
ctx->context->LazyCopyFunctionRemoteInputs(), base_request));
ctx, remote_workers, context_id, context_view_id, keep_alive_secs,
server_def, remote_eager_workers.get(), context->Executor().Async(),
context->LazyCopyFunctionRemoteInputs(), base_request));
} else {
// The master's context_view_id will be incremented by one
// the UpdateRemoteMaster call later. We want all new workers and
@ -544,10 +602,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
// we must set their context_view_id to the existing master's
// context_view_id + 1.
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
added_workers, context_id, context_view_id + 1, keep_alive_secs,
server_def, remote_eager_workers.get(),
ctx->context->Executor().Async(),
ctx->context->LazyCopyFunctionRemoteInputs(), base_request));
ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
server_def, remote_eager_workers.get(), context->Executor().Async(),
context->LazyCopyFunctionRemoteInputs(), base_request));
if (!existing_workers.empty()) {
if (VLOG_IS_ON(1)) {
for (const string& w : existing_workers) {
@ -555,8 +612,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
}
}
LOG_AND_RETURN_IF_ERROR(UpdateRemoteContexts(
existing_workers, context_id, context_view_id + 1, server_def,
remote_eager_workers.get(), base_request));
ctx, existing_workers, added_workers, removed_workers, context_id,
context_view_id + 1, server_def, remote_eager_workers.get(),
base_request));
}
}
@ -578,12 +636,12 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
tensorflow::eager::CreateClusterFLR(context_id, ctx->context,
tensorflow::eager::CreateClusterFLR(context_id, context,
worker_session.get());
auto remote_mgr = absl::make_unique<tensorflow::eager::RemoteMgr>(
/*is_master=*/true, ctx->context);
/*is_master=*/true, context);
LOG_AND_RETURN_IF_ERROR(ctx->context->InitializeRemoteMaster(
LOG_AND_RETURN_IF_ERROR(context->InitializeRemoteMaster(
std::move(new_server), grpc_server->worker_env(), worker_session,
std::move(remote_eager_workers), std::move(new_remote_device_mgr),
remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr,
@ -601,9 +659,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
session_name, &worker_session));
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
tensorflow::eager::CreateClusterFLR(context_id, ctx->context,
tensorflow::eager::CreateClusterFLR(context_id, context,
worker_session.get());
LOG_AND_RETURN_IF_ERROR(ctx->context->UpdateRemoteMaster(
LOG_AND_RETURN_IF_ERROR(context->UpdateRemoteMaster(
grpc_server->worker_env(), std::move(remote_eager_workers),
added_workers, removed_workers, context_id, r, device_mgr,
keep_alive_secs, cluster_flr));
@ -614,76 +672,6 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
}
#endif // !IS_MOBILE_PLATFORM
tensorflow::Status OpInferSingleInputAttrs(TFE_Op* op,
TFE_TensorHandle* input) {
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++);
if (!input_def.number_attr().empty() || !input_def.type_list_attr().empty()) {
// Some clients that are still setting their input attributes manually are
// adding input list to their op by calling `TFE_OpAddInput` for each of
// its elements instead of calling `TFE_OpAddInputList`. When this happens,
// we cannot detect the end of such list, thus lose track of the input
// arguments in the op definition. To guarantee backward compatibility with
// those clients, disable automatic inference in this case.
op->inference_ctx.reset(nullptr);
return tensorflow::Status::OK();
}
const std::string& type_attr = input_def.type_attr();
if (!type_attr.empty() && ictx->attrs.find(type_attr) == ictx->attrs.end()) {
op->operation.MutableAttrs()->Set(type_attr, input->handle->dtype);
ictx->attrs.insert(type_attr);
}
return tensorflow::Status::OK();
}
void OpInferSingleTypeInputListAttrs(TFE_Op* op,
const tensorflow::OpDef::ArgDef& input_def,
const tensorflow::DataType dtype,
int num_inputs) {
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
if (ictx->attrs.find(input_def.number_attr()) == ictx->attrs.end()) {
op->operation.MutableAttrs()->Set(input_def.number_attr(), num_inputs);
ictx->attrs.insert(input_def.number_attr());
}
if (ictx->attrs.find(input_def.type_attr()) == ictx->attrs.end()) {
op->operation.MutableAttrs()->Set(input_def.type_attr(), dtype);
ictx->attrs.insert(input_def.type_attr());
}
}
void OpInferMixedTypeInputListAttrs(
TFE_Op* op, const tensorflow::OpDef::ArgDef& input_def,
const std::vector<tensorflow::DataType>& dtypes) {
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
if (ictx->attrs.find(input_def.type_list_attr()) == ictx->attrs.end()) {
op->operation.MutableAttrs()->Set(
input_def.type_list_attr(),
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(dtypes.data(),
dtypes.size()));
ictx->attrs.insert(input_def.type_list_attr());
}
}
tensorflow::Status OpInferInputListAttrs(TFE_Op* op, TFE_TensorHandle** inputs,
int num_inputs) {
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++);
if (!input_def.type_list_attr().empty()) {
std::vector<tensorflow::DataType> dtypes(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
dtypes[i] = inputs[i]->handle->dtype;
}
OpInferMixedTypeInputListAttrs(op, input_def, dtypes);
} else if (!input_def.type_attr().empty() &&
!input_def.number_attr().empty()) {
OpInferSingleTypeInputListAttrs(op, input_def, inputs[0]->handle->dtype,
num_inputs);
} else {
return tensorflow::errors::InvalidArgument("Invalid input list definition");
}
return tensorflow::Status::OK();
}
} // namespace
extern "C" {
@ -719,12 +707,14 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get());
return new TFE_Context(opts->session_options.options,
opts->device_placement_policy, opts->mirroring_policy,
opts->async, opts->lazy_remote_inputs_copy,
device_mgr.release(),
/*device_mgr_owned*/ true, r,
tensorflow::GetDefaultCustomKernelCreator());
return new TFE_Context{new tensorflow::EagerContext(
opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
/*device_mgr_owned*/ true, r,
tensorflow::GetDefaultCustomKernelCreator())};
}
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
@ -735,22 +725,28 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr);
return new TFE_Context(opts->session_options.options,
opts->device_placement_policy, opts->mirroring_policy,
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
/*device_mgr_owned*/ false, r,
tensorflow::GetDefaultCustomKernelCreator());
return new TFE_Context{new tensorflow::EagerContext(
opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
/*device_mgr_owned*/ false, r,
tensorflow::GetDefaultCustomKernelCreator())};
}
void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; }
void TFE_DeleteContext(TFE_Context* ctx) {
// context->RefCountIsOne() should be true here.
// TODO(iga): Remove EagerContext refcounting.
ctx->context->Unref();
delete ctx;
}
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
TF_DeviceList* list = new TF_DeviceList;
ctx->context->local_device_mgr()->ListDeviceAttributes(&list->response);
if (ctx->context->remote_device_mgr()) {
ctx->context->remote_device_mgr()->ListDeviceAttributes(&list->response);
}
return list;
TF_DeviceList* l = new TF_DeviceList;
ctx->context->ListDevices(&l->response);
return l;
}
void TFE_ContextClearCaches(TFE_Context* ctx) {
@ -773,6 +769,22 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
"Invalid tensorflow.ServerDef protocol buffer");
return;
}
if (server_def.has_cluster_device_filters()) {
const auto& cdf = server_def.cluster_device_filters();
for (const auto& jdf : cdf.jobs()) {
const string& remote_prefix = "/job:" + jdf.name() + "/task:";
for (const auto& tdf : jdf.tasks()) {
const int32_t task_index = tdf.first;
std::vector<string> device_filters(tdf.second.device_filters_size());
for (int i = 0; i < tdf.second.device_filters_size(); i++) {
device_filters[i] = tdf.second.device_filters(i);
}
const string remote_worker = remote_prefix + std::to_string(task_index);
status->status =
ctx->context->SetRemoteDeviceFilters(remote_worker, device_filters);
}
}
}
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
ctx, /*reset_context=*/true);
#endif // !IS_MOBILE_PLATFORM
@ -797,6 +809,11 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
status->status = tensorflow::errors::InvalidArgument(
"Trying to update a context with invalid context id.");
}
if (server_def.has_cluster_device_filters()) {
LOG(WARNING) << "Device filters can only be specified when initializing "
"the cluster. Any changes in device filters are ignored "
"when updating the server def.";
}
// TODO(haoyuzhang): Check server_def compatibility before the update
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
ctx, /*reset_context=*/false);
@ -811,8 +828,9 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
"TFE_ContextSetServerDef not supported on mobile");
return false;
#else // !defined(IS_MOBILE_PLATFORM)
tensorflow::EagerContext* context = ctx->context;
tensorflow::GrpcServer* grpc_server =
static_cast<tensorflow::GrpcServer*>(ctx->context->GetServer());
static_cast<tensorflow::GrpcServer*>(context->GetServer());
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
status->status = grpc_server->master_env()->worker_cache->GetEagerClientCache(
@ -831,7 +849,7 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
// Send a rpc request to the worker to check aliveness.
tensorflow::eager::KeepAliveRequest request;
request.set_context_id(ctx->context->GetContextId());
request.set_context_id(context->GetContextId());
tensorflow::eager::KeepAliveResponse response;
tensorflow::Status keep_alive_status;
@ -886,138 +904,212 @@ void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
if (h == nullptr) return;
tensorflow::profiler::TraceMe activity(
"TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
VLOG(1) << "Deleting tensor handle " << h << " with internal handle "
<< h->handle;
if (h->handle) {
h->handle->Unref();
}
delete h;
}
tensorflow::TensorHandleInterface::~TensorHandleInterface() {
VLOG(1) << "Deleting tensor handle " << this << " with internal handle "
<< handle_;
if (handle_) {
handle_->Unref();
}
}
bool tensorflow::TensorHandleInterface::IsValid(Status* status) const {
if (handle_ == nullptr) {
*status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return false;
}
return true;
}
TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
return static_cast<TF_DataType>(h->handle->dtype);
return h->handle->DataType();
}
TF_DataType tensorflow::TensorHandleInterface::DataType() const {
return static_cast<TF_DataType>(handle_->dtype);
}
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return -1;
}
return h->handle->NumDims(&status->status);
}
int tensorflow::TensorHandleInterface::NumDims(Status* status) const {
if (!IsValid(status)) {
return -1;
}
int result;
status->status = h->handle->NumDims(&result);
*status = handle_->NumDims(&result);
return result;
}
int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return -1;
}
return h->handle->NumElements(&status->status);
}
int64_t tensorflow::TensorHandleInterface::NumElements(Status* status) const {
if (!IsValid(status)) {
return -1;
}
tensorflow::int64 result;
status->status = h->handle->NumElements(&result);
*status = handle_->NumElements(&result);
return result;
}
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return -1;
}
return h->handle->Dim(dim_index, &status->status);
}
int64_t tensorflow::TensorHandleInterface::Dim(int dim_index,
Status* status) const {
if (!IsValid(status)) {
return -1;
}
tensorflow::int64 result;
status->status = h->handle->Dim(dim_index, &result);
*status = handle_->Dim(dim_index, &result);
return result;
}
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
tensorflow::Device* d = h->handle->op_device();
return h->handle->DeviceName(&status->status);
}
const char* tensorflow::TensorHandleInterface::DeviceName(
Status* status) const {
if (!IsValid(status)) {
return nullptr;
}
tensorflow::Device* d = handle_->op_device();
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str();
}
const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
tensorflow::Device* d = h->handle->device();
return h->handle->BackingDeviceName(&status->status);
}
const char* tensorflow::TensorHandleInterface::BackingDeviceName(
Status* status) const {
if (!IsValid(status)) {
return nullptr;
}
tensorflow::Device* d = handle_->device();
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str();
}
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr || !h->handle->IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
h->handle->Ref();
return new TFE_TensorHandle{
std::unique_ptr<AbstractTensorHandleInterface>(h->handle->Copy())};
}
return new TFE_TensorHandle(h->handle);
AbstractTensorHandleInterface* tensorflow::TensorHandleInterface::Copy() {
handle_->Ref();
return new TensorHandleInterface(handle_);
}
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
tensorflow::TensorHandle* handle = h->handle;
return h->handle->Resolve(&status->status);
}
TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
if (!IsValid(status)) {
return nullptr;
}
// TODO(agarwal): move this implementation inside TFE_TensorHandle.
if (handle->IsRemote()) {
if (handle_->IsRemote()) {
const tensorflow::Tensor* t = nullptr;
tensorflow::TensorHandle* h_cpu = nullptr;
status->status = EagerCopyToDevice(
handle, handle->Context(), &handle->Context()->Executor(),
handle->Context()->HostCPU(), false, &h_cpu);
if (!status->status.ok()) {
*status = EagerCopyToDevice(handle_, handle_->Context(),
&handle_->Context()->Executor(),
handle_->Context()->HostCPU(), false, &h_cpu);
if (!status->ok()) {
return nullptr;
}
status->status = h_cpu->Tensor(&t);
if (!status->status.ok()) {
*status = h_cpu->Tensor(&t);
if (!status->ok()) {
h_cpu->Unref();
return nullptr;
}
TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, &status->status);
TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status);
h_cpu->Unref();
return retval;
} else {
tensorflow::Tensor tensor;
if (IsCPU(handle->device())) {
if (IsCPU(handle_->device())) {
const tensorflow::Tensor* src = nullptr;
status->status = handle->Tensor(&src);
if (!status->status.ok()) return nullptr;
*status = handle_->Tensor(&src);
if (!status->ok()) return nullptr;
tensor = *src;
} else {
tensorflow::EagerContext* ctx = handle->Context();
tensorflow::EagerContext* ctx = handle_->Context();
CHECK_NE(ctx, nullptr);
status->status = h->handle->CopyToDevice(ctx, ctx->HostCPU(), &tensor);
if (!status->status.ok()) return nullptr;
*status = handle_->CopyToDevice(*ctx, ctx->HostCPU(), &tensor);
if (!status->ok()) return nullptr;
}
return tensorflow::TF_TensorFromTensor(tensor, &status->status);
return tensorflow::TF_TensorFromTensor(tensor, status);
}
}
void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr || !h->handle->IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
tensorflow::TensorHandle* handle = h->handle;
tensorflow::TensorHandle* handle =
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle();
if (handle->IsRemote()) {
status->status = tensorflow::errors::InvalidArgument(
@ -1046,7 +1138,8 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg, TF_Status* status) {
tensorflow::Device* device;
status->status = ctx->context->FindDeviceFromName(device_name, &device);
tensorflow::EagerContext* context = ctx->context;
status->status = context->FindDeviceFromName(device_name, &device);
if (!status->status.ok()) {
deallocator(data, len, deallocator_arg);
return nullptr;
@ -1074,11 +1167,12 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
buf->Unref();
tensorflow::TensorHandle* ret_handle;
status->status = tensorflow::TensorHandle::CreateLocalHandle(
t, device, ctx->context, &ret_handle);
t, device, context, &ret_handle);
if (!status->status.ok()) {
return nullptr;
}
return new TFE_TensorHandle(ret_handle);
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(ret_handle)};
}
// This function will block till the operation that produces `h` has
@ -1086,12 +1180,14 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
// bytes of the memory pointed to by the device pointer returned above.
size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr || !h->handle->IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return 0;
}
tensorflow::TensorHandle* handle = h->handle;
tensorflow::TensorHandle* handle =
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle();
if (handle->IsRemote()) {
status->status = tensorflow::errors::InvalidArgument(
@ -1109,8 +1205,14 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) {
return NewOrResetOp(ctx, op_or_function_name, nullptr, status,
/* op_to_reset= */ nullptr);
std::unique_ptr<TFE_Op> new_op(
new TFE_Op{tensorflow::EagerOperation(ctx->context)});
status->status =
new_op->operation.Reset(op_or_function_name, nullptr, false, nullptr);
if (!status->status.ok()) {
new_op.reset();
}
return new_op.release();
}
void TFE_DeleteOp(TFE_Op* op) { delete op; }
@ -1121,7 +1223,7 @@ void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
tensorflow::Device* device = (op->operation.Device() == nullptr)
? op->operation.EagerContext()->HostCPU()
? op->operation.EagerContext().HostCPU()
: op->operation.Device();
return device->name().c_str();
}
@ -1135,20 +1237,23 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
}
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
op->operation.AddInput(input->handle);
if (op->inference_ctx) {
status->status = OpInferSingleInputAttrs(op, input);
}
tensorflow::TensorHandle* h =
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
input->handle.get())
->Handle();
op->operation.AddInput(h);
status->status = op->operation.MaybeInferSingleInputAttrs(h);
}
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
TF_Status* status) {
for (int i = 0; i < num_inputs; ++i) {
op->operation.AddInput(inputs[i]->handle);
}
if (op->inference_ctx) {
status->status = OpInferInputListAttrs(op, inputs, num_inputs);
op->operation.AddInput(
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
inputs[i]->handle.get())
->Handle());
}
status->status = op->operation.InferInputListAttrs(num_inputs);
}
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
@ -1381,15 +1486,16 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) {
VLOG(1) << "Calling TFE_Execute() on op " << op;
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
VLOG(1) << "Calling TFE_Execute() on op " << op;
status->status = tensorflow::EagerExecute(&op->operation,
handle_retvals.data(), num_retvals);
if (!status->status.ok()) {
return;
}
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = new TFE_TensorHandle(handle_retvals[i]);
retvals[i] = new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle_retvals[i])};
}
}
@ -1399,15 +1505,18 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
TF_Status* status) {
tensorflow::TensorHandle* handle = nullptr;
tensorflow::Device* device;
status->status = ctx->context->FindDeviceFromName(device_name, &device);
tensorflow::EagerContext* context = ctx->context;
status->status = context->FindDeviceFromName(device_name, &device);
if (!status->status.ok()) {
return nullptr;
}
status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context,
&ctx->context->Executor(),
device, false, &handle);
status->status = tensorflow::EagerCopyToDevice(
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle(),
context, &context->Executor(), device, false, &handle);
if (status->status.ok()) {
return new TFE_TensorHandle(handle);
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
}
return nullptr;
}
@ -1455,11 +1564,12 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
TF_Status* status) {
status->status = ctx->context->Executor().WaitForAllPendingNodes();
tensorflow::EagerContext* context = ctx->context;
status->status = context->Executor().WaitForAllPendingNodes();
if (!status->status.ok()) return;
tensorflow::mutex_lock ml(*ctx->context->MetadataMu());
status->status = MessageToBuffer(*ctx->context->RunMetadataProto(), buf);
ctx->context->ClearRunMetadata();
tensorflow::mutex_lock ml(*context->MetadataMu());
status->status = MessageToBuffer(*context->RunMetadataProto(), buf);
context->ClearRunMetadata();
}
namespace {

View File

@ -213,7 +213,7 @@ TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo(
TFE_TensorDebugInfo* debug_info);
// Returns the number of dimensions used to represent the tensor on its device.
// The number of dimensions used to reprensent the tensor on device can be
// The number of dimensions used to represent the tensor on device can be
// different from the number returned by TFE_TensorHandleNumDims.
// The return value was current at the time of TFE_TensorDebugInfo creation.
TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims(

View File

@ -28,19 +28,22 @@ using tensorflow::string;
namespace {
std::vector<int64> TensorShapeAsVector(TFE_TensorHandle* handle,
TF_Status* status) {
std::vector<int64> TensorShapeAsVector(const tensorflow::TensorHandle& handle,
tensorflow::Status* status) {
std::vector<int64> shape;
int rank = TFE_TensorHandleNumDims(handle, status);
if (TF_GetCode(status) != TF_OK) {
int rank = -1;
*status = handle.NumDims(&rank);
if (!status->ok()) {
return shape;
}
shape.reserve(rank);
for (int i = 0; i < rank; ++i) {
shape.push_back(TFE_TensorHandleDim(handle, i, status));
if (TF_GetCode(status) != TF_OK) {
tensorflow::int64 dim;
*status = handle.Dim(i, &dim);
if (!status->ok()) {
return shape;
}
shape.push_back(dim);
}
return shape;
}
@ -51,14 +54,19 @@ extern "C" {
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
TFE_TensorHandle* h, TF_Status* status) {
return h->handle->TensorDebugInfo(&status->status);
}
TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
Status* status) {
const tensorflow::Tensor* tensor;
status->status = h->handle->Tensor(&tensor);
if (TF_GetCode(status) != TF_OK) {
*status = handle_->Tensor(&tensor);
if (!status->ok()) {
return nullptr;
}
#ifdef TENSORFLOW_EAGER_USE_XLA
tensorflow::Device* device = h->handle->device();
tensorflow::Device* device = handle_->device();
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
tensorflow::XlaDevice* xla_device =
@ -67,15 +75,15 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
tensorflow::XlaDevice::PaddedShapeFn shape_fn =
xla_device->metadata().padded_shape_fn();
xla::Shape padded_shape;
status->status = shape_fn(*tensor, &padded_shape);
if (!status->status.ok()) {
*status = shape_fn(*tensor, &padded_shape);
if (!status->ok()) {
return nullptr;
}
if (VLOG_IS_ON(3)) {
std::vector<int64> shape_to_log = TensorShapeAsVector(h, status);
if (!status->status.ok()) {
std::vector<int64> shape_to_log = TensorShapeAsVector(*handle_, status);
if (!status->ok()) {
// Ignore the status here as we are simply logging.
status->status = tensorflow::Status::OK();
*status = tensorflow::Status::OK();
} else {
VLOG(3) << "Fully padded shape of ["
<< absl::StrJoin(shape_to_log, ", ") << "] is "
@ -88,7 +96,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
// Currently, the only case of XlaTensor containing a tuple shape is to
// represent 64 bit ints, doubles, and complex numbers (we don't support
// 64bit complex numbers).
status->status = tensorflow::errors::InvalidArgument(
*status = tensorflow::errors::InvalidArgument(
"XlaTensors should only contain tuples of size 2. Shape: ",
padded_shape.DebugString());
return nullptr;
@ -100,13 +108,13 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
const xla::Shape& shape1 =
xla::ShapeUtil::GetTupleElementShape(padded_shape, 1);
if (shape0.IsTuple() || shape1.IsTuple()) {
status->status = tensorflow::errors::InvalidArgument(
*status = tensorflow::errors::InvalidArgument(
"XlaTensors should not contain nested tuples. Shape: ",
padded_shape.DebugString());
return nullptr;
}
if (!xla::ShapeUtil::Equal(shape0, shape1)) {
status->status = tensorflow::errors::InvalidArgument(
*status = tensorflow::errors::InvalidArgument(
"Subshapes of XlaTensors should be the same. Shape: ",
padded_shape.DebugString());
return nullptr;
@ -131,15 +139,15 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
dev_dims.push_back(padded_shape.dimensions(dim_index));
}
}
status->status = tensorflow::Status::OK();
*status = tensorflow::Status::OK();
return new TFE_TensorDebugInfo(dev_dims);
}
#endif // TENSORFLOW_EAGER_USE_XLA
// If the tensor is not an XLA tensor, the device shape is
// the same as regular tensor shape.
std::vector<int64> dev_dims = TensorShapeAsVector(h, status);
if (TF_GetCode(status) != TF_OK) {
std::vector<int64> dev_dims = TensorShapeAsVector(*handle_, status);
if (!status->ok()) {
return nullptr;
}
return new TFE_TensorDebugInfo(dev_dims);

View File

@ -18,22 +18,23 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/profiler/rpc/client/capture_profile.h"
#include "tensorflow/core/profiler/rpc/profiler_server.h"
using tensorflow::string;
void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status,
TFE_Op* op_to_reset) {
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status) {
if (op_to_reset) {
NewOrResetOp(ctx, op_or_function_name, raw_device_name, status,
op_to_reset);
status->status = op_to_reset->operation.Reset(
op_or_function_name, raw_device_name, false, nullptr);
} else {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"op_to_reset should not be nullptr");
@ -41,7 +42,9 @@ void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name,
}
void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
op->operation.ConsumeInput(h->handle);
op->operation.ConsumeInput(
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle());
}
TFE_Profiler* TFE_NewProfiler() { return new TFE_Profiler(); }
@ -85,14 +88,14 @@ bool TFE_ProfilerClientStartTracing(const char* service_addr,
int num_tracing_attempts,
TF_Status* status) {
tensorflow::Status s =
tensorflow::profiler::client::ValidateHostPortPair(service_addr);
tensorflow::profiler::ValidateHostPortPair(service_addr);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return false;
}
s = tensorflow::profiler::client::StartTracing(
service_addr, logdir, worker_list, include_dataset_ops, duration_ms,
num_tracing_attempts);
s = tensorflow::profiler::Trace(service_addr, logdir, worker_list,
include_dataset_ops, duration_ms,
num_tracing_attempts);
tensorflow::Set_TF_Status_from_Status(status, s);
return s.ok();
}
@ -101,14 +104,14 @@ void TFE_ProfilerClientMonitor(const char* service_addr, int duration_ms,
int monitoring_level, bool display_timestamp,
TF_Buffer* result, TF_Status* status) {
tensorflow::Status s =
tensorflow::profiler::client::ValidateHostPortPair(service_addr);
tensorflow::profiler::ValidateHostPortPair(service_addr);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return;
}
string content;
s = tensorflow::profiler::client::Monitor(
service_addr, duration_ms, monitoring_level, display_timestamp, &content);
s = tensorflow::profiler::Monitor(service_addr, duration_ms, monitoring_level,
display_timestamp, &content);
void* data = tensorflow::port::Malloc(content.length());
content.copy(static_cast<char*>(data), content.length(), 0);
result->data = data;
@ -616,3 +619,16 @@ void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
return new TFE_Executor(&ctx->context->Executor());
}
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
ctx->context->HostCPU()->parsed_name());
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
void* data = tensorflow::port::Malloc(str.length());
str.copy(static_cast<char*>(data), str.length(), 0);
buf->data = data;
buf->length = str.length();
buf->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
}

View File

@ -29,10 +29,10 @@ extern "C" {
// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster
// than seperately calling it because if the existing op has the same
// `raw_device_name`, it skips parsing and just leave as it is.
TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Context* ctx,
TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset,
const char* op_or_function_name,
const char* raw_device_name,
TF_Status* status, TFE_Op* op_to_reset);
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
TF_Status* status);
@ -458,6 +458,11 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg, TF_Status* status);
// Retrieves the address space (i.e. job, replia, task) of the local host and
// saves it in the buffer.
TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
TF_Buffer* buf);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -1,66 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/host_info.h"
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status,
TFE_Op* op_to_reset) {
const char* name = op_or_function_name; // Shorthand
const tensorflow::AttrTypeMap* types;
bool is_function = false;
status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function);
if (!status->status.ok()) {
return nullptr;
}
if (op_to_reset && op_to_reset->ctx != ctx) {
status->status = tensorflow::errors::Internal(
"Cannot reset a TFE_Op from another TFE_Context");
return nullptr;
}
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
if (!is_function) {
const tensorflow::OpDef* op_def;
status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def);
if (!status->status.ok()) {
return nullptr;
}
inference_ctx.reset(new TFE_OpInferenceContext(op_def));
} else if (!ctx->context->FindFunctionByName(name)) {
status->status = tensorflow::errors::NotFound(
"'", name,
"' is neither a type of a primitive operation nor a name "
"of a function registered in binary running on ",
tensorflow::port::Hostname(),
". Make sure the operation or function is "
"registered in the binary running in this process.");
return nullptr;
}
if (op_to_reset) {
status->status = op_to_reset->Reset(
name, is_function, types, raw_device_name, std::move(inference_ctx));
return op_to_reset;
}
TFE_Op* new_op =
new TFE_Op(ctx, name, is_function, types, std::move(inference_ctx));
status->status = new_op->operation.SetDeviceName(raw_device_name);
return new_op;
}

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
@ -62,36 +63,10 @@ struct TFE_ContextOptions {
};
struct TFE_Context {
TFE_Context(const tensorflow::SessionOptions& opts,
TFE_ContextDevicePlacementPolicy default_device_placement_policy,
TFE_ContextMirroringPolicy default_mirroring_policy, bool async,
const bool lazy_remote_inputs_copy,
const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned,
tensorflow::Rendezvous* rendezvous,
const tensorflow::CustomKernelCreator* custom_kernel_creator)
: context(new tensorflow::EagerContext(
opts,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
default_device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(
default_mirroring_policy),
async, lazy_remote_inputs_copy, device_mgr, device_mgr_owned,
rendezvous, custom_kernel_creator)) {}
~TFE_Context() {
// TODO(iga): Add a separate API method to shutdown TFE_Context so that we
// don't send RPCs and block in destructor.
context->WaitForAndCloseRemoteContexts();
// context->RefCountIsOne() should be true here.
// TODO(iga): Remove EagerContext refcounting.
context->Unref();
}
tensorflow::EagerContext* context;
};
struct TFE_TensorHandle {
explicit TFE_TensorHandle(tensorflow::TensorHandle* h) : handle(h) {}
static TFE_TensorHandle* CreateLocalHandle(const class tensorflow::Tensor& t,
TF_Status* s) {
tensorflow::TensorHandle* handle;
@ -99,10 +74,11 @@ struct TFE_TensorHandle {
if (!s->status.ok()) {
return nullptr;
}
return new TFE_TensorHandle(handle);
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
}
tensorflow::TensorHandle* handle;
std::unique_ptr<AbstractTensorHandleInterface> handle;
};
struct TFE_TensorDebugInfo {
@ -113,46 +89,10 @@ struct TFE_TensorDebugInfo {
std::vector<tensorflow::int64> dev_dims;
};
struct TFE_OpInferenceContext {
explicit TFE_OpInferenceContext(const tensorflow::OpDef* op_def)
: op_def(op_def) {}
const tensorflow::OpDef* op_def; // op definition from protobuf
int input_arg_idx = 0; // arg definition index for the next input to be added
tensorflow::gtl::FlatSet<std::string> attrs; // attributes inferred so far
};
struct TFE_Op {
TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
const tensorflow::AttrTypeMap* t,
std::unique_ptr<TFE_OpInferenceContext> inference_ctx)
: ctx(ctx),
operation(ctx->context, op, is_function, t),
inference_ctx(std::move(inference_ctx)) {}
void Clear() {
operation.Clear();
inference_ctx.reset();
}
tensorflow::Status Reset(const char* op, bool is_function,
const tensorflow::AttrTypeMap* t,
const char* raw_device_name,
std::unique_ptr<TFE_OpInferenceContext> infer_ctx) {
inference_ctx = std::move(infer_ctx);
return operation.Reset(ctx->context, op, is_function, t, raw_device_name,
nullptr);
}
TFE_Context* ctx;
tensorflow::EagerOperation operation;
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
};
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status,
TFE_Op* op_to_reset = nullptr);
struct TFE_Profiler {
explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); }

View File

@ -1362,10 +1362,11 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
TFE_TensorHandle* inputs[] = {input1, input2};
TFE_OpAddInput(concatOp, dim, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CHECK(concatOp->inference_ctx);
CHECK(concatOp->operation.OpDef());
TFE_OpAddInput(concatOp, inputs[0], status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_FALSE(concatOp->inference_ctx) << "Inference context is still present";
EXPECT_FALSE(concatOp->operation.OpDef())
<< "Inference context is still present";
TFE_OpAddInput(concatOp, inputs[1], status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);

View File

@ -284,7 +284,7 @@ class ForwardAccumulator {
// Temporarily push or pop transient state for this accumulator.
//
// Allows an accumulator which is currently processing an operation to
// temporarily reset its state. Without pushing and poping, accumulators
// temporarily reset its state. Without pushing and popping, accumulators
// ignore operations executed as a direct result of their own jvp
// computations.
void PushState() { call_state_.emplace(nullptr, false); }

View File

@ -0,0 +1,90 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
#define TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
// Abstract interface to a TensorHandle.
//
// A TensorHandle is management class around a Tensor which may track additional
// metadata and synchronization.
//
// This allows us to hide concrete implementations of TensorHandle from header
// files. The interface lists the common functionality that must be provided by
// any concrete implementation. However, in cases where the true concrete class
// is needed a static_cast can be applied.
class AbstractTensorHandleInterface {
public:
virtual ~AbstractTensorHandleInterface() {}
// Check if the handle is in a valid initialized state.
virtual bool IsValid(tensorflow::Status* status) const = 0;
// Returns tensor dtype.
virtual TF_DataType DataType() const = 0;
// Returns number of dimensions.
virtual int NumDims(tensorflow::Status* status) const = 0;
// Returns number of elements across all dimensions.
virtual int64_t NumElements(tensorflow::Status* status) const = 0;
// Returns size of specified dimension
virtual int64_t Dim(int dim_index, tensorflow::Status* status) const = 0;
// Returns the device which created the handle.
virtual const char* DeviceName(tensorflow::Status* status) const = 0;
// Returns the device where the tensor was placed.
virtual const char* BackingDeviceName(tensorflow::Status* status) const = 0;
// Returns a tensor for the handle. If tensor is remote, it will be copied.
virtual TF_Tensor* Resolve(tensorflow::Status* status) = 0;
// Returns debug information about the tensor.
virtual TFE_TensorDebugInfo* TensorDebugInfo(tensorflow::Status* status) = 0;
// Return a copy of the handle.
virtual AbstractTensorHandleInterface* Copy() = 0;
};
namespace tensorflow {
class TensorHandleInterface : public AbstractTensorHandleInterface {
public:
explicit TensorHandleInterface(TensorHandle* h) : handle_(h) {}
~TensorHandleInterface() override;
bool IsValid(Status* status) const override;
TF_DataType DataType() const override;
int NumDims(Status* status) const override;
int64_t NumElements(Status* status) const override;
int64_t Dim(int dim_index, Status* status) const override;
const char* DeviceName(Status* status) const override;
const char* BackingDeviceName(Status* status) const override;
TF_Tensor* Resolve(Status* status) override;
TFE_TensorDebugInfo* TensorDebugInfo(Status* status) override;
AbstractTensorHandleInterface* Copy() override;
// TODO(gjn): This is not a very generic interface, but is needed for specific
// use cases.
TensorHandle* Handle() { return handle_; }
private:
TensorHandle* handle_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_

View File

@ -18,37 +18,23 @@ cc_library(
],
)
# Core TensorFlow depends on this, this will be included in main library
cc_library(
name = "filesystem_interface_impl",
srcs = ["filesystem_interface.cc"],
hdrs = ["filesystem_interface.h"],
deps = [
":modular_filesystem",
"//tensorflow/c:tf_file_statistics",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_internal",
"//tensorflow/core:ptr_util",
"//tensorflow/core/platform:env",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:strcat",
"//tensorflow/core/platform:stringpiece",
],
alwayslink = 1,
)
# Core TensorFlow depends on this, will be included in main library
cc_library(
name = "modular_filesystem",
srcs = ["modular_filesystem.cc"],
srcs = [
"modular_filesystem.cc",
"modular_filesystem_registration.cc",
"modular_filesystem_registration.h",
],
hdrs = ["modular_filesystem.h"],
deps = [
":filesystem_interface",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core:lib",
"//tensorflow/c:tf_status_internal",
"//tensorflow/core:ptr_util",
"//tensorflow/core/platform:env",
"//tensorflow/core/platform:strcat",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
],
)
@ -63,16 +49,12 @@ tf_cc_test(
"notap", # b/139060984, requires implementing modular support for Google filesystem
],
deps = [
":filesystem_interface_impl",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_internal",
":modular_filesystem",
"//tensorflow/core:framework_internal",
"//tensorflow/core/lib/io:path",
"//tensorflow/core/platform:env",
"//tensorflow/core/platform:error",
"//tensorflow/core/platform:stacktrace_handler",
"//tensorflow/core/platform:str_util",
"//tensorflow/core/platform:strcat",
"//tensorflow/core/platform:test",
],
)

View File

@ -1,366 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/util/ptr_util.h"
/// This translation unit is linked in core TensorFlow and provides the
/// functionality needed for plugin registration to check ABI/API compatibility,
/// to ensure required methods are present, to ensure plugins are not allowed to
/// change functionality after being loaded and to register the filesystems
/// provided by a plugin. Consult the header file for more information about
/// how this is achieved.
namespace tensorflow {
namespace {
// Checks if the plugin and core ABI numbers match, filling in `status`.
//
// If the numbers don't match, plugin cannot be loaded.
static bool CheckABIHelper(int pluginABI, int coreABI, StringPiece where,
TF_Status* status) {
if (pluginABI != coreABI) {
TF_SetStatus(
status, TF_FAILED_PRECONDITION,
strings::StrCat("Plugin ABI (", pluginABI, ") for ", where,
" operations doesn't match expected core ABI (",
coreABI, "). Plugin cannot be loaded.")
.c_str());
return false;
}
return true;
}
// Checks if the plugin and core ABI numbers match, for all operations.
//
// If the numbers don't match, plugin cannot be loaded.
//
// Uses the simpler `CheckABIHelper(int, int, StringPiece, TF_Status*)`
static bool CheckABI(
int plugin_filesystem_ops_ABI,
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
int plugin_random_access_file_ops_ABI,
const TF_WritableFileOps* plugin_writable_file_ops,
int plugin_writable_file_ops_ABI,
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
int plugin_read_only_memory_region_ops_ABI, TF_Status* status) {
if (!CheckABIHelper(plugin_filesystem_ops_ABI, TF_FILESYSTEM_OPS_ABI,
"filesystem", status))
return false;
if (plugin_random_access_file_ops != nullptr &&
!CheckABIHelper(plugin_random_access_file_ops_ABI,
TF_RANDOM_ACCESS_FILE_OPS_ABI, "random access file",
status))
return false;
if (plugin_writable_file_ops != nullptr &&
!CheckABIHelper(plugin_writable_file_ops_ABI, TF_WRITABLE_FILE_OPS_ABI,
"writable file", status))
return false;
if (plugin_read_only_memory_region_ops != nullptr &&
!CheckABIHelper(plugin_read_only_memory_region_ops_ABI,
TF_READ_ONLY_MEMORY_REGION_OPS_ABI,
"read only memory region", status))
return false;
return true;
}
// Checks if the plugin and core API numbers match, logging mismatches.
static void CheckAPIHelper(int plugin_API, int core_API, StringPiece where) {
if (plugin_API != core_API) {
VLOG(0) << "Plugin API (" << plugin_API << ") for " << where
<< " operations doesn't match expected core API (" << core_API
<< "). Plugin will be loaded but functionality might be missing.";
}
}
// Checks if the plugin and core API numbers match, for all operations.
//
// Uses the simpler `CheckAPIHelper(int, int, StringPiece)`.
static void CheckAPI(
int plugin_filesystem_ops_API,
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
int plugin_random_access_file_ops_API,
const TF_WritableFileOps* plugin_writable_file_ops,
int plugin_writable_file_ops_API,
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
int plugin_read_only_memory_region_ops_API) {
CheckAPIHelper(plugin_filesystem_ops_API, TF_FILESYSTEM_OPS_API,
"filesystem");
if (plugin_random_access_file_ops != nullptr)
CheckAPIHelper(plugin_random_access_file_ops_API,
TF_RANDOM_ACCESS_FILE_OPS_API, "random access file");
if (plugin_writable_file_ops != nullptr)
CheckAPIHelper(plugin_writable_file_ops_API, TF_WRITABLE_FILE_OPS_API,
"writable file");
if (plugin_read_only_memory_region_ops != nullptr)
CheckAPIHelper(plugin_read_only_memory_region_ops_API,
TF_READ_ONLY_MEMORY_REGION_OPS_API,
"read only memory region");
}
// Validates the filesystem operations supplied by the plugin.
static bool ValidateHelper(const TF_FilesystemOps* ops, TF_Status* status) {
if (ops == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without operations");
return false;
}
if (ops->init == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `init` operation");
return false;
}
if (ops->cleanup == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `cleanup` operation");
return false;
}
return true;
}
// Validates the random access file operations supplied by the plugin.
static bool ValidateHelper(const TF_RandomAccessFileOps* ops,
TF_Status* status) {
if (ops == nullptr) {
// We allow filesystems where files can only be written to (from TF code)
return true;
}
if (ops->cleanup == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `cleanup` operation on "
"random access files");
return false;
}
return true;
}
// Validates the writable file operations supplied by the plugin.
static bool ValidateHelper(const TF_WritableFileOps* ops, TF_Status* status) {
if (ops == nullptr) {
// We allow read-only filesystems
return true;
}
if (ops->cleanup == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `cleanup` operation on "
"writable files");
return false;
}
return true;
}
// Validates the read only memory region operations given by the plugin.
static bool ValidateHelper(const TF_ReadOnlyMemoryRegionOps* ops,
TF_Status* status) {
if (ops == nullptr) {
// read only memory region support is always optional
return true;
}
if (ops->cleanup == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `cleanup` operation on "
"read only memory regions");
return false;
}
if (ops->data == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `data` operation on "
"read only memory regions");
return false;
}
if (ops->length == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `length` operation on "
"read only memory regions");
return false;
}
return true;
}
// Validates the operations supplied by the plugin.
//
// Uses the 4 simpler `ValidateHelper(const TF_..., TF_Status*)` to validate
// each individual function table and then checks that the function table for a
// specific file type exists if the plugin offers support for creating that
// type of files.
static bool Validate(
const TF_FilesystemOps* plugin_filesystem_ops,
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
const TF_WritableFileOps* plugin_writable_file_ops,
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
TF_Status* status) {
if (!ValidateHelper(plugin_filesystem_ops, status)) return false;
if (!ValidateHelper(plugin_random_access_file_ops, status)) return false;
if (!ValidateHelper(plugin_writable_file_ops, status)) return false;
if (!ValidateHelper(plugin_read_only_memory_region_ops, status)) return false;
if (plugin_filesystem_ops->new_random_access_file != nullptr &&
plugin_random_access_file_ops == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Filesystem allows creation of random access files but no "
"operations on them have been supplied.");
return false;
}
if ((plugin_filesystem_ops->new_writable_file != nullptr ||
plugin_filesystem_ops->new_appendable_file != nullptr) &&
plugin_writable_file_ops == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Filesystem allows creation of writable files but no "
"operations on them have been supplied.");
return false;
}
if (plugin_filesystem_ops->new_read_only_memory_region_from_file != nullptr &&
plugin_read_only_memory_region_ops == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Filesystem allows creation of readonly memory regions but no "
"operations on them have been supplied.");
return false;
}
return true;
}
// Copies a function table from plugin memory space to core memory space.
//
// This has three benefits:
// * allows having newer plugins than the current core TensorFlow: the
// additional entries in the plugin's table are just discarded;
// * allows having older plugins than the current core TensorFlow (though
// we are still warning users): the entries that core TensorFlow expects
// but plugins didn't provide will be set to `nullptr` values and core
// TensorFlow will know to not call these on behalf of users;
// * increased security as plugins will not be able to alter function table
// after loading up. Thus, malicious plugins can't alter functionality to
// probe for gadgets inside core TensorFlow. We can even protect the area
// of memory where the copies reside to not allow any more writes to it
// after all copies are created.
template <typename T>
static std::unique_ptr<const T> CopyToCore(const T* plugin_ops,
size_t plugin_size) {
if (plugin_ops == nullptr) return nullptr;
size_t copy_size = sizeof(T);
if (plugin_size < copy_size) {
copy_size = plugin_size;
}
auto core_ops = tensorflow::MakeUnique<T>();
memcpy(const_cast<T*>(core_ops.get()), plugin_ops, copy_size);
return core_ops;
}
} // namespace
} // namespace tensorflow
void RegisterFilesystemPlugin(
int plugin_filesystem_ops_ABI, int plugin_filesystem_ops_API,
size_t plugin_filesystem_ops_size, int plugin_random_access_file_ops_ABI,
int plugin_random_access_file_ops_API,
size_t plugin_random_access_file_ops_size, int plugin_writable_file_ops_ABI,
int plugin_writable_file_ops_API, size_t plugin_writable_file_ops_size,
int plugin_read_only_memory_region_ops_ABI,
int plugin_read_only_memory_region_ops_API,
size_t plugin_read_only_memory_region_ops_size, const char* scheme,
const TF_FilesystemOps* plugin_filesystem_ops,
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
const TF_WritableFileOps* plugin_writable_file_ops,
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
TF_Status* status) {
if (scheme == nullptr) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"`scheme` argument must not be `nullptr`.");
return;
}
// ABI numbers must match exactly for plugin to be loaded
if (!tensorflow::CheckABI(
plugin_filesystem_ops_ABI, plugin_random_access_file_ops,
plugin_random_access_file_ops_ABI, plugin_writable_file_ops,
plugin_writable_file_ops_ABI, plugin_read_only_memory_region_ops,
plugin_read_only_memory_region_ops_ABI, status)) {
return;
}
// API numbers should match but mismatch doesn't block plugin load
tensorflow::CheckAPI(plugin_filesystem_ops_API, plugin_random_access_file_ops,
plugin_random_access_file_ops_API,
plugin_writable_file_ops, plugin_writable_file_ops_API,
plugin_read_only_memory_region_ops,
plugin_read_only_memory_region_ops_API);
// Plugin can only be loaded if all supplied ops are valid
if (!tensorflow::Validate(plugin_filesystem_ops,
plugin_random_access_file_ops,
plugin_writable_file_ops,
plugin_read_only_memory_region_ops, status)) {
return;
}
// Copy all the function tables to core TensorFlow memory space
auto core_filesystem_ops = tensorflow::CopyToCore<TF_FilesystemOps>(
plugin_filesystem_ops, plugin_filesystem_ops_size);
auto core_random_access_file_ops =
tensorflow::CopyToCore<TF_RandomAccessFileOps>(
plugin_random_access_file_ops, plugin_random_access_file_ops_size);
auto core_writable_file_ops = tensorflow::CopyToCore<TF_WritableFileOps>(
plugin_writable_file_ops, plugin_writable_file_ops_size);
auto core_read_only_memory_region_ops =
tensorflow::CopyToCore<TF_ReadOnlyMemoryRegionOps>(
plugin_read_only_memory_region_ops,
plugin_read_only_memory_region_ops_size);
// Initialize the opaque filesystem structure
auto filesystem = tensorflow::MakeUnique<TF_Filesystem>();
core_filesystem_ops->init(filesystem.get(), status);
if (!status->status.ok()) {
core_filesystem_ops->cleanup(filesystem.get());
return;
}
// Register new filesystem
status->status = tensorflow::Env::Default()->RegisterFileSystem(
scheme, tensorflow::MakeUnique<tensorflow::ModularFileSystem>(
std::move(filesystem), std::move(core_filesystem_ops),
std::move(core_random_access_file_ops),
std::move(core_writable_file_ops),
std::move(core_read_only_memory_region_ops)));
}

View File

@ -56,7 +56,7 @@ extern "C" {
/// Lifetime: The wrapper data structures are owned by core TensorFlow. The data
/// pointed to by the `void*` members is always owned by the plugin. The plugin
/// will provide functions to call to allocate and deallocate this data (see
/// next section) and core TensorFlow ensures to call these at the proper time.
/// next sections) and core TensorFlow ensures to call these at the proper time.
///
/// Plugins will never receive a `TF_*` pointer that is `nullptr`. Core
/// TensorFlow will never touch the `void*` wrapped by these structures, except
@ -529,7 +529,7 @@ typedef struct TF_FilesystemOps {
/// If `statuses` is not null, plugins must fill each element with detailed
/// status for each file, as if calling `path_exists` on each one. Core
/// TensorFlow initializes the `statuses` array and plugins must use
/// `TF_SetStatus` to set each element instead of dirrectly assigning.
/// `TF_SetStatus` to set each element instead of directly assigning.
///
/// DEFAULT IMPLEMENTATION: Checks existence of every file. Needs
/// `path_exists`.
@ -601,6 +601,10 @@ typedef struct TF_FilesystemOps {
///
/// Plugins must not return `nullptr`. Returning empty strings is allowed.
///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// This function will be called by core TensorFlow to clean up all path
/// arguments for all other methods in the filesystem API.
///
@ -618,6 +622,10 @@ typedef struct TF_FilesystemOps {
/// In case of error, plugins must set `status` to a value different than
/// `TF_OK`, free memory allocated for `entries` and return -1.
///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// Plugins:
/// * Must set `status` to `TF_OK` if all children were returned.
/// * Must set `status` to `TF_NOT_FOUND` if `path` doesn't point to a
@ -654,6 +662,10 @@ typedef struct TF_FilesystemOps {
/// different than `TF_OK`, free any memory that might have been allocated for
/// `entries` and return -1.
///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// Plugins:
/// * Must set `status` to `TF_OK` if all matches were returned.
/// * Might use any other error value for `status` to signal other errors.
@ -736,95 +748,132 @@ constexpr size_t TF_FILESYSTEM_OPS_SIZE = sizeof(TF_FilesystemOps);
/// SECTION 4. Plugin registration and initialization
/// ----------------------------------------------------------------------------
///
/// In this section we define two functions:
/// * `TF_InitPlugin`: must be present in the plugin shared object as it will
/// be called by core TensorFlow when the filesystem plugin is loaded;
/// * `RegisterFilesystemPlugin`: it is implemented by core TensorFlow but
/// plugins must call it in their `TF_InitPlugin`, usually using the macro
/// `TF_REGISTER_FILESYSTEM_PLUGIN`.
/// In this section we define the API used by core TensorFlow to initialize a
/// filesystem provided by a plugin. That is, we define the following:
/// * `TF_InitPlugin` function: must be present in the plugin shared object as
/// it will be called by core TensorFlow when the filesystem plugin is
/// loaded;
/// * `TF_FilesystemPluginOps` struct: used to transfer information between
/// plugins and core TensorFlow about the operations provided and metadata;
/// * `TF_FilesystemPluginInfo` struct: similar to the above structure, but
/// collects information about all the file schemes that the plugin provides
/// support for, as well as about the plugin's memory handling routines;
/// * `TF_SetFilesystemVersionMetadata` function: must be called by plugins in
/// their `TF_InitPlugin` to record the versioning information the plugins
/// are compiled against.
///
/// The `TF_InitPlugin` function is used by plugins to set up the data
/// structures that implement this interface, as presented in Section 2.
///
/// The `RegisterFilesystemPlugin` is used by core TensorFlow to check that
/// plugins satisfy the requirements expected by core TensorFlow, as follows:
/// 1. If ABI numbers don't match we don't load the plugin, else we continue.
/// 2. If the API numbers are mismatched, we warn the user and continue
/// loading the plugin.
/// 3. If any required operation is missing, we stop loading the plugin.
///
/// If all these checks succeed, we copy the plugin operations to a different
/// memory location so that core TensorFlow has the guarantee that they won't be
/// changed by plugins at a later time. Finally, we initialize the opaque
/// pointer of `TF_Filesystem` by calling the required `init` function of
/// `TF_FilesystemOps` and if that succeeds we register the filesystem.
/// structures that implement this interface, as presented in Section 2. In
/// order to not have plugin shared objects call back symbols defined in core
/// TensorFlow, `TF_InitPlugin` has a `TF_FilesystemPluginInfo` argument which
/// the plugin must fill (using the `TF_SetFilesystemVersionMetadata` for the
/// metadata and setting up all the supported operations and the URI schemes
/// that are supported).
// Initializes a TensorFlow plugin.
//
// Must be implemented by the plugin DSO. It is called by TensorFlow runtime.
//
// Filesystem plugins can be loaded on demand by users via
// `Env::LoadLibrary` or during TensorFlow's startup if they are on certain
// paths (although this has a security risk if two plugins register for the
// same filesystem and the malicious one loads before the legimitate one -
// but we consider this to be something that users should care about and
// manage themselves). In both of these cases, core TensorFlow looks for
// the `TF_InitPlugin` symbol and calls that function.
//
// A plugin is loaded only if this `status` is `TF_OK` after the call.
TF_CAPI_EXPORT extern void TF_InitPlugin(TF_Status* status);
/// This structure incorporates the operations defined in Section 2 and the
/// metadata defined in section 3, allowing plugins to define different ops
/// for different URI schemes.
///
/// Every URI scheme is of the form "fs" for URIs of form "fs:///path/to/file".
/// For local filesystems (i.e., when the URI is "/path/to/file"), the scheme
/// must be "". The scheme must never be `nullptr`.
///
/// Every plugin fills this in `TF_InitPlugin`, using the alocator passed as
/// argument to allocate memory. After `TF_InitPlugin` finishes, core
/// TensorFlow uses the information present in this to initialize filesystems
/// for the URI schemes that the plugin requests.
///
/// All pointers defined in this structure point to memory allocated by the DSO
/// using an allocator provided by core TensorFlow when calling `TF_InitPlugin`.
///
/// IMPORTANT: To maintain binary compatibility, the layout of this structure
/// must not change! In the unlikely case that a new type of file needs to be
/// supported, add the new ops and metadata at the end of the structure.
typedef struct TF_FilesystemPluginOps {
char* scheme;
int filesystem_ops_abi;
int filesystem_ops_api;
size_t filesystem_ops_size;
TF_FilesystemOps* filesystem_ops;
int random_access_file_ops_abi;
int random_access_file_ops_api;
size_t random_access_file_ops_size;
TF_RandomAccessFileOps* random_access_file_ops;
int writable_file_ops_abi;
int writable_file_ops_api;
size_t writable_file_ops_size;
TF_WritableFileOps* writable_file_ops;
int read_only_memory_region_ops_abi;
int read_only_memory_region_ops_api;
size_t read_only_memory_region_ops_size;
TF_ReadOnlyMemoryRegionOps* read_only_memory_region_ops;
} TF_FilesystemPluginOps;
/// Registers a filesystem plugin so that core TensorFlow can use it.
/// This structure gathers together all the operations provided by the plugin.
///
/// Must be called by the plugin during `TF_InitPlugin`, usually by using the
/// convenience `TF_REGISTER_FILESYSTEM_PLUGIN` macro.
/// Plugins must provide exactly `num_schemes` elements in the `ops` array.
///
/// Arguments (grouped by category):
/// * `..ABI`: ABI compatibility numbers (see Section 3.).
/// * `..API`: API compatibility numbers (see Section 3.).
/// * `..Size`: Sizes of the operation tables (see Section 3.).
/// * `scheme`: The URI scheme that plugin is registering filesystems for.
/// Must be of the form "fs" for URIs of form "fs:///path/to/file". For
/// local filesystems (i.e., when the URI is "/path/to/file"), `scheme`
/// must be "". Must never be `nullptr`.
/// * `..Ops`: The function tables provided by the plugin. Owned by the
/// plugin, but core TensorFlow makes a copy of these.
/// * `status`: The output variable for representing success/failure.
/// Since memory that is allocated by the DSO gets transferred to core
/// TensorFlow, we need to provide a way for the allocation and deallocation to
/// match. This is why this structure also defines `plugin_memory_allocate` and
/// `plugin_memory_free` members.
///
/// Sets `status` to `TF_OK` if plugin was registered and filesystem operations
/// can be invoked from anywhere during TensorFlow's runtime. Any other value of
/// `status` means that plugin failed to load properly and as such the
/// operations it provides cannot be used at all (i.e., core TensorFlow will
/// never run them, returning early with `TF_UNIMPLEMENTED` or similar error
/// values).
TF_CAPI_EXPORT extern void RegisterFilesystemPlugin(
int pluginFilesystemOpsABI, int pluginFilesystemOpsAPI,
size_t pluginFilesystemOpsSize, int pluginRandomAccessFileOpsABI,
int pluginRandomAccessFileOpsAPI, size_t pluginRandomAccessFileOpsSize,
int pluginWritableFileOpsABI, int pluginWritableFileOpsAPI,
size_t pluginWritableFileOpsSize, int pluginReadOnlyMemoryRegionOpsABI,
int pluginReadOnlyMemoryRegionOpsAPI,
size_t pluginReadOnlyMemoryRegionOpsSize, const char* scheme,
const TF_FilesystemOps* pluginFilesystemOps,
const TF_RandomAccessFileOps* pluginRandomAccessFileOps,
const TF_WritableFileOps* pluginWritableFileOps,
const TF_ReadOnlyMemoryRegionOps* pluginReadOnlyMemoryRegionOps,
TF_Status* status);
/// All memory allocated by the plugin that will be owned by core TensorFlow
/// must be allocated using the allocator in this structure. Core TensorFlow
/// will use the deallocator to free this memory once it no longer needs it.
///
/// IMPORTANT: To maintain binary compatibility, the layout of this structure
/// must not change! In the unlikely case that new global operations must be
/// provided, add them at the end of the structure.
typedef struct TF_FilesystemPluginInfo {
size_t num_schemes;
TF_FilesystemPluginOps* ops;
void* (*plugin_memory_allocate)(size_t size);
void (*plugin_memory_free)(void* ptr);
} TF_FilesystemPluginInfo;
/// This macro is just a convenience wrapper around `RegisterFilesystemPlugin`.
/// Plugins should prefer using this macro instead of a direct call.
#define TF_REGISTER_FILESYSTEM_PLUGIN( \
scheme, pluginFilesystemOps, pluginRandomAccessFileOps, \
pluginWritableFileOps, pluginReadOnlyMemoryRegionOps, status) \
RegisterFilesystemPlugin( \
TF_FILESYSTEM_OPS_ABI, TF_FILESYSTEM_OPS_API, TF_FILESYSTEM_OPS_SIZE, \
TF_RANDOM_ACCESS_FILE_OPS_ABI, TF_RANDOM_ACCESS_FILE_OPS_API, \
TF_RANDOM_ACCESS_FILE_OPS_SIZE, TF_WRITABLE_FILE_OPS_ABI, \
TF_WRITABLE_FILE_OPS_API, TF_WRITABLE_FILE_OPS_SIZE, \
TF_READ_ONLY_MEMORY_REGION_OPS_ABI, TF_READ_ONLY_MEMORY_REGION_OPS_API, \
TF_READ_ONLY_MEMORY_REGION_OPS_SIZE, scheme, pluginFilesystemOps, \
pluginRandomAccessFileOps, pluginWritableFileOps, \
pluginReadOnlyMemoryRegionOps, status)
/// Convenience function for setting the versioning metadata.
///
/// The argument is guaranteed to not be `nullptr`.
///
/// We want this to be defined in the plugin's memory space and we guarantee
/// that core TensorFlow will never call this.
static inline void TF_SetFilesystemVersionMetadata(
TF_FilesystemPluginOps* ops) {
ops->filesystem_ops_abi = TF_FILESYSTEM_OPS_ABI;
ops->filesystem_ops_api = TF_FILESYSTEM_OPS_API;
ops->filesystem_ops_size = TF_FILESYSTEM_OPS_SIZE;
ops->random_access_file_ops_abi = TF_RANDOM_ACCESS_FILE_OPS_ABI;
ops->random_access_file_ops_api = TF_RANDOM_ACCESS_FILE_OPS_API;
ops->random_access_file_ops_size = TF_RANDOM_ACCESS_FILE_OPS_SIZE;
ops->writable_file_ops_abi = TF_WRITABLE_FILE_OPS_ABI;
ops->writable_file_ops_api = TF_WRITABLE_FILE_OPS_API;
ops->writable_file_ops_size = TF_WRITABLE_FILE_OPS_SIZE;
ops->read_only_memory_region_ops_abi = TF_READ_ONLY_MEMORY_REGION_OPS_ABI;
ops->read_only_memory_region_ops_api = TF_READ_ONLY_MEMORY_REGION_OPS_API;
ops->read_only_memory_region_ops_size = TF_READ_ONLY_MEMORY_REGION_OPS_SIZE;
}
/// Initializes a TensorFlow plugin.
///
/// Must be implemented by the plugin DSO. It is called by TensorFlow runtime.
///
/// Filesystem plugins can be loaded on demand by users via
/// `Env::LoadLibrary` or during TensorFlow's startup if they are on certain
/// paths (although this has a security risk if two plugins register for the
/// same filesystem and the malicious one loads before the legimitate one -
/// but we consider this to be something that users should care about and
/// manage themselves). In both of these cases, core TensorFlow looks for
/// the `TF_InitPlugin` symbol and calls this function.
///
/// For every filesystem URI scheme that this plugin supports, the plugin must
/// add one `TF_FilesystemPluginInfo` entry in `plugin_info->ops` and call
/// `TF_SetFilesystemVersionMetadata` for that entry.
///
/// Plugins must also initialize `plugin_info->plugin_memory_allocate` and
/// `plugin_info->plugin_memory_free` to ensure memory allocated by plugin is
/// freed in a compatible way.
TF_CAPI_EXPORT extern void TF_InitPlugin(TF_FilesystemPluginInfo* plugin_info);
#ifdef __cplusplus
} // end extern "C"

View File

@ -18,11 +18,10 @@ limitations under the License.
#include <string>
#include <utility>
#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/file_system_helper.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/util/ptr_util.h"
// TODO(mihaimaruseac): After all filesystems are converted, all calls to
@ -165,16 +164,18 @@ Status ModularFileSystem::GetChildren(const std::string& dir,
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
std::string translated_name = TranslateName(dir);
char** children;
// Note that `children` is allocated by the plugin and freed by core
// TensorFlow, so we need to use `plugin_memory_free_` here.
char** children = nullptr;
const int num_children =
ops_->get_children(filesystem_.get(), translated_name.c_str(), &children,
plugin_status.get());
if (num_children >= 0) {
for (int i = 0; i < num_children; i++) {
result->push_back(std::string(children[i]));
free(children[i]);
plugin_memory_free_(children[i]);
}
free(children);
plugin_memory_free_(children);
}
return StatusFromTF_Status(plugin_status.get());
@ -186,15 +187,17 @@ Status ModularFileSystem::GetMatchingPaths(const std::string& pattern,
return internal::GetMatchingPaths(this, Env::Default(), pattern, result);
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
char** matches;
// Note that `matches` is allocated by the plugin and freed by core
// TensorFlow, so we need to use `plugin_memory_free_` here.
char** matches = nullptr;
const int num_matches = ops_->get_matching_paths(
filesystem_.get(), pattern.c_str(), &matches, plugin_status.get());
if (num_matches >= 0) {
for (int i = 0; i < num_matches; i++) {
result->push_back(std::string(matches[i]));
free(matches[i]);
plugin_memory_free_(matches[i]);
}
free(matches);
plugin_memory_free_(matches);
}
return StatusFromTF_Status(plugin_status.get());
@ -358,7 +361,8 @@ std::string ModularFileSystem::TranslateName(const std::string& name) const {
CHECK(p != nullptr) << "TranslateName(" << name << ") returned nullptr";
std::string ret(p);
free(p);
// Since `p` is allocated by plugin, free it using plugin's method.
plugin_memory_free_(p);
return ret;
}
@ -435,4 +439,8 @@ Status ModularWritableFile::Tell(int64* position) {
return StatusFromTF_Status(plugin_status.get());
}
Status RegisterFilesystemPlugin(const std::string& dso_path) {
return filesystem_registration::RegisterFilesystemPluginImpl(dso_path);
}
} // namespace tensorflow

View File

@ -32,7 +32,7 @@ namespace tensorflow {
// TODO(b/143949615): After all filesystems are converted, this file will be
// moved to core/platform, and this class can become a singleton and replace the
// need for `Env::Default()`. At that time, we might decide to remove the need
// for `Env::Default()` altoghether, but that's a different project, not in
// for `Env::Default()` altogether, but that's a different project, not in
// scope for now. I'm just mentioning this here as that transition will mean
// removal of the registration part from `Env` and adding it here instead: we
// will need tables to hold for each scheme the function tables that implement
@ -46,12 +46,16 @@ class ModularFileSystem final : public FileSystem {
std::unique_ptr<const TF_RandomAccessFileOps> random_access_file_ops,
std::unique_ptr<const TF_WritableFileOps> writable_file_ops,
std::unique_ptr<const TF_ReadOnlyMemoryRegionOps>
read_only_memory_region_ops)
read_only_memory_region_ops,
std::function<void*(size_t)> plugin_memory_allocate,
std::function<void(void*)> plugin_memory_free)
: filesystem_(std::move(filesystem)),
ops_(std::move(filesystem_ops)),
random_access_file_ops_(std::move(random_access_file_ops)),
writable_file_ops_(std::move(writable_file_ops)),
read_only_memory_region_ops_(std::move(read_only_memory_region_ops)) {}
read_only_memory_region_ops_(std::move(read_only_memory_region_ops)),
plugin_memory_allocate_(std::move(plugin_memory_allocate)),
plugin_memory_free_(std::move(plugin_memory_free)) {}
~ModularFileSystem() override { ops_->cleanup(filesystem_.get()); }
@ -93,6 +97,8 @@ class ModularFileSystem final : public FileSystem {
std::unique_ptr<const TF_WritableFileOps> writable_file_ops_;
std::unique_ptr<const TF_ReadOnlyMemoryRegionOps>
read_only_memory_region_ops_;
std::function<void*(size_t)> plugin_memory_allocate_;
std::function<void(void*)> plugin_memory_free_;
TF_DISALLOW_COPY_AND_ASSIGN(ModularFileSystem);
};
@ -156,6 +162,9 @@ class ModularReadOnlyMemoryRegion final : public ReadOnlyMemoryRegion {
TF_DISALLOW_COPY_AND_ASSIGN(ModularReadOnlyMemoryRegion);
};
// Registers a filesystem plugin so that core TensorFlow can use it.
Status RegisterFilesystemPlugin(const std::string& dso_path);
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_H_

View File

@ -0,0 +1,346 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
// Checks that all schemes provided by a plugin are valid.
// TODO(mihaimaruseac): More validation could be done here, based on supported
// charset, maximum length, etc. Punting it for later.
static Status ValidateScheme(const char* scheme) {
if (scheme == nullptr)
return errors::InvalidArgument(
"Attempted to register filesystem with `nullptr` URI scheme");
return Status::OK();
}
// Checks if the plugin and core ABI numbers match.
//
// If the numbers don't match, plugin cannot be loaded.
static Status CheckABI(int pluginABI, int coreABI, StringPiece where) {
if (pluginABI != coreABI)
return errors::FailedPrecondition(
strings::StrCat("Plugin ABI (", pluginABI, ") for ", where,
" operations doesn't match expected core ABI (",
coreABI, "). Plugin cannot be loaded."));
return Status::OK();
}
// Checks if the plugin and core ABI numbers match, for all operations.
//
// If the numbers don't match, plugin cannot be loaded.
//
// Uses the simpler `CheckABI(int, int, StringPiece)`.
static Status ValidateABI(const TF_FilesystemPluginOps* ops) {
TF_RETURN_IF_ERROR(
CheckABI(ops->filesystem_ops_abi, TF_FILESYSTEM_OPS_ABI, "filesystem"));
if (ops->random_access_file_ops != nullptr)
TF_RETURN_IF_ERROR(CheckABI(ops->random_access_file_ops_abi,
TF_RANDOM_ACCESS_FILE_OPS_ABI,
"random access file"));
if (ops->writable_file_ops != nullptr)
TF_RETURN_IF_ERROR(CheckABI(ops->writable_file_ops_abi,
TF_WRITABLE_FILE_OPS_ABI, "writable file"));
if (ops->read_only_memory_region_ops != nullptr)
TF_RETURN_IF_ERROR(CheckABI(ops->read_only_memory_region_ops_abi,
TF_READ_ONLY_MEMORY_REGION_OPS_ABI,
"read only memory region"));
return Status::OK();
}
// Checks if the plugin and core API numbers match, logging mismatches.
static void CheckAPI(int plugin_API, int core_API, StringPiece where) {
if (plugin_API != core_API) {
VLOG(0) << "Plugin API (" << plugin_API << ") for " << where
<< " operations doesn't match expected core API (" << core_API
<< "). Plugin will be loaded but functionality might be missing.";
}
}
// Checks if the plugin and core API numbers match, for all operations.
//
// Uses the simpler `CheckAPIHelper(int, int, StringPiece)`.
static void ValidateAPI(const TF_FilesystemPluginOps* ops) {
CheckAPI(ops->filesystem_ops_api, TF_FILESYSTEM_OPS_API, "filesystem");
if (ops->random_access_file_ops != nullptr)
CheckAPI(ops->random_access_file_ops_api, TF_RANDOM_ACCESS_FILE_OPS_API,
"random access file");
if (ops->writable_file_ops != nullptr)
CheckAPI(ops->writable_file_ops_api, TF_WRITABLE_FILE_OPS_API,
"writable file");
if (ops->read_only_memory_region_ops != nullptr)
CheckAPI(ops->read_only_memory_region_ops_api,
TF_READ_ONLY_MEMORY_REGION_OPS_API, "read only memory region");
}
// Validates the filesystem operations supplied by the plugin.
static Status ValidateHelper(const TF_FilesystemOps* ops) {
if (ops == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without operations");
if (ops->init == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `init` operation");
if (ops->cleanup == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `cleanup` operation");
return Status::OK();
}
// Validates the random access file operations supplied by the plugin.
static Status ValidateHelper(const TF_RandomAccessFileOps* ops) {
if (ops == nullptr) {
// We allow filesystems where files can only be written to (from TF code)
return Status::OK();
}
if (ops->cleanup == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `cleanup` operation on random "
"access files");
return Status::OK();
}
// Validates the writable file operations supplied by the plugin.
static Status ValidateHelper(const TF_WritableFileOps* ops) {
if (ops == nullptr) {
// We allow read-only filesystems
return Status::OK();
}
if (ops->cleanup == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `cleanup` operation on writable "
"files");
return Status::OK();
}
// Validates the read only memory region operations given by the plugin.
static Status ValidateHelper(const TF_ReadOnlyMemoryRegionOps* ops) {
if (ops == nullptr) {
// read only memory region support is always optional
return Status::OK();
}
if (ops->cleanup == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `cleanup` operation on read "
"only memory regions");
if (ops->data == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `data` operation on read only "
"memory regions");
if (ops->length == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `length` operation on read only "
"memory regions");
return Status::OK();
}
// Validates the operations supplied by the plugin.
//
// Uses the 4 simpler `ValidateHelper(const TF_...*)` to validate each
// individual function table and then checks that the function table for a
// specific file type exists if the plugin offers support for creating that
// type of files.
static Status ValidateOperations(const TF_FilesystemPluginOps* ops) {
TF_RETURN_IF_ERROR(ValidateHelper(ops->filesystem_ops));
TF_RETURN_IF_ERROR(ValidateHelper(ops->random_access_file_ops));
TF_RETURN_IF_ERROR(ValidateHelper(ops->writable_file_ops));
TF_RETURN_IF_ERROR(ValidateHelper(ops->read_only_memory_region_ops));
if (ops->filesystem_ops->new_random_access_file != nullptr &&
ops->random_access_file_ops == nullptr)
return errors::FailedPrecondition(
"Filesystem allows creation of random access files but no "
"operations on them have been supplied.");
if ((ops->filesystem_ops->new_writable_file != nullptr ||
ops->filesystem_ops->new_appendable_file != nullptr) &&
ops->writable_file_ops == nullptr)
return errors::FailedPrecondition(
"Filesystem allows creation of writable files but no "
"operations on them have been supplied.");
if (ops->filesystem_ops->new_read_only_memory_region_from_file != nullptr &&
ops->read_only_memory_region_ops == nullptr)
return errors::FailedPrecondition(
"Filesystem allows creation of readonly memory regions but no "
"operations on them have been supplied.");
return Status::OK();
}
// Copies a function table from plugin memory space to core memory space.
//
// This has three benefits:
// * allows having newer plugins than the current core TensorFlow: the
// additional entries in the plugin's table are just discarded;
// * allows having older plugins than the current core TensorFlow (though
// we are still warning users): the entries that core TensorFlow expects
// but plugins didn't provide will be set to `nullptr` values and core
// TensorFlow will know to not call these on behalf of users;
// * increased security as plugins will not be able to alter function table
// after loading up. Thus, malicious plugins can't alter functionality to
// probe for gadgets inside core TensorFlow. We can even protect the area
// of memory where the copies reside to not allow any more writes to it
// after all copies are created.
template <typename T>
static std::unique_ptr<const T> CopyToCore(const T* plugin_ops,
size_t plugin_size) {
if (plugin_ops == nullptr) return nullptr;
size_t copy_size = std::min(plugin_size, sizeof(T));
auto core_ops = tensorflow::MakeUnique<T>();
memset(core_ops.get(), 0, sizeof(T));
memcpy(core_ops.get(), plugin_ops, copy_size);
return core_ops;
}
// Registers one filesystem from the plugin.
//
// Must be called only with `index` a valid index in `info->ops`.
static Status RegisterFileSystem(const TF_FilesystemPluginInfo* info,
int index) {
// Step 1: Copy all the function tables to core TensorFlow memory space
auto core_filesystem_ops = CopyToCore<TF_FilesystemOps>(
info->ops[index].filesystem_ops, info->ops[index].filesystem_ops_size);
auto core_random_access_file_ops = CopyToCore<TF_RandomAccessFileOps>(
info->ops[index].random_access_file_ops,
info->ops[index].random_access_file_ops_size);
auto core_writable_file_ops =
CopyToCore<TF_WritableFileOps>(info->ops[index].writable_file_ops,
info->ops[index].writable_file_ops_size);
auto core_read_only_memory_region_ops =
CopyToCore<TF_ReadOnlyMemoryRegionOps>(
info->ops[index].read_only_memory_region_ops,
info->ops[index].read_only_memory_region_ops_size);
// Step 2: Initialize the opaque filesystem structure
auto filesystem = tensorflow::MakeUnique<TF_Filesystem>();
TF_Status* c_status = TF_NewStatus();
Status status = Status::OK();
core_filesystem_ops->init(filesystem.get(), c_status);
status = Status(c_status->status);
TF_DeleteStatus(c_status);
if (!status.ok()) return status;
// Step 3: Actual registration
return Env::Default()->RegisterFileSystem(
info->ops[index].scheme,
tensorflow::MakeUnique<tensorflow::ModularFileSystem>(
std::move(filesystem), std::move(core_filesystem_ops),
std::move(core_random_access_file_ops),
std::move(core_writable_file_ops),
std::move(core_read_only_memory_region_ops),
info->plugin_memory_allocate, info->plugin_memory_free));
}
// Registers filesystem at `index`, if plugin is providing valid information.
//
// Extracted to a separate function so that pointers inside `info` are freed
// by the caller regardless of whether validation/registration failed or not.
//
// Must be called only with `index` a valid index in `info->ops`.
static Status ValidateAndRegisterFilesystems(
const TF_FilesystemPluginInfo* info, int index) {
TF_RETURN_IF_ERROR(ValidateScheme(info->ops[index].scheme));
TF_RETURN_IF_ERROR(ValidateABI(&info->ops[index]));
ValidateAPI(&info->ops[index]); // we just warn on API number mismatch
TF_RETURN_IF_ERROR(ValidateOperations(&info->ops[index]));
TF_RETURN_IF_ERROR(RegisterFileSystem(info, index));
return Status::OK();
}
// Ensures that the plugin provides the required memory management operations.
static Status ValidatePluginMemoryRoutines(
const TF_FilesystemPluginInfo* info) {
if (info->plugin_memory_allocate == nullptr)
return errors::FailedPrecondition(
"Cannot load filesystem plugin which does not provide "
"`plugin_memory_allocate`");
if (info->plugin_memory_free == nullptr)
return errors::FailedPrecondition(
"Cannot load filesystem plugin which does not provide "
"`plugin_memory_free`");
return Status::OK();
}
namespace filesystem_registration {
Status RegisterFilesystemPluginImpl(const std::string& dso_path) {
// Step 1: Load plugin
Env* env = Env::Default();
void* dso_handle;
TF_RETURN_IF_ERROR(env->LoadLibrary(dso_path.c_str(), &dso_handle));
// Step 2: Load symbol for `TF_InitPlugin`
void* dso_symbol;
TF_RETURN_IF_ERROR(
env->GetSymbolFromLibrary(dso_handle, "TF_InitPlugin", &dso_symbol));
// Step 3: Call `TF_InitPlugin`
TF_FilesystemPluginInfo info;
memset(&info, 0, sizeof(info));
auto TF_InitPlugin =
reinterpret_cast<int (*)(TF_FilesystemPluginInfo*)>(dso_symbol);
TF_InitPlugin(&info);
// Step 4: Ensure plugin provides the memory management functions.
TF_RETURN_IF_ERROR(ValidatePluginMemoryRoutines(&info));
// Step 5: Validate and register all filesystems
// Try to register as many filesystems as possible.
// Free memory once we no longer need it
Status status;
for (int i = 0; i < info.num_schemes; i++) {
status.Update(ValidateAndRegisterFilesystems(&info, i));
info.plugin_memory_free(info.ops[i].scheme);
info.plugin_memory_free(info.ops[i].filesystem_ops);
info.plugin_memory_free(info.ops[i].random_access_file_ops);
info.plugin_memory_free(info.ops[i].writable_file_ops);
info.plugin_memory_free(info.ops[i].read_only_memory_region_ops);
}
info.plugin_memory_free(info.ops);
return status;
}
} // namespace filesystem_registration
} // namespace tensorflow

View File

@ -0,0 +1,28 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
namespace filesystem_registration {
Status RegisterFilesystemPluginImpl(const std::string& dso_path);
} // namespace filesystem_registration
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_

View File

@ -24,8 +24,6 @@ limitations under the License.
#include <sys/stat.h>
#include <unistd.h>
#include <vector>
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_helper.h"
#include "tensorflow/c/tf_status.h"
@ -33,6 +31,9 @@ limitations under the License.
// Implementation of a filesystem for POSIX environments.
// This filesystem will support `file://` and empty (local) URI schemes.
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); }
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {
@ -45,7 +46,9 @@ typedef struct PosixFile {
static void Cleanup(TF_RandomAccessFile* file) {
auto posix_file = static_cast<PosixFile*>(file->plugin_file);
close(posix_file->fd);
free(const_cast<char*>(posix_file->filename));
// This would be safe to free using `free` directly as it is only opaque.
// However, it is better to be consistent everywhere.
plugin_memory_free(const_cast<char*>(posix_file->filename));
delete posix_file;
}
@ -100,7 +103,7 @@ typedef struct PosixFile {
static void Cleanup(TF_WritableFile* file) {
auto posix_file = static_cast<PosixFile*>(file->plugin_file);
free(const_cast<char*>(posix_file->filename));
plugin_memory_free(const_cast<char*>(posix_file->filename));
delete posix_file;
}
@ -383,12 +386,13 @@ static int GetChildren(const TF_Filesystem* filesystem, const char* path,
if (num_entries < 0) {
TF_SetStatusFromIOError(status, errno, path);
} else {
*entries = static_cast<char**>(calloc(num_entries, sizeof((*entries)[0])));
*entries = static_cast<char**>(
plugin_memory_allocate(num_entries * sizeof((*entries)[0])));
for (int i = 0; i < num_entries; i++) {
(*entries)[i] = strdup(dir_entries[i]->d_name);
free(dir_entries[i]);
plugin_memory_free(dir_entries[i]);
}
free(dir_entries);
plugin_memory_free(dir_entries);
}
return num_entries;
@ -396,48 +400,59 @@ static int GetChildren(const TF_Filesystem* filesystem, const char* path,
} // namespace tf_posix_filesystem
void TF_InitPlugin(TF_Status* status) {
TF_RandomAccessFileOps random_access_file_ops = {
tf_random_access_file::Cleanup,
tf_random_access_file::Read,
};
TF_WritableFileOps writable_file_ops = {
tf_writable_file::Cleanup, tf_writable_file::Append,
tf_writable_file::Tell, tf_writable_file::Flush,
tf_writable_file::Sync, tf_writable_file::Close,
};
TF_ReadOnlyMemoryRegionOps read_only_memory_region_ops = {
tf_read_only_memory_region::Cleanup,
tf_read_only_memory_region::Data,
tf_read_only_memory_region::Length,
};
TF_FilesystemOps filesystem_ops = {
tf_posix_filesystem::Init,
tf_posix_filesystem::Cleanup,
tf_posix_filesystem::NewRandomAccessFile,
tf_posix_filesystem::NewWritableFile,
tf_posix_filesystem::NewAppendableFile,
tf_posix_filesystem::NewReadOnlyMemoryRegionFromFile,
tf_posix_filesystem::CreateDir,
/*recursively_create_dir=*/nullptr,
tf_posix_filesystem::DeleteFile,
tf_posix_filesystem::DeleteDir,
/*delete_recursively=*/nullptr,
tf_posix_filesystem::RenameFile,
tf_posix_filesystem::CopyFile,
tf_posix_filesystem::PathExists,
/*paths_exist=*/nullptr,
tf_posix_filesystem::Stat,
/*is_directory=*/nullptr,
/*get_file_size=*/nullptr,
/*translate_name=*/nullptr,
tf_posix_filesystem::GetChildren,
/*get_matching_paths=*/nullptr,
/*flush_caches=*/nullptr,
};
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
const char* uri) {
TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri);
for (const char* scheme : {"", "file"})
TF_REGISTER_FILESYSTEM_PLUGIN(scheme, &filesystem_ops,
&random_access_file_ops, &writable_file_ops,
&read_only_memory_region_ops, status);
ops->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup;
ops->random_access_file_ops->read = tf_random_access_file::Read;
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
ops->writable_file_ops->append = tf_writable_file::Append;
ops->writable_file_ops->tell = tf_writable_file::Tell;
ops->writable_file_ops->flush = tf_writable_file::Flush;
ops->writable_file_ops->sync = tf_writable_file::Sync;
ops->writable_file_ops->close = tf_writable_file::Close;
ops->read_only_memory_region_ops = static_cast<TF_ReadOnlyMemoryRegionOps*>(
plugin_memory_allocate(TF_READ_ONLY_MEMORY_REGION_OPS_SIZE));
ops->read_only_memory_region_ops->cleanup =
tf_read_only_memory_region::Cleanup;
ops->read_only_memory_region_ops->data = tf_read_only_memory_region::Data;
ops->read_only_memory_region_ops->length = tf_read_only_memory_region::Length;
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_posix_filesystem::Init;
ops->filesystem_ops->cleanup = tf_posix_filesystem::Cleanup;
ops->filesystem_ops->new_random_access_file =
tf_posix_filesystem::NewRandomAccessFile;
ops->filesystem_ops->new_writable_file = tf_posix_filesystem::NewWritableFile;
ops->filesystem_ops->new_appendable_file =
tf_posix_filesystem::NewAppendableFile;
ops->filesystem_ops->new_read_only_memory_region_from_file =
tf_posix_filesystem::NewReadOnlyMemoryRegionFromFile;
ops->filesystem_ops->create_dir = tf_posix_filesystem::CreateDir;
ops->filesystem_ops->delete_file = tf_posix_filesystem::DeleteFile;
ops->filesystem_ops->delete_dir = tf_posix_filesystem::DeleteDir;
ops->filesystem_ops->rename_file = tf_posix_filesystem::RenameFile;
ops->filesystem_ops->copy_file = tf_posix_filesystem::CopyFile;
ops->filesystem_ops->path_exists = tf_posix_filesystem::PathExists;
ops->filesystem_ops->stat = tf_posix_filesystem::Stat;
ops->filesystem_ops->get_children = tf_posix_filesystem::GetChildren;
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
info->plugin_memory_allocate = plugin_memory_allocate;
info->plugin_memory_free = plugin_memory_free;
info->num_schemes = 2;
info->ops = static_cast<TF_FilesystemPluginOps*>(
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
ProvideFilesystemSupportFor(&info->ops[0], "");
ProvideFilesystemSupportFor(&info->ops[1], "file");
}

View File

@ -44,7 +44,7 @@ int TransferFileContents(const char* src, const char* dst, mode_t mode,
}
// Both files have been opened, do the transfer.
// Since errno would be overriden by `close` below, save it here.
// Since errno would be overridden by `close` below, save it here.
int error_code = 0;
if (CopyFileContents(dst_fd, src_fd, size) < 0) error_code = errno;

View File

@ -0,0 +1,36 @@
# Experimental windows filesystem plugin.
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
package(
licenses = ["notice"], # Apache 2.0
)
# Filesystem implementation for Windows environment
tf_cc_shared_object(
name = "windows_filesystem.dll",
framework_so = [],
linkstatic = False,
tags = [
"manual",
"nobuilder",
"notap",
],
visibility = ["//visibility:public"],
deps = [":windows_filesystem_impl"],
)
# The real implementation of the filesystem.
cc_library(
name = "windows_filesystem_impl",
srcs = ["windows_filesystem.cc"],
copts = get_win_copts(),
tags = [
"manual",
"nobuilder",
"notap",
],
deps = [
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
],
)

View File

@ -0,0 +1,73 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stdlib.h>
#include <string.h>
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h"
// Implementation of a filesystem for POSIX environments.
// This filesystem will support `file://` and empty (local) URI schemes.
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); }
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {
// TODO(mihaimaruseac): Implement later
} // namespace tf_random_access_file
// SECTION 2. Implementation for `TF_WritableFile`
// ----------------------------------------------------------------------------
namespace tf_writable_file {
// TODO(mihaimaruseac): Implement later
} // namespace tf_writable_file
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
// ----------------------------------------------------------------------------
namespace tf_read_only_memory_region {
// TODO(mihaimaruseac): Implement later
} // namespace tf_read_only_memory_region
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
// ----------------------------------------------------------------------------
namespace tf_windows_filesystem {
// TODO(mihaimaruseac): Implement later
} // namespace tf_windows_filesystem
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
const char* uri) {
TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri);
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
info->plugin_memory_allocate = plugin_memory_allocate;
info->plugin_memory_free = plugin_memory_free;
info->num_schemes = 2;
info->ops = static_cast<TF_FilesystemPluginOps*>(
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
ProvideFilesystemSupportFor(&info->ops[0], "");
ProvideFilesystemSupportFor(&info->ops[1], "file");
}

View File

@ -18,19 +18,36 @@ limitations under the License.
#include "tensorflow/c/kernels.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include <stddef.h>
#include <stdint.h>
#include <string.h>
#include <memory>
#include <string>
#include "absl/container/inlined_vector.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
struct MyCustomKernel {
bool created;

View File

@ -133,7 +133,7 @@ TEST(OpsTest, TestShapeInference_VectorizeFunction) {
TEST(OpsTest, AttributeAccessors) {
TF_OpDefinitionBuilder* builder =
TF_NewOpDefinitionBuilder("AttributeAccesorsOp");
TF_NewOpDefinitionBuilder("AttributeAccessorsOp");
TF_OpDefinitionBuilderAddAttr(builder, "foo1: int >= 2");
TF_OpDefinitionBuilderAddAttr(builder, "foo2: string=\"my string\"");
TF_OpDefinitionBuilderSetIsCommutative(builder, true);
@ -151,7 +151,7 @@ TEST(OpsTest, AttributeAccessors) {
op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
bool found = false;
for (const auto& op : op_list.op()) {
if (op.name() == "AttributeAccesorsOp") {
if (op.name() == "AttributeAccessorsOp") {
ASSERT_TRUE(op.is_commutative());
ASSERT_TRUE(op.is_aggregate());
ASSERT_TRUE(op.allows_uninitialized_input());

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h"
#include <memory>
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor_internal.h"
@ -103,49 +105,35 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
}
TF_Tensor* ret =
new TF_Tensor{Tensor(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimvec), buf)};
// TODO(gjn): Make the choice of interface a compile-time configuration.
tensorflow::TensorInterface ret(
Tensor(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimvec), buf));
buf->Unref();
size_t elem_size = TF_DataTypeSize(dtype);
if (elem_size > 0 && len < (elem_size * ret->tensor.NumElements())) {
delete ret;
if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
return nullptr;
}
return ret;
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(ret)};
}
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) {
// It is safe to move the Tensor if and only if we own the unique reference to
// it. In that case, we might as well not delete and reallocate, but a future
// implementation might need to do so.
TensorBuffer* buf = tensorflow::TensorCApi::Buffer(tensor->tensor);
if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
buf->OwnsMemory()) {
return tensor;
}
return nullptr;
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* t) {
return t->tensor->CanMove() ? t : nullptr;
}
void TF_DeleteTensor(TF_Tensor* t) { delete t; }
TF_DataType TF_TensorType(const TF_Tensor* t) {
return static_cast<TF_DataType>(t->tensor.dtype());
}
TF_DataType TF_TensorType(const TF_Tensor* t) { return t->tensor->Type(); }
int TF_NumDims(const TF_Tensor* t) { return t->tensor.dims(); }
int TF_NumDims(const TF_Tensor* t) { return t->tensor->NumDims(); }
int64_t TF_Dim(const TF_Tensor* t, int dim_index) {
return static_cast<int64_t>(t->tensor.dim_size(dim_index));
return t->tensor->Dim(dim_index);
}
size_t TF_TensorByteSize(const TF_Tensor* t) {
return tensorflow::TensorCApi::Buffer(t->tensor)->size();
}
size_t TF_TensorByteSize(const TF_Tensor* t) { return t->tensor->ByteSize(); }
void* TF_TensorData(const TF_Tensor* t) {
return tensorflow::TensorCApi::Buffer(t->tensor)->data();
}
void* TF_TensorData(const TF_Tensor* t) { return t->tensor->Data(); }
int64_t TF_TensorElementCount(const TF_Tensor* t) {
int64_t result = 1;
@ -160,15 +148,63 @@ void TF_TensorBitcastFrom(const TF_Tensor* from, TF_DataType type,
TF_Tensor* to, const int64_t* new_dims,
int num_new_dims, TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
Status cc_status(
static_cast<tensorflow::TensorInterface*>(to->tensor.get())
->BitcastFrom(*static_cast<const tensorflow::TensorInterface*>(
from->tensor.get()),
type, new_dims, num_new_dims));
Set_TF_Status_from_Status(status, cc_status);
}
namespace tensorflow {
bool TensorInterface::CanMove() const {
// It is safe to move the Tensor if and only if we own the unique reference to
// it. In that case, we might as well not delete and reallocate, but a future
// implementation might need to do so.
TensorBuffer* buf = tensorflow::TensorCApi::Buffer(tensor_);
if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
buf->OwnsMemory()) {
return true;
}
return false;
}
TF_DataType TensorInterface::Type() const {
return static_cast<TF_DataType>(tensor_.dtype());
}
int TensorInterface::NumDims() const { return tensor_.dims(); }
int64_t TensorInterface::Dim(int dim_index) const {
return static_cast<int64_t>(tensor_.dim_size(dim_index));
}
int64_t TensorInterface::NumElements() const {
return static_cast<int64_t>(tensor_.NumElements());
}
size_t TensorInterface::ByteSize() const {
return tensorflow::TensorCApi::Buffer(tensor_)->size();
}
void* TensorInterface::Data() const {
return tensorflow::TensorCApi::Buffer(tensor_)->data();
}
Status TensorInterface::BitcastFrom(const TensorInterface& from,
TF_DataType type, const int64_t* new_dims,
int num_new_dims) {
tensorflow::TensorShape s;
for (int i = 0; i < num_new_dims; ++i) {
s.AddDim(new_dims[i]);
}
Status cc_status(to->tensor.BitcastFrom(
from->tensor, static_cast<tensorflow::DataType>(type), s));
Set_TF_Status_from_Status(status, cc_status);
return tensor_.BitcastFrom(from.tensor_,
static_cast<tensorflow::DataType>(type), s);
}
} // namespace tensorflow
// --------------------------------------------------------------------------
void StringEncode(const char* src, size_t src_len, char* dst) {
dst = tensorflow::core::EncodeVarint64(dst, src_len);
@ -277,12 +313,11 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) {
return t;
}
if (src.dtype() != tensorflow::DT_STRING) {
auto* result = new TF_Tensor();
if (!result->tensor.CopyFrom(src, src.shape())) {
delete result;
Tensor tensor;
if (!tensor.CopyFrom(src, src.shape())) {
return nullptr;
}
return result;
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(tensor)};
}
// DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly
// encoded sequence of strings.
@ -332,31 +367,35 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) {
}
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
if (src->tensor.dtype() == DT_RESOURCE) {
if (src->tensor.dims() != 0) {
return static_cast<const tensorflow::TensorInterface*>(src->tensor.get())
->ToTensor(dst);
}
Status TensorInterface::ToTensor(Tensor* dst) const {
if (tensor_.dtype() == DT_RESOURCE) {
if (tensor_.dims() != 0) {
return InvalidArgument(
"Malformed TF_RESOURCE tensor: expected a scalar, got a tensor with "
"shape ",
src->tensor.shape().DebugString());
tensor_.shape().DebugString());
}
*dst = Tensor(tensorflow::DT_RESOURCE, src->tensor.shape());
*dst = Tensor(tensorflow::DT_RESOURCE, tensor_.shape());
if (!dst->scalar<tensorflow::ResourceHandle>()().ParseFromString(
string(static_cast<const char*>(TF_TensorData(src)),
TF_TensorByteSize(src)))) {
string(static_cast<const char*>(Data()), ByteSize()))) {
return InvalidArgument(
"Malformed TF_RESOUCE tensor: unable to parse resource handle");
"Malformed TF_RESOURCE tensor: unable to parse resource handle");
}
return Status::OK();
}
if (src->tensor.dtype() != DT_STRING) {
*dst = src->tensor;
if (tensor_.dtype() != DT_STRING) {
*dst = tensor_;
return Status::OK();
}
// TF_STRING tensors require copying since Tensor class expects a sequence of
// string objects.
const tensorflow::int64 num_elements = src->tensor.NumElements();
const char* input = reinterpret_cast<const char*>(TF_TensorData(src));
const size_t src_size = TF_TensorByteSize(src);
const tensorflow::int64 num_elements = tensor_.NumElements();
const char* input = reinterpret_cast<const char*>(Data());
const size_t src_size = ByteSize();
if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) <
num_elements) {
return InvalidArgument(
@ -365,7 +404,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
const char* limit = input + src_size;
*dst = Tensor(src->tensor.dtype(), src->tensor.shape());
*dst = Tensor(tensor_.dtype(), tensor_.shape());
auto dstarray = dst->flat<tstring>();
for (tensorflow::int64 i = 0; i < num_elements; ++i) {
tensorflow::uint64 offset =
@ -384,8 +423,8 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
return Status::OK();
}
bool TensorInterface::IsAligned() const { return tensor_.IsAligned(); }
} // namespace tensorflow
bool TF_TensorIsAligned(const TF_Tensor* tensor) {
return tensor->tensor.IsAligned();
}
bool TF_TensorIsAligned(const TF_Tensor* t) { return t->tensor->IsAligned(); }

View File

@ -16,9 +16,12 @@ limitations under the License.
#ifndef TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
#define TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
#include <memory>
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_interface.h"
#include "tensorflow/core/framework/tensor_shape.h"
// Internal structures used by the C API. These are likely to change and should
@ -28,7 +31,7 @@ limitations under the License.
// passed to or returned from C functions *by pointer*. Otherwise, changes to
// its internal structure will break the C API's binary interface.
typedef struct TF_Tensor {
::tensorflow::Tensor tensor;
std::unique_ptr<AbstractTensorInterface> tensor;
} TF_Tensor;
class TF_ManagedBuffer : public tensorflow::TensorBuffer {
@ -83,4 +86,5 @@ void* allocate_tensor(const char* operation, size_t len, Allocator* allocator);
// a different Allocator as `arg`.
void deallocate_buffer(void* data, size_t len, void* arg);
} // namespace tensorflow
#endif // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_

View File

@ -41,6 +41,16 @@ filegroup(
],
)
filegroup(
name = "pywrap_required_hdrs",
srcs = [
"training/coordinator.h",
],
visibility = [
"//tensorflow/python:__pkg__",
],
)
cc_library(
name = "gradients",
srcs = [

View File

@ -96,7 +96,7 @@ class SymbolicGradientBuilder {
// Used to identify nodes at which to stop backprop.
std::unordered_set<int> GetStopBackpropNodes(
const std::vector<bool>& reachable_nodes,
const std::unordered_set<int>& output_nodes);
const std::unordered_set<int>& output_nodes) const;
const Scope& scope_;
const ops::GradOpRegistry* registry_;
@ -190,7 +190,7 @@ std::vector<bool> SymbolicGradientBuilder::GetReachableNodes() {
std::unordered_set<int> SymbolicGradientBuilder::GetStopBackpropNodes(
const std::vector<bool>& reachable_nodes,
const std::unordered_set<int>& output_nodes) {
const std::unordered_set<int>& output_nodes) const {
// Output nodes that get transitively consumed by other `outputs_` are stored
// in `internal_outputs`.
std::unordered_set<int> internal_outputs;
@ -346,8 +346,8 @@ Status SymbolicGradientBuilder::SumGradients(const Output& src, Output* grad) {
"Unable to find backprop list for node.id ", src.node()->name());
}
const auto& grads = iter->second;
// Filter any backproped 'NoGradient' Outputs from 'grads' (if needed).
// Return any valid backproped gradients that remain after filtering,
// Filter any backpropped 'NoGradient' Outputs from 'grads' (if needed).
// Return any valid backpropped gradients that remain after filtering,
// or 'NoGradient' otherwise.
std::vector<Output> grads_to_keep;
for (const Output& o : grads) {
@ -519,7 +519,7 @@ Status SymbolicGradientBuilder::AddGradients() {
// Backprop along the in edges.
// TODO(andydavis) Find cleaner way to map each grad output returned by
// gradient function to the src node/output to which it should be
// backproped. Maybe grad functions can return a vector of Output pairs to
// backpropped. Maybe grad functions can return a vector of Output pairs to
// make this association explicit.
size_t dx_index = 0;
for (const Edge* e : n->in_edges()) {

View File

@ -64,7 +64,7 @@ bool IsZero(const Scope& scope, const Output& grad) {
// Multiply after broadcasting vec to match dimensions of mat.
// Args:
// vec: A 1-D tensor of dimension [D0]
// mat: A 2-D tensor of dimesnion [D0, D1]
// mat: A 2-D tensor of dimension [D0, D1]
//
// Returns:
// A tensor of dimension [D0, D1], the result fo vec * mat.

View File

@ -124,13 +124,12 @@ cc_library(
hdrs = ["bundle_v2.h"],
deps = [
":constants",
"@com_google_absl//absl/container:flat_hash_set",
] + if_not_mobile([
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:strcat",
"//tensorflow/core/util/tensor_bundle",
]),
"@com_google_absl//absl/container:flat_hash_set",
],
)
tf_cc_test(

View File

@ -1,5 +1,6 @@
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
load("//tensorflow/core/platform:build_config.bzl", "if_llvm_aarch64_available")
package(
default_visibility = ["//visibility:private"],
@ -27,9 +28,15 @@ cc_library(
"compile.h",
"flags.h",
],
defines = if_llvm_aarch64_available(["TF_LLVM_AARCH64_AVAILABLE=1"]),
visibility = ["//tensorflow/python:__pkg__"],
deps = [
":aot_only_var_handle_op",
":embedded_protocol_buffers",
"@com_google_absl//absl/base",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"//tensorflow/compiler/tf2xla",
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
@ -53,10 +60,13 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@llvm-project//llvm:target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
] + if_llvm_aarch64_available([
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
]),
)
tf_cc_test(
@ -86,6 +96,19 @@ tf_cc_binary(
deps = [":tfcompile_main"],
)
cc_library(
name = "llvm_targets",
visibility = ["//tensorflow/python:__pkg__"],
deps = [
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@llvm-project//llvm:target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
] + if_llvm_aarch64_available([
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
]),
)
cc_library(
name = "tfcompile_main",
srcs = ["tfcompile_main.cc"],
@ -104,11 +127,6 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:aarch64_code_gen", # fixdeps: keep
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@llvm-project//llvm:target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
],
)
@ -214,8 +232,13 @@ cc_library(
cc_library(
name = "aot_only_var_handle_op",
srcs = ["aot_only_var_handle_op.cc"],
hdrs = ["aot_only_var_handle_op.h"],
visibility = [
"//tensorflow/compiler/tf2xla:__pkg__",
],
deps = [
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/core:framework",
],
alwayslink = 1,
)

View File

@ -13,9 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/aot/aot_only_var_handle_op.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace {
@ -51,6 +54,31 @@ void XlaAotOnlyVarHandleOp::Compile(XlaOpKernelContext* context) {
}
} // namespace
REGISTER_XLA_OP(Name("VarHandleOp").CompilationOnly(), XlaAotOnlyVarHandleOp);
REGISTER_OP(tfcompile::kXlaAotOnlyVarHandleOp)
.Doc(R"doc(
Internal VarHandleOp registration used for XLA AOT compilation.
)doc")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.Attr("dtype: type")
.Attr("shape: shape")
.Output("resource: resource")
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Scalar());
DataType t;
TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t));
PartialTensorShape p;
TF_RETURN_IF_ERROR(c->GetAttr("shape", &p));
shape_inference::ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s));
c->set_output_handle_shapes_and_types(
0, std::vector<shape_inference::ShapeAndType>{{s, t}});
return Status::OK();
});
REGISTER_XLA_OP(Name(tfcompile::kXlaAotOnlyVarHandleOp).CompilationOnly(),
XlaAotOnlyVarHandleOp);
} // namespace tensorflow

View File

@ -0,0 +1,27 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_
#define TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_
namespace tensorflow {
namespace tfcompile {
static constexpr const char* const kXlaAotOnlyVarHandleOp =
"_XlaAotOnlyVarHandleOp";
} // namespace tfcompile
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_

View File

@ -74,16 +74,16 @@ void DumpStatsToStdout(const Stats& stats) {
const int kBufSize = 1000;
char buf[kBufSize];
snprintf(buf, kBufSize, "Mean with %2.0f%% trimmed:", trim_ratio * 100);
const string label_trimmed(buf);
std::string label_trimmed(buf);
snprintf(buf, kBufSize, "Mean of %2.0f%% best:", best_ratio * 100);
const string label_best(buf);
std::vector<std::pair<string, double>> groups = {
std::string label_best(buf);
std::vector<std::pair<std::string, double>> groups = {
{"Best:", sorted_us.front()},
{"Worst:", sorted_us.back()},
{"Median:", sorted_us[count_us / 2]},
{"Mean:", sum_us / count_us},
{label_trimmed, sum_us_trimmed / count_us_trimmed},
{label_best, sum_us_best / count_us_best},
{std::move(label_trimmed), sum_us_trimmed / count_us_trimmed},
{std::move(label_best), sum_us_best / count_us_best},
};
int max_label_size = 0;
double max_us = 0;
@ -102,7 +102,7 @@ void DumpStatsToStdout(const Stats& stats) {
}
// Dump stats out.
printf("Benchmark ran %zu iterations over %lld us\n", count_us,
stats.total_us);
static_cast<long long>(stats.total_us)); // NOLINT
for (const auto& g : groups) {
printf(" %-*s %*.3f us\n", max_label_size, g.first.c_str(), max_digits + 4,
g.second);
@ -114,7 +114,8 @@ void Benchmark(const Options& options, const BenchmarkFn& fn, Stats* stats) {
const int64 max_us = (options.max_micros <= 0 && options.max_iters <= 0)
? Options::kDefaultMicros
: options.max_micros;
printf("Running benchmark for %lld us\n", max_us);
// NOLINTNEXTLINE
printf("Running benchmark for %lld us\n", static_cast<long long>(max_us));
const int64 start_us = NowMicros();
int64 iters = 0;
while (true) {

View File

@ -423,8 +423,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
const string include_xla_data_proto =
opts.gen_program_shape
?
R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
? R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
: "";
const string include_hlo_profile_printer_data_proto =

View File

@ -20,6 +20,9 @@ limitations under the License.
#include <utility>
#include <vector>
#include "absl/base/call_once.h"
#include "llvm-c/Target.h"
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/compiler/tf2xla/tf2xla.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
@ -90,7 +93,7 @@ Status CompileXla(xla::CompileOnlyClient* client,
} // namespace
Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
const MainFlags& flags, CompileResult* compile_result) {
// Converts the graph into an XLA computation, and compiles the
// computation.
@ -108,8 +111,8 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
if (!flags.mlir_components.empty()) {
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
}
TF_RETURN_IF_ERROR(
ConvertGraphDefToXla(graph_def, config, client, &computation));
TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config,
client, &computation));
}
if (!flags.out_session_module.empty()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
@ -132,5 +135,96 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
return CompileXla(client, computation, aot_opts, compile_result);
}
static Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
if (absl::EndsWith(fname, ".pbtxt")) {
return ReadTextProto(Env::Default(), fname, proto);
} else {
return ReadBinaryProto(Env::Default(), fname, proto);
}
}
static absl::once_flag targets_init;
static void InitializeTargets() {
// Initialize all LLVM targets so we can cross compile.
#if TF_LLVM_AARCH64_AVAILABLE
LLVMInitializeAArch64Target();
LLVMInitializeAArch64TargetInfo();
LLVMInitializeAArch64TargetMC();
LLVMInitializeAArch64AsmPrinter();
#endif
LLVMInitializeARMTarget();
LLVMInitializeARMTargetInfo();
LLVMInitializeARMTargetMC();
LLVMInitializeARMAsmPrinter();
LLVMInitializePowerPCTarget();
LLVMInitializePowerPCTargetInfo();
LLVMInitializePowerPCTargetMC();
LLVMInitializePowerPCAsmPrinter();
LLVMInitializeX86Target();
LLVMInitializeX86TargetInfo();
LLVMInitializeX86TargetMC();
LLVMInitializeX86AsmPrinter();
}
Status Main(const MainFlags& flags) {
absl::call_once(targets_init, &InitializeTargets);
// Process config.
tf2xla::Config config;
if (flags.config.empty()) {
return errors::InvalidArgument("Must specify --config");
}
TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
TF_RETURN_IF_ERROR(ValidateConfig(config));
if (flags.dump_fetch_nodes) {
std::set<string> nodes;
for (const tf2xla::Fetch& fetch : config.fetch()) {
nodes.insert(fetch.id().node_name());
}
std::cout << absl::StrJoin(nodes, ",");
return Status::OK();
}
// Read and initialize the graph.
if (flags.graph.empty()) {
return errors::InvalidArgument("Must specify --graph");
}
GraphDef graph_def;
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
CompileResult compile_result;
TF_RETURN_IF_ERROR(
CompileGraph(std::move(graph_def), config, flags, &compile_result));
// Write output files.
Env* env = Env::Default();
const std::vector<char>& obj = compile_result.aot->object_file_data();
TF_RETURN_IF_ERROR(
WriteStringToFile(env, flags.out_function_object,
absl::string_view(obj.data(), obj.size())));
CodegenOpts codegen_opts;
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
codegen_opts.gen_program_shape = flags.gen_program_shape;
codegen_opts.target_triple = flags.target_triple;
if (flags.cpp_class.empty()) {
return errors::InvalidArgument("Must specify --cpp_class");
}
codegen_opts.gen_hlo_profile_printer_data =
xla::GetDebugOptionsFromFlags().xla_hlo_profile();
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
&codegen_opts.namespaces));
MetadataResult metadata_result;
TF_RETURN_IF_ERROR(
GenerateMetadata(codegen_opts, compile_result, &metadata_result));
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object,
metadata_result.object_file_data));
string header;
TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result,
metadata_result, &header));
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
return Status::OK();
}
} // namespace tfcompile
} // namespace tensorflow

View File

@ -42,9 +42,12 @@ struct CompileResult {
// that performs the graph operations.
//
// The XLA compilation options are specified in the flags.
Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
const MainFlags& flags, CompileResult* compile_result);
// The full compilation method, for reuse in a library setting.
Status Main(const MainFlags& flags);
} // namespace tfcompile
} // namespace tensorflow

View File

@ -25,6 +25,7 @@ namespace tensorflow {
namespace tfcompile {
// Flags for the tfcompile binary. See *.cc file for descriptions.
struct MainFlags {
string graph;
string config;

View File

@ -25,6 +25,7 @@ test_suite(
":test_graph_tfmatmulandadd_test",
":test_graph_tfsplits_test",
":test_graph_tftop_k_test",
":test_graph_tfvariable_readonly_test",
":test_graph_tfvariable_sequential_updates_test",
":test_graph_tfvariable_test",
":tfcompile_test",
@ -73,6 +74,7 @@ genrule(
"test_graph_tfsplits.pb",
"test_graph_tftop_k.pb",
"test_graph_tfvariable.pb",
"test_graph_tfvariable_readonly.pb",
"test_graph_tfvariable_sequential_updates.pb",
],
# Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any
@ -238,6 +240,17 @@ tf_library(
],
)
tf_library(
name = "test_graph_tfvariable_readonly",
testonly = 1,
config = "test_graph_tfvariable_readonly.config.pbtxt",
cpp_class = "VariableReadonlyComp",
graph = "test_graph_tfvariable_readonly.pb",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfvariable_sequential_updates",
testonly = 1,
@ -269,6 +282,7 @@ tf_cc_test(
":test_graph_tfsplits",
":test_graph_tftop_k",
":test_graph_tfvariable",
":test_graph_tfvariable_readonly",
":test_graph_tfvariable_sequential_updates",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@ -323,6 +337,42 @@ tf_library(
],
)
tf_library(
name = "test_graph_tfcond_mlir_bridge",
testonly = 1,
config = "test_graph_tfcond.config.pbtxt",
cpp_class = "CondComp",
graph = "test_graph_tfcond.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfassert_eq_mlir_bridge",
testonly = 1,
config = "test_graph_tfassert_eq.config.pbtxt",
cpp_class = "AssertComp",
graph = "test_graph_tfassert_eq.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfgather_mlir_bridge",
testonly = 1,
config = "test_graph_tfgather.config.pbtxt",
cpp_class = "GatherComp",
graph = "test_graph_tfgather.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfmatmul_mlir_bridge",
testonly = 1,
@ -361,6 +411,66 @@ tf_library(
],
)
tf_library(
name = "test_graph_tfsplits_mlir_bridge",
testonly = 1,
config = "test_graph_tfsplits.config.pbtxt",
cpp_class = "SplitsComp",
graph = "test_graph_tfsplits.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tftop_k_mlir_bridge",
testonly = 1,
config = "test_graph_tftop_k.config.pbtxt",
cpp_class = "TopKComp",
graph = "test_graph_tftop_k.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfvariable_readonly_mlir_bridge",
testonly = 1,
config = "test_graph_tfvariable_readonly.config.pbtxt",
cpp_class = "VariableReadonlyComp",
graph = "test_graph_tfvariable_readonly.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfvariable_mlir_bridge",
testonly = 1,
config = "test_graph_tfvariable.config.pbtxt",
cpp_class = "VariableComp",
graph = "test_graph_tfvariable.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfvariable_sequential_updates_mlir_bridge",
testonly = 1,
config = "test_graph_tfvariable_sequential_updates.config.pbtxt",
cpp_class = "VariableSequentialUpdatesComp",
graph = "test_graph_tfvariable_sequential_updates.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_cc_test(
name = "tfcompile_test_mlir_bridge",
srcs = ["tfcompile_test.cc"],
@ -372,9 +482,17 @@ tf_cc_test(
":test_graph_tfadd_mlir_bridge",
":test_graph_tfadd_with_ckpt_mlir_bridge",
":test_graph_tfadd_with_ckpt_saver_mlir_bridge",
":test_graph_tfassert_eq_mlir_bridge",
":test_graph_tfcond_mlir_bridge",
":test_graph_tfgather_mlir_bridge",
":test_graph_tfmatmul_mlir_bridge",
":test_graph_tfmatmulandadd_mlir_bridge",
":test_graph_tfmatmulandadd_with_profiling_mlir_bridge",
":test_graph_tfsplits_mlir_bridge",
":test_graph_tftop_k_mlir_bridge",
":test_graph_tfvariable_mlir_bridge",
":test_graph_tfvariable_readonly_mlir_bridge",
":test_graph_tfvariable_sequential_updates_mlir_bridge",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_data_proto_cc",

View File

@ -34,6 +34,7 @@ from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variables
@ -153,11 +154,21 @@ def tftop_k(_):
array_ops.identity(output[1], name='indices')
def tfvariable(_):
def tfvariable_readonly(_):
x = variables.Variable(1000.0, name='x')
old_x = x.value()
with ops.control_dependencies([old_x]):
new_x = x.assign_add(42.0)
new_value = math_ops.add(old_x, 42.0)
array_ops.identity(new_value, name='result')
# TODO(b/147908587): Change x and the two constants back to have a scalar shape
# when the bug is fixed.
def tfvariable(_):
x = variables.Variable([1000.0], name='x', shape=[1])
old_x = x.value()
with ops.control_dependencies([old_x]):
new_x = x.assign_add([42.0])
array_ops.stack([old_x, new_x], name='result')
@ -184,6 +195,7 @@ def write_graph(build_graph, out_dir):
def main(_):
control_flow_util.enable_control_flow_v2()
write_graph(tfadd, FLAGS.out_dir)
write_graph(tfadd_with_ckpt, FLAGS.out_dir)
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
@ -196,6 +208,7 @@ def main(_):
write_graph(tfsplits, FLAGS.out_dir)
write_graph(tftop_k, FLAGS.out_dir)
write_graph(tfvariable, FLAGS.out_dir)
write_graph(tfvariable_readonly, FLAGS.out_dir)
write_graph(tfvariable_sequential_updates, FLAGS.out_dir)

View File

@ -0,0 +1,12 @@
# Text form of tensorflow.tf2xla.Config proto.
fetch {
id { node_name: "result" }
}
variable {
node_name: "x"
shape {
}
type: DT_FLOAT
readonly: true
}

View File

@ -30,9 +30,17 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfcond_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfgather_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tftop_k_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates_mlir_bridge.h"
#else
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
@ -47,6 +55,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h"
#include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates.h"
#endif
@ -167,8 +176,6 @@ TEST(TFCompileTest, AddWithCkptSaver) {
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
}
// TODO(bixia): the following tests failed with MLIR bridge.
#if !defined(ENABLE_MLIR_BRIDGE_TEST)
TEST(TFCompileTest, Cond) {
CondComp cond;
EXPECT_EQ(cond.arg0_data(), cond.arg_data(0));
@ -233,7 +240,6 @@ TEST(TFCompileTest, Gather) {
EXPECT_EQ(gather_const.result0_data(), gather.results()[0]);
}
}
#endif
TEST(TFCompileTest, MatMul2) {
Eigen::ThreadPool tp(2);
@ -439,6 +445,7 @@ TEST(TFCompileTest, Function) {
EXPECT_EQ(add_fn.result0_data()[0], 3);
EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]);
}
#endif
TEST(TFCompileTest, Splits) {
Eigen::ThreadPool tp(1);
@ -492,6 +499,20 @@ TEST(TFCompileTest, TopK) {
EXPECT_EQ(expected_indices[1], fn.result1(1));
}
TEST(TFCompileTest, VariableReadonly) {
Eigen::ThreadPool tp(1);
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
VariableReadonlyComp fn;
float x = 23;
fn.set_var_x_data(&x);
fn.set_thread_pool(&device);
fn.Run();
EXPECT_EQ(fn.result0(), 65);
EXPECT_EQ(fn.var_x(), 23);
}
TEST(TFCompileTest, Variable) {
Eigen::ThreadPool tp(1);
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
@ -665,6 +686,11 @@ TEST(TFCompileTest, HloProfiling) {
/*clock_rate_ghz=*/1.0);
VLOG(1) << "Original HLO profile string:\n" << hlo_profile_as_string;
// Replace Arg_n with argn when the MLIR bridge is used.
#if defined(ENABLE_MLIR_BRIDGE_TEST)
RE2::GlobalReplace(&hlo_profile_as_string, "(Arg_)([0-9].)", "arg\\2");
#endif
// Strip away identifier details from the profile string to avoid this test
// being a change detector for xla internals. Identifiers such as '%dot.0.7'
// just become '%dot'.
@ -690,7 +716,6 @@ TEST(TFCompileTest, HloProfiling) {
IsSupersetOf({header, total_cycles_profile_line, dot_profile_line,
add_profile_line, tuple_profile_line}));
}
#endif
} // namespace
} // namespace tfcompile

View File

@ -21,7 +21,6 @@ limitations under the License.
#include "absl/strings/match.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "llvm-c/Target.h"
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/compile.h"
#include "tensorflow/compiler/aot/flags.h"
@ -56,88 +55,6 @@ const char kUsageHeader[] =
"--cpp_class=\"mynamespace::MyComputation\"\n"
"\n";
Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
if (absl::EndsWith(fname, ".pbtxt")) {
return ReadTextProto(Env::Default(), fname, proto);
} else {
return ReadBinaryProto(Env::Default(), fname, proto);
}
}
Status Main(const MainFlags& flags) {
// Initialize all LLVM targets so we can cross compile.
LLVMInitializeAArch64Target();
LLVMInitializeAArch64TargetInfo();
LLVMInitializeAArch64TargetMC();
LLVMInitializeAArch64AsmPrinter();
LLVMInitializeARMTarget();
LLVMInitializeARMTargetInfo();
LLVMInitializeARMTargetMC();
LLVMInitializeARMAsmPrinter();
LLVMInitializePowerPCTarget();
LLVMInitializePowerPCTargetInfo();
LLVMInitializePowerPCTargetMC();
LLVMInitializePowerPCAsmPrinter();
LLVMInitializeX86Target();
LLVMInitializeX86TargetInfo();
LLVMInitializeX86TargetMC();
LLVMInitializeX86AsmPrinter();
// Process config.
tf2xla::Config config;
if (flags.config.empty()) {
return errors::InvalidArgument("Must specify --config");
}
TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
TF_RETURN_IF_ERROR(ValidateConfig(config));
if (flags.dump_fetch_nodes) {
std::set<string> nodes;
for (const tf2xla::Fetch& fetch : config.fetch()) {
nodes.insert(fetch.id().node_name());
}
std::cout << absl::StrJoin(nodes, ",");
return Status::OK();
}
// Read and initialize the graph.
if (flags.graph.empty()) {
return errors::InvalidArgument("Must specify --graph");
}
GraphDef graph_def;
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
CompileResult compile_result;
TF_RETURN_IF_ERROR(CompileGraph(graph_def, config, flags, &compile_result));
// Write output files.
Env* env = Env::Default();
const std::vector<char>& obj = compile_result.aot->object_file_data();
TF_RETURN_IF_ERROR(
WriteStringToFile(env, flags.out_function_object,
absl::string_view(obj.data(), obj.size())));
CodegenOpts codegen_opts;
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
codegen_opts.gen_program_shape = flags.gen_program_shape;
codegen_opts.target_triple = flags.target_triple;
if (flags.cpp_class.empty()) {
return errors::InvalidArgument("Must specify --cpp_class");
}
codegen_opts.gen_hlo_profile_printer_data =
xla::GetDebugOptionsFromFlags().xla_hlo_profile();
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
&codegen_opts.namespaces));
MetadataResult metadata_result;
TF_RETURN_IF_ERROR(
GenerateMetadata(codegen_opts, compile_result, &metadata_result));
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object,
metadata_result.object_file_data));
string header;
TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result,
metadata_result, &header));
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
return Status::OK();
}
} // end namespace tfcompile
} // end namespace tensorflow

View File

@ -2,6 +2,7 @@ load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "tf_cc_
load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps")
load("//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
load("//tensorflow/core/platform:build_config_root.bzl", "tf_cuda_tests_tags")
package(
default_visibility = [":internal"],
@ -56,6 +57,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":jit_compilation_passes",
":xla_kernel_creator", # buildcleaner: keep
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
@ -69,6 +71,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = if_cuda_or_rocm([
":jit_compilation_passes",
":xla_kernel_creator", # buildcleaner: keep
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
@ -77,19 +80,6 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "xla_mlir_gpu_jit",
visibility = ["//visibility:public"],
deps = if_cuda_or_rocm([
":jit_compilation_passes",
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/xla/service:mlir_gpu_plugin",
]),
alwayslink = 1,
)
cc_library(
name = "xla_cpu_device",
srcs = ["xla_cpu_device.cc"],
@ -115,6 +105,7 @@ cc_library(
srcs = ["xla_gpu_device.cc"],
visibility = [":friends"],
deps = [
":flags",
":jit_compilation_passes",
":xla_device",
":xla_kernel_creator", # buildcleaner: keep
@ -123,6 +114,7 @@ cc_library(
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:gpu_init",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@ -167,7 +159,9 @@ XLA_DEVICE_DEPS = [
":common",
":xla_launch_util",
":xla_tensor",
"@com_google_absl//absl/base",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:optional",
"//tensorflow/compiler/jit/ops:xla_ops",
@ -260,13 +254,26 @@ cc_library(
}),
)
# Internal targets below this point.
cc_library(
name = "flags",
srcs = ["flags.cc"],
hdrs = ["flags.h"],
visibility = [":friends"],
deps = [
"//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@com_google_absl//absl/base",
"@com_google_absl//absl/strings",
],
)
# Header-only version of "flags" library, for linking from the shared object
# without ODR violations.
cc_library(
name = "flags_headers_only",
hdrs = ["flags.h"],
visibility = [":friends"],
deps = [
"//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal",
@ -286,6 +293,8 @@ cc_library(
visibility = [":friends"],
)
# Internal targets below this point.
cc_library(
name = "xla_launch_util",
srcs = ["xla_launch_util.cc"],
@ -407,6 +416,7 @@ cc_library(
"xla_kernel_creator.h",
],
deps = [
":flags",
":jit_compilation_passes",
":xla_kernel_creator_util",
"//tensorflow/core:core_cpu_internal",
@ -635,6 +645,7 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
@ -766,7 +777,7 @@ tf_cc_test(
],
# TODO(b/141643254) Re-enable msan after fixing use-of-uninitialized-value
# error.
tags = ["nomsan"],
tags = ["nomsan"] + tf_cuda_tests_tags(),
deps = [
":common",
":compilation_passes",

View File

@ -1584,7 +1584,6 @@ DeadnessAnalysis::~DeadnessAnalysis() {}
absl::flat_hash_map<TensorId, string, TensorId::Hasher>
DeadnessAnalysisImpl::PredicateMapAsString() const {
absl::flat_hash_map<TensorId, string, TensorId::Hasher> result;
std::vector<TensorId> tensor_ids;
for (const auto& kv_pair : predicate_map_) {
CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second);
}

View File

@ -374,39 +374,6 @@ xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
return new_def;
}
TF_ATTRIBUTE_NOINLINE Status
ValidateOutsideCompilationCallNode(Node* call_node) {
// DT_INT64 as input/output for outside compilation is not supported yet:
// b/120809951.
for (const Edge* e : call_node->in_edges()) {
if (e->IsControlEdge()) {
continue;
}
DataType dtype = e->src()->output_type(e->src_output());
if (dtype == DT_INT64) {
return errors::Unimplemented(
"int64 input for outside compilation is not supported yet: "
"b/120809951. Please cast output of node ",
e->src()->DebugString(),
" to int32 before feeding it into outside compilation.");
}
}
for (const Edge* e : call_node->out_edges()) {
if (e->IsControlEdge()) {
continue;
}
DataType dtype = e->dst()->input_type(e->dst_input());
if (dtype == DT_INT64) {
return errors::Unimplemented(
"int64 output for outside compilation is not supported yet: "
"b/120809951. Please cast input of node ",
e->dst()->DebugString(),
" to int32 before returning it from outside compilation.");
}
}
return Status::OK();
}
// Replace outside compilation function call node with XlaHostCompute node.
TF_ATTRIBUTE_NOINLINE xla::StatusOr<Node*> ReplaceOutsideCompilationCallNode(
Graph* g, Node* call_node, const std::map<string, int>& host_compute_core,
@ -2384,7 +2351,6 @@ Status ExtractOutsideCompilationForFunction(
}
std::map<string, Node*> host_compute_nodes;
for (Node* n : outside_compilation_nodes) {
TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n));
auto host_compute_node_or = ReplaceOutsideCompilationCallNode(
graph_out.get(), n, host_compute_core, *cluster_deps);
TF_RETURN_IF_ERROR(host_compute_node_or.status());

View File

@ -13,13 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/flags.h"
#include <mutex> // NOLINT
#include "absl/base/call_once.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_split.h"
#include "absl/strings/strip.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
@ -32,7 +35,7 @@ XlaOpsCommonFlags* ops_flags;
IntroduceFloatingPointJitterPassFlags* jitter_flags;
std::vector<Flag>* flag_list;
std::once_flag flags_init;
absl::once_flag flags_init;
bool SetterForXlaAutoJitFlag(const string& value) {
int32 opt_level;
@ -155,6 +158,7 @@ void AllocateAndParseFlags() {
device_flags = new XlaDeviceFlags;
device_flags->tf_xla_compile_on_demand = false;
device_flags->tf_xla_enable_xla_devices = true;
ops_flags = new XlaOpsCommonFlags;
ops_flags->tf_xla_always_defer_compilation = false;
@ -187,6 +191,12 @@ void AllocateAndParseFlags() {
"Switch a device into 'on-demand' mode, where instead of "
"autoclustering ops are compiled one by one just-in-time."),
Flag("tf_xla_enable_xla_devices",
&device_flags->tf_xla_enable_xla_devices,
"Generate XLA_* devices, where placing a computation on such a "
"device"
"forces compilation by XLA. Deprecated."),
Flag("tf_xla_always_defer_compilation",
&ops_flags->tf_xla_always_defer_compilation, ""),
@ -206,38 +216,45 @@ void AllocateAndParseFlags() {
} // namespace
bool SetXlaAutoJitFlagFromFlagString(const string& value) {
std::call_once(flags_init, &AllocateAndParseFlags);
absl::call_once(flags_init, &AllocateAndParseFlags);
return SetterForXlaAutoJitFlag(value);
}
BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags() {
std::call_once(flags_init, &AllocateAndParseFlags);
absl::call_once(flags_init, &AllocateAndParseFlags);
return build_ops_flags;
}
MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() {
std::call_once(flags_init, &AllocateAndParseFlags);
absl::call_once(flags_init, &AllocateAndParseFlags);
return mark_for_compilation_flags;
}
XlaDeviceFlags* GetXlaDeviceFlags() {
std::call_once(flags_init, &AllocateAndParseFlags);
absl::call_once(flags_init, &AllocateAndParseFlags);
return device_flags;
}
const XlaOpsCommonFlags& GetXlaOpsCommonFlags() {
std::call_once(flags_init, &AllocateAndParseFlags);
absl::call_once(flags_init, &AllocateAndParseFlags);
return *ops_flags;
}
const IntroduceFloatingPointJitterPassFlags&
GetIntroduceFloatingPointJitterPassFlags() {
std::call_once(flags_init, &AllocateAndParseFlags);
absl::call_once(flags_init, &AllocateAndParseFlags);
return *jitter_flags;
}
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
std::call_once(flags_init, &AllocateAndParseFlags);
absl::call_once(flags_init, &AllocateAndParseFlags);
AppendMarkForCompilationPassFlagsInternal(flag_list);
}
static bool xla_is_enabled = false;
void SetXlaIsEnabled() { xla_is_enabled = true; }
bool IsXlaEnabled() { return xla_is_enabled; }
} // namespace tensorflow

View File

@ -87,6 +87,9 @@ struct XlaDeviceFlags {
// Enabling this mode by a legacy flag is a temporary mechanism. When this
// feature is battle-tested, we will switch this to be a session option.
bool tf_xla_compile_on_demand;
// Enables "XLA" devices if this flag is set.
bool tf_xla_enable_xla_devices;
};
// Flags common to the _Xla* ops and their kernels.
@ -151,6 +154,15 @@ GetIntroduceFloatingPointJitterPassFlags();
// Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet.
void AppendMarkForCompilationPassFlags(
std::vector<tensorflow::Flag>* flag_list);
// Makes all future calls to `IsXlaEnabled()` return `true`.
//
// Should only be called when XLA is linked in.
void SetXlaIsEnabled();
// Returns whether XLA is enabled.
bool IsXlaEnabled();
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <unordered_map>
#include <unordered_set>
#include "absl/base/call_once.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_join.h"
@ -1616,8 +1617,8 @@ StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileClusterImpl(
if (!should_compile && global_jit_level_ != OptimizerOptions::OFF &&
device_type.type_string() == DEVICE_CPU) {
static std::once_flag once;
std::call_once(once, [] {
static absl::once_flag once;
absl::call_once(once, [] {
LOG(WARNING)
<< "(One-time warning): Not using XLA:CPU for cluster because envvar "
"TF_XLA_FLAGS=--tf_xla_cpu_global_jit was not set. If you want "
@ -1776,9 +1777,9 @@ absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable() {
"Lgamma", "Digamma",
// Binary
"Add", "AddV2", "Sub", "Mul", "Div", "Atan2", "Complex", "DivNoNan",
"MulNoNan", "FloorDiv", "Xlogy", "Xdivy", "FloorMod", "BitwiseAnd",
"BitwiseOr", "BitwiseXor", "LeftShift", "RightShift", "LogicalAnd",
"LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
"MulNoNan", "FloorDiv", "Xlogy", "Xlog1py", "Xdivy", "FloorMod",
"BitwiseAnd", "BitwiseOr", "BitwiseXor", "LeftShift", "RightShift",
"LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
"ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "TruncateDiv",
"TruncateMod", "Equal", "NotEqual", "Greater", "GreaterEqual",
"Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", "SoftsignGrad",
@ -1872,6 +1873,8 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"Einsum",
"EmptyTensorList",
"ExtractImagePatches",
"Igamma",
"Igammac",
"FFT",
"FFT2D",
"FFT3D",

View File

@ -163,12 +163,11 @@ Status XlaCompilationCache::BuildExecutable(
build_options.set_device_allocator(options.device_allocator);
build_options.set_alias_passthrough_params(options.alias_passthrough_params);
auto compile_result =
client_->Compile(*result.computation, argument_layouts, build_options);
if (!compile_result.ok()) {
return compile_result.status();
}
*executable = std::move(compile_result.ValueOrDie());
TF_ASSIGN_OR_RETURN(
auto executables,
client_->Compile(*result.computation, argument_layouts, build_options));
TF_RET_CHECK(executables.size() == 1);
*executable = std::move(executables[0]);
return Status::OK();
}

View File

@ -36,8 +36,13 @@ class XlaCpuDeviceFactory : public DeviceFactory {
};
Status XlaCpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0"));
XlaDeviceFlags* flags = GetXlaDeviceFlags();
if (!flags->tf_xla_enable_xla_devices) {
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
return Status::OK();
}
devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0"));
return Status::OK();
}
@ -45,6 +50,10 @@ Status XlaCpuDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) {
XlaDeviceFlags* flags = GetXlaDeviceFlags();
if (!flags->tf_xla_enable_xla_devices) {
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
return Status::OK();
}
bool compile_on_demand = flags->tf_xla_compile_on_demand;
XlaOpRegistry::DeviceRegistration registration;

View File

@ -20,7 +20,9 @@ limitations under the License.
#include <unordered_set>
#include <utility>
#include "absl/base/call_once.h"
#include "absl/memory/memory.h"
#include "absl/strings/match.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
#include "tensorflow/compiler/jit/xla_device_context.h"
@ -386,14 +388,33 @@ Status XlaDevice::TryGetDeviceContext(DeviceContext** out_context) {
return Status::OK();
}
// Warn about XLA_CPU/XLA_GPU exactly once.
static void ShowXlaDeviceDeprecationWarning(
absl::string_view compilation_device_name) {
static absl::once_flag once;
if (absl::StrContains(compilation_device_name, "CPU") ||
absl::StrContains(compilation_device_name, "GPU")) {
absl::call_once(once, [] {
LOG(WARNING)
<< "XLA_GPU and XLA_CPU devices are deprecated and will be "
"removed in subsequent releases. Instead, use either "
"@tf.function(experimental_compile=True) for must-compile "
"semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 "
"for auto-clustering best-effort compilation.";
});
}
}
void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":"
<< op_kernel->type_string();
ShowXlaDeviceDeprecationWarning(jit_device_name_.type_string());
op_kernel->Compute(context);
}
void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) {
ShowXlaDeviceDeprecationWarning(jit_device_name_.type_string());
VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
<< op_kernel->type_string();
op_kernel->ComputeAsync(context, done);

View File

@ -140,7 +140,6 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
// The device tensor should always be fresh.
TF_RET_CHECK(!xla_tensor->has_shaped_buffer());
xla_tensor->set_host_tensor(*cpu_tensor);
TF_RETURN_IF_ERROR(
xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
stream_->parent()->device_ordinal()));

View File

@ -14,17 +14,20 @@ limitations under the License.
==============================================================================*/
// Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
// operators using XLA via the XLA "CUDA" (GPU) backend.
// operators using XLA via the XLA "CUDA" or "ROCM" (GPU) backend.
#include <set>
#include "absl/memory/memory.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_split.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
@ -61,7 +64,14 @@ class XlaGpuDeviceFactory : public DeviceFactory {
};
Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA");
XlaDeviceFlags* flags = GetXlaDeviceFlags();
if (!flags->tf_xla_enable_xla_devices) {
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
return Status::OK();
}
auto platform =
se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
if (!platform.ok()) {
// Treat failures as non-fatal; there might not be a GPU in the machine.
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
@ -84,6 +94,12 @@ Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
Status XlaGpuDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) {
XlaDeviceFlags* flags = GetXlaDeviceFlags();
if (!flags->tf_xla_enable_xla_devices) {
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
return Status::OK();
}
XlaOpRegistry::DeviceRegistration registration;
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
registration.autoclustering_policy =
@ -103,7 +119,8 @@ Status XlaGpuDeviceFactory::CreateDevices(
RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT);
(void)registrations;
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA");
auto platform =
se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
if (!platform.ok()) {
// Treat failures as non-fatal; there might not be a GPU in the machine.
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_kernel_creator.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_kernel_creator_util.h"
#include "tensorflow/core/common_runtime/function.h"
@ -39,6 +40,10 @@ bool RegisterLaunchOpCreator() {
}
static bool register_me = RegisterLaunchOpCreator();
static bool register_xla = [] {
SetXlaIsEnabled();
return true;
}();
} // end namespace
} // namespace tensorflow

View File

@ -222,8 +222,9 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
OpKernelConstruction construction(
DeviceType(dev->device_type()), dev,
dev->GetAllocator(AllocatorAttributes()), &node_def,
&fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types,
fbody->ret_types, output_memory_types, flr->graph_def_version(), &s);
&fbody->fdef.signature(), flr, dev->resource_manager(), fbody->arg_types,
input_memory_types, fbody->ret_types, output_memory_types,
flr->graph_def_version(), &s);
*kernel = absl::make_unique<XlaLocalLaunchBase>(
&construction, constant_arg_indices, resource_arg_indices, function);

View File

@ -66,6 +66,8 @@ cc_library(
"//tensorflow/compiler/mlir/lite:tensorflow_lite_optimize",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
"//tensorflow/compiler/mlir/lite/quantization:quantization_passes",
"//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_to_quant",
"//tensorflow/compiler/mlir/lite/quantization/xla:hlo_xla_quantization_passes",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
@ -77,10 +79,10 @@ cc_library(
"//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_linalg",
"//tensorflow/compiler/mlir/xla:xla_dialect_registration",
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/mlir/xla:xla_legalize_to_linalg",
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
"//tensorflow/compiler/mlir/xla:xla_lower",
"//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts",

View File

@ -26,9 +26,11 @@ package_group(
filegroup(
name = "tensorflow_lite_ops_td_files",
srcs = [
"ir/tfl_op_interfaces.td",
"ir/tfl_ops.td",
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:include/mlir/Transforms/LoopLikeInterface.td",
],
)
@ -55,6 +57,25 @@ gentbl(
],
)
gentbl(
name = "tensorflow_lite_op_interfaces_inc_gen",
tbl_outs = [
(
"-gen-op-interface-decls",
"ir/tfl_ops_interface.h.inc",
),
(
"-gen-op-interface-defs",
"ir/tfl_ops_interface.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ir/tfl_op_interfaces.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
],
)
gentbl(
name = "tensorflow_lite_prepare_tf_inc_gen",
tbl_outs = [
@ -177,11 +198,12 @@ cc_library(
"ir/tfl_ops.cc",
"ir/tfl_ops.cc.inc",
"ir/tfl_ops.h.inc",
"ir/tfl_ops_interface.cc.inc",
"ir/tfl_ops_interface.h.inc",
"utils/attribute_utils.cc",
],
hdrs = [
"ir/tfl_ops.h",
"ir/tfl_traits.h",
"transforms/passes.h",
"utils/attribute_utils.h",
"//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",
@ -190,8 +212,6 @@ cc_library(
deps = [
":tensorflow_lite_ops_inc_gen",
":validators",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/lite/schema:schema_fbs",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:Dialect",
@ -200,6 +220,10 @@ cc_library(
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
# TODO(jpienaar): Move this out after splitting out LoopLikeOpInterface.
"@llvm-project//mlir:Transforms",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/lite/schema:schema_fbs",
],
alwayslink = 1,
)
@ -258,6 +282,7 @@ tf_cc_test(
cc_library(
name = "tensorflow_lite_legalize_tf",
srcs = [
"transforms/dilated_conv.cc",
"transforms/extract_ophint.cc",
"transforms/generated_legalize_tf.inc",
"transforms/generated_lower_static_tensor_list.inc",
@ -273,6 +298,7 @@ cc_library(
"transforms/unroll_batch_matmul.cc",
],
hdrs = [
"transforms/dilated_conv.h",
"transforms/passes.h",
"transforms/unroll_batch_matmul.h",
],
@ -284,13 +310,16 @@ cc_library(
":validators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:tensor_list",
"//tensorflow/core/platform:logging",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
@ -316,6 +345,7 @@ cc_library(
deps = [
":tensorflow_lite",
":validators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
@ -347,6 +377,7 @@ cc_library(
":validators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/lite/quantization/lite:tfl_to_std",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
@ -371,6 +402,8 @@ genrule(
name = "op_quant_spec_getters_inc",
srcs = [
"ir/tfl_ops.td",
"ir/tfl_op_interfaces.td",
"@llvm-project//mlir:include/mlir/Transforms/LoopLikeInterface.td",
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
],
outs = [
@ -673,12 +706,16 @@ cc_library(
],
)
exports_files(
["transforms/passes.h"],
cc_library(
name = "empty_passes",
hdrs = ["transforms/passes.h"],
visibility = [
"//configs/devtools/hawkeye/tflite:__subpackages__",
"//learning/brain/models/app_benchmarks:__subpackages__",
"//tensorflow/compiler/mlir/lite:friends",
"//tensorflow/lite/experimental/mlir:__subpackages__",
],
deps = [
"@llvm-project//llvm:support",
],
)

View File

@ -31,10 +31,11 @@ struct PassConfig {
: emit_builtin_tflite_ops(true),
lower_tensor_list_ops(false),
trim_functions_whitelist({}),
quant_specs(specs),
quant_specs(std::move(specs)),
skip_control_dialect(false),
form_clusters(false),
inline_functions(false) {}
inline_functions(true),
unfold_batch_matmul(true) {}
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
// added, which produces TF Lite ops.
@ -57,6 +58,9 @@ struct PassConfig {
// Inline function calls within the main function in the MLIR module, prior
// to legalization to TFLite.
bool inline_functions;
// if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set
// of tfl.fully_connected ops.
bool unfold_batch_matmul;
};
} // namespace TFL

View File

@ -389,7 +389,6 @@ StatusOr<mlir::ElementsAttr> ConvertIntBuffer(
mlir::RankedTensorType shaped_type, mlir::Type elem_type,
const std::vector<uint8_t>& buffer) {
unsigned bit_width;
mlir::RankedTensorType buffer_type;
if (auto itype = elem_type.dyn_cast<mlir::IntegerType>()) {
bit_width = itype.getWidth();
} else if (auto qtype = elem_type.dyn_cast<QuantizedType>()) {
@ -920,15 +919,13 @@ StatusOr<FuncOp> ConvertSubgraph(
// represents TFLite, this entry point must be called "main"
// TODO(b/131175224,b/132239787) Support multiple entry points
std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
if (subgraph.name.empty()) {
if (index == 0) {
return "main";
} else {
return llvm::formatv("fn_{0}", index).str();
}
} else {
return subgraph.name;
if (index == 0) {
return "main";
}
if (subgraph.name.empty()) {
return llvm::formatv("fn_{0}", index).str();
}
return subgraph.name;
}
} // namespace

View File

@ -90,6 +90,7 @@ using mlir::MLIRContext;
using mlir::ModuleOp;
using mlir::NoneType;
using mlir::Operation;
using mlir::Region;
using mlir::StringAttr;
using mlir::TensorType;
using mlir::TranslateFromMLIRRegistration;
@ -309,7 +310,7 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
return true;
}
static std::unique_ptr<::tensorflow::NodeDef> getTensorFlowNodeDef(
static std::unique_ptr<::tensorflow::NodeDef> GetTensorFlowNodeDef(
::mlir::Operation* inst) {
// We pass empty string for the original node_def name since Flex runtime
// does not care about this being set correctly on node_def. There is no
@ -425,6 +426,11 @@ class Translator {
mlir::TF::WhileOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
// Build while operator where cond & body are regions.
BufferOffset<tflite::Operator> BuildWhileOperator(
mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
// Builds custom operators.
// Templated on a) data type of custom_option to be stored into flatbuffer,
// and b) TFL custom op type.
@ -472,7 +478,10 @@ class Translator {
Operation* inst, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(FuncOp fn);
// Build a subgraph with a given name out of the region either corresponding
// to a function's body or while op.
Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(
const std::string& name, Region* region);
// Builds Metadata with the given `name` and buffer `content`.
BufferOffset<tflite::Metadata> BuildMetadata(StringRef name,
@ -494,6 +503,12 @@ class Translator {
// Returns a unique name for `val`.
std::string UniqueName(mlir::Value val);
// Returns the names of the subgraphs corresponding the regions of the op. The
// names are supposed to be unique as the op name is unique and the suffix is
// not a valid name.
std::string GetWhileBodyName(mlir::TFL::WhileOp while_op);
std::string GetWhileCondName(mlir::TFL::WhileOp while_op);
ModuleOp module_;
tensorflow::OpOrArgNameMapper& name_mapper_;
@ -523,7 +538,7 @@ class Translator {
};
std::string Translator::UniqueName(mlir::Value val) {
return name_mapper_.GetUniqueName(val);
return std::string(name_mapper_.GetUniqueName(val));
}
Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
@ -595,6 +610,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
};
std::vector<int32_t> shape;
std::vector<int32_t> shape_signature;
if (type.hasStaticShape()) {
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
@ -612,7 +628,17 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
}
} else if (type.hasRank()) {
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
shape.reserve(shape_ref.size());
for (auto& dim : shape_ref) {
shape.push_back(dim == -1 ? 1 : dim);
}
shape_signature = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
}
Type element_type = type.getElementType();
tflite::TensorType tflite_element_type =
GetTFLiteType(type.getElementType()).ValueOrDie();
@ -649,10 +675,19 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
break;
}
}
return tflite::CreateTensor(
builder_, builder_.CreateVector(shape), tflite_element_type,
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
/*is_variable=*/is_variable);
if (shape_signature.empty()) {
return tflite::CreateTensor(
builder_, builder_.CreateVector(shape), tflite_element_type,
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
/*is_variable=*/is_variable);
} else {
return tflite::CreateTensor(
builder_, builder_.CreateVector(shape), tflite_element_type,
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
/*is_variable=*/is_variable, /*sparsity=*/0,
/*shape_signature=*/builder_.CreateVector(shape_signature));
}
}
BufferOffset<tflite::Operator> Translator::BuildIfOperator(
@ -687,6 +722,30 @@ BufferOffset<tflite::Operator> Translator::BuildWhileOperator(
builtin_options);
}
std::string Translator::GetWhileBodyName(mlir::TFL::WhileOp while_op) {
return (name_mapper_.GetUniqueName(while_op.getOperation()) + "$body").str();
}
std::string Translator::GetWhileCondName(mlir::TFL::WhileOp while_op) {
return (name_mapper_.GetUniqueName(while_op.getOperation()) + "$cond").str();
}
BufferOffset<tflite::Operator> Translator::BuildWhileOperator(
mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results) {
auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE);
int body_subgraph_index = subgraph_index_map_.at(GetWhileBodyName(op));
int cond_subgraph_index = subgraph_index_map_.at(GetWhileCondName(op));
auto builtin_options = tflite::CreateWhileOptions(
builder_, cond_subgraph_index, body_subgraph_index)
.Union();
auto inputs = builder_.CreateVector(operands);
auto outputs = builder_.CreateVector(results);
return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
tflite::BuiltinOptions_WhileOptions,
builtin_options);
}
template <typename CustomOptionType, typename TFLOp>
BufferOffset<tflite::Operator> Translator::BuildCustomOperator(
const CustomOptionType& custom_option, const std::string& opcode_name,
@ -908,6 +967,16 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
return BuildMaxUnpooling2DOperator(inst, max_unpooling_op, operands,
results);
}
if (auto whileOp = dyn_cast<mlir::TFL::WhileOp>(inst)) {
if (inst->getNumOperands() != inst->getNumResults()) {
inst->emitOpError(
"number of operands and results don't match, only canonical "
"TFL While supported");
return llvm::None;
}
return BuildWhileOperator(whileOp, operands, results);
}
inst->emitOpError("is not a supported TFLite op");
return llvm::None;
}
@ -944,7 +1013,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
// we emit op as flex.
// if custom is enabled
// we emit the op as custom.
auto node_def = getTensorFlowNodeDef(inst);
auto node_def = GetTensorFlowNodeDef(inst);
if (!node_def) {
return llvm::None;
}
@ -1043,18 +1112,16 @@ void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) {
bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) {
std::vector<int> operand_indices;
// TODO(b/138254427): When the bug is addressed, we'll be able to inspect
// for the presence of a specific OpTrait using mlir::Operation, without
// having to cast it to specific ops like below.
// Until then, when a new RNN/LSTM op is added to TFLite and has stateful
// tensors as operands, they will need to be added here as well.
if (!mlir::TFL::IsStatefulOp(op, &operand_indices)) return false;
return absl::c_find(operand_indices, operand_index) != operand_indices.end();
}
Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
const std::string& name, Region* region) {
bool has_input_attr = false;
InitializeNamesFromAttribute(fn, &has_input_attr);
if (auto fn = dyn_cast<FuncOp>(region->getParentOp())) {
InitializeNamesFromAttribute(fn, &has_input_attr);
}
std::vector<BufferOffset<tflite::Tensor>> tensors;
llvm::DenseMap<Value, int> tensor_index_map;
@ -1086,7 +1153,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
};
std::vector<BufferOffset<tflite::Operator>> operators;
auto& bb = fn.getBlocks().front();
auto& bb = region->front();
// Main function's arguments are first passed to `input` op so they don't
// have associated tensor and buffer. Build FlatBuffer tensor and buffer for
@ -1094,7 +1161,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
for (unsigned i = 0, e = bb.getNumArguments(); i < e; ++i) {
mlir::BlockArgument arg = bb.getArgument(i);
std::string name;
if (has_input_attr) name = name_mapper_.GetUniqueName(arg);
if (has_input_attr) name = std::string(name_mapper_.GetUniqueName(arg));
if (name.empty()) name = absl::StrCat("arg", i);
if (!build_tensor_and_buffer(arg, name)) return llvm::None;
}
@ -1146,7 +1213,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
return tflite::CreateSubGraph(
builder_, builder_.CreateVector(tensors), builder_.CreateVector(inputs),
builder_.CreateVector(outputs), builder_.CreateVector(operators),
/*name=*/builder_.CreateString(fn.getName().str()));
/*name=*/builder_.CreateString(name));
}
BufferOffset<tflite::Metadata> Translator::BuildMetadata(StringRef name,
@ -1189,35 +1256,45 @@ Optional<std::string> Translator::Translate(
}
Optional<std::string> Translator::TranslateInternal() {
// Create a list of functions in the module with main function being the
// first function in the list. This is required as the first subgraph in the
// model is entry point for the model.
std::vector<FuncOp> functions;
functions.reserve(std::distance(module_.begin(), module_.end()));
// A list of named regions in the module with main function being the first in
// the list. The main function is required as the first subgraph in the model
// is entry point for the model.
std::vector<std::pair<std::string, Region*>> named_regions;
named_regions.reserve(std::distance(module_.begin(), module_.end()));
int subgraph_idx = 0;
FuncOp main_fn = module_.lookupSymbol<FuncOp>("main");
subgraph_index_map_[main_fn.getName().str()] = subgraph_idx++;
functions.push_back(main_fn);
for (auto fn : module_.getOps<FuncOp>()) {
if (fn == main_fn) continue;
named_regions.emplace_back("main", &main_fn.getBody());
// Walk over the module collection ops with functions and while ops.
module_.walk([&](Operation* op) {
if (auto fn = dyn_cast<FuncOp>(op)) {
if (fn != main_fn) {
subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
named_regions.emplace_back(fn.getName().str(), &fn.getBody());
}
} else if (auto wo = dyn_cast<mlir::TFL::WhileOp>(op)) {
std::string name = GetWhileCondName(wo);
subgraph_index_map_[name] = subgraph_idx++;
named_regions.emplace_back(GetWhileCondName(wo), &wo.cond());
name = GetWhileBodyName(wo);
subgraph_index_map_[name] = subgraph_idx++;
named_regions.emplace_back(name, &wo.body());
}
});
subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
functions.push_back(fn);
}
// Build subgraph for each of the functions.
// Build subgraph for each of the named regions.
std::vector<BufferOffset<tflite::SubGraph>> subgraphs;
subgraphs.reserve(functions.size());
subgraphs.reserve(named_regions.size());
int first_failed_func = -1;
for (int i = 0; i < functions.size(); ++i) {
auto subgraph_or = BuildSubGraph(functions[i]);
for (auto it : llvm::enumerate(named_regions)) {
auto subgraph_or = BuildSubGraph(it.value().first, it.value().second);
if (!subgraph_or) {
if (first_failed_func == -1)
// Record the index of the first function that cannot be converted.
// Record the index of the first region that cannot be converted.
// Keep looping through all subgraphs in the module to make sure that
// we collect the list of missing ops from the entire module.
first_failed_func = i;
first_failed_func = it.index();
} else {
subgraphs.push_back(*subgraph_or);
}
@ -1238,9 +1315,10 @@ Optional<std::string> Translator::TranslateInternal() {
"-emit-custom-ops flag): " +
failed_custom_ops_list;
return functions[first_failed_func].emitError("failed while converting: '")
<< functions[first_failed_func].getName() << "\'\n"
<< err,
auto& failed_region = named_regions[first_failed_func];
return failed_region.second->getParentOp()->emitError()
<< "failed while converting: '" << failed_region.first
<< "': " << err,
llvm::None;
}

View File

@ -0,0 +1,58 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the operation interface definition file for TensorFlow Lite.
#ifndef TFL_OP_INTERFACES
#define TFL_OP_INTERFACES
include "mlir/IR/OpBase.td"
//===----------------------------------------------------------------------===//
// TFL op interface for stateful operands.
def TFL_StatefulOp : OpInterface<"StatefulOpInterface"> {
let description = [{
Interface for ops that are stateful and need to identify stateful operands.
Stateful operands correspond to TF's variables semantics. An op that has 1
or more stateful operands is a stateful op.
}];
let methods = [
InterfaceMethod<
[{Returns the indices of stateful operands.}],
"std::vector<int>", "GetStatefulOperands", (ins)
>,
];
}
//===----------------------------------------------------------------------===//
// TFL op interface for output channel index.
def TFL_ChannelDimIndexInterface : OpInterface<"ChannelDimIndexInterface"> {
let description = [{
Interface for defining the index of out channel index.
}];
let methods = [
InterfaceMethod<
[{Returns the dimension index of the output channels.}],
"int", "GetChannelDimIndex", (ins)
>,
];
}
#endif // TFL_OP_INTERFACES

View File

@ -797,8 +797,7 @@ struct RemoveAdjacentReshape : public RewritePattern {
// With
// %2 = "tfl.reshape"(%0, %shape1)
rewriter.replaceOpWithNewOp<ReshapeOp>(
{prevOp.getResult()}, op, thisOp.getType(), prevOp.getOperand(0),
thisOp.getOperand(1));
op, thisOp.getType(), prevOp.getOperand(0), thisOp.getOperand(1));
}
};
@ -1302,6 +1301,19 @@ OpFoldResult AbsOp::fold(ArrayRef<Attribute> operands) {
return ConstFoldUnaryOp(result_type, operands[0], compute);
}
//===----------------------------------------------------------------------===//
// NegOp
//===----------------------------------------------------------------------===//
OpFoldResult NegOp::fold(ArrayRef<Attribute> operands) {
Type result_type = getType();
// Only constant fold for tensor of f32 is implemented.
if (!IsF32ShapedType(result_type)) return nullptr;
auto compute = [](APFloat value) -> APFloat { return llvm::neg(value); };
return ConstFoldUnaryOp(result_type, operands[0], compute);
}
//===----------------------------------------------------------------------===//
// SinOp
//===----------------------------------------------------------------------===//
@ -1724,10 +1736,97 @@ static LogicalResult Verify(TransposeOp op) {
return success();
}
namespace {
struct WhileResultOperandsMatch : public OpRewritePattern<WhileOp> {
using OpRewritePattern<WhileOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(WhileOp while_op,
PatternRewriter &rewriter) const override {
auto size = while_op.body().front().getArguments().size();
Operation *op = while_op.getOperation();
auto old_size = op->getNumResults();
// No change needed as the number of operands match the number of results.
if (size == old_size) return matchFailure();
// Collect the new types by combining results of old op with additional
// operand results.
llvm::SmallVector<Type, 4> types;
types.reserve(size);
for (auto type : while_op.getResultTypes()) types.push_back(type);
for (auto arg : while_op.body().front().getArguments().drop_front(old_size))
types.push_back(arg.getType());
// Collect operands.
llvm::SmallVector<Value, 8> operands;
operands.reserve(while_op.getNumOperands());
for (auto operand : while_op.getOperands()) operands.push_back(operand);
// Replace with new While with matching operands and results.
Operation *new_op = rewriter.insert(
Operation::create(op->getLoc(), op->getName(), types, operands,
op->getAttrs(), {}, /*numRegions=*/2,
/*resizableOperandList=*/true));
for (int i = 0; i < 2; ++i) new_op->getRegion(i).takeBody(op->getRegion(i));
rewriter.replaceOp(op,
new_op->getResults().take_front(op->getNumResults()));
return matchSuccess();
}
};
} // namespace
void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<WhileResultOperandsMatch>(context);
}
Region &WhileOp::getLoopBody() { return body(); }
bool WhileOp::isDefinedOutsideOfLoop(Value value) {
// TODO(jpienaar): This is to overly conservative and disables anything other
// than constant hoisting initially.
return false;
}
LogicalResult WhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) {
// TODO(jpienaar): This can be removed post the resizable trait is added.
Operation &while_op = *this->getOperation();
if (!while_op.hasResizableOperandsList()) return failure();
if (ops.empty()) return success();
// Operands to the while op.
llvm::SmallVector<Value, 4> operands(getOperands());
// Results that have to be returned by the body.
llvm::SmallVector<Value, 4> results(
body().front().getTerminator()->getOperands());
for (auto op : ops) {
// Move the hoisted value to just before the while.
op->moveBefore(&while_op);
// Each result of the hoisted op becomes an input to while, cond and body.
for (auto result : op->getResults()) {
operands.push_back(result);
auto type = result.getType();
auto arg = body().front().addArgument(type);
// Loop invariant value passes through the body function unchanged.
result.replaceAllUsesWith(arg);
results.push_back(arg);
// Operand types match for body and cond. The value is hoisted out of the
// body and so not necessarily used in cond. This could be expanded to
// consider common usage across cond and body.
cond().front().addArgument(type);
}
}
body().front().getTerminator()->setOperands(results);
while_op.setOperands(operands);
return success();
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.cc.inc"
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"

View File

@ -27,7 +27,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
#include "mlir/Transforms/LoopLikeInterface.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/lite/schema/schema_generated.h"
@ -44,6 +44,7 @@ class TensorFlowLiteDialect : public Dialect {
Location loc) override;
};
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.h.inc"
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc"

View File

@ -19,6 +19,8 @@ limitations under the License.
#define TFL_OPS
include "mlir/IR/OpBase.td"
include "mlir/Transforms/LoopLikeInterface.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td"
include "tensorflow/compiler/mlir/lite/quantization/quantization.td"
def TFL_Dialect : Dialect {
@ -248,16 +250,6 @@ def TFL_ComparisonBinaryBuilder : OpBuilder<
buildComparisonBinOp(builder, result, lhs, rhs);
}]>;
//===----------------------------------------------------------------------===//
// TFL native op trait for stateful operands and channel indices.
class StatefulOperands<list<int> operands>
: ParamNativeOpTrait<"TFL::StatefulOperands", StrJoinInt<operands>.result>;
class ChannelDimIndex<int index>
: ParamNativeOpTrait<"TFL::ChannelDimIndex", !cast<string>(index)>;
//===----------------------------------------------------------------------===//
// TFL op base class.
//===----------------------------------------------------------------------===//
@ -285,7 +277,7 @@ class TFL_Op<string mnemonic, list<OpTrait> traits = []> :
class TFL_ConvOp<string mnemonic, string opSummary, int index> :
TFL_Op<mnemonic, [NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
ChannelDimIndex<index>, AffineOpCoefficient<index, 1>]> {
TFL_ChannelDimIndexInterface, AffineOpCoefficient<index, 1>]> {
let summary = opSummary # " operator";
let description = [{
@ -335,7 +327,7 @@ an output element, this operation computes \\(y = |x|\\).
let hasFolder = 1;
}
def TFL_AddOp : TFL_Op<"add", [Broadcastable, NoSideEffect, Commutative]> {
def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape, NoSideEffect, Commutative]> {
let summary = "Addition operator";
let description = [{
@ -486,8 +478,7 @@ def TFL_ArgMaxOp : TFL_Op<"arg_max", [NoSideEffect]> {
}];
let arguments = (
// TODO: Add support for uint8.
ins TensorOf<[F32, I32, I8]>:$input,
ins TensorOf<[F32, I32, I8, TFL_Uint8, QI8, QUI8]>:$input,
TFL_I32OrI64Tensor:$dim
);
@ -515,8 +506,7 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> {
}];
let arguments = (
// TODO(pkanwar): Add support for uint8.
ins TensorOf<[F32, I32, I8]>:$input,
ins TensorOf<[F32, I32, I8, TFL_Uint8, QI8, QUI8]>:$input,
TFL_I32OrI64Tensor:$dim
);
@ -617,7 +607,12 @@ def TFL_ExternalConstOp : Op<TFL_Dialect, "external_const", [NoSideEffect]> {
let results = (outs AnyTensor:$output);
}
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0>;
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0> {
let extraClassDeclaration = [{
// StatefulOpInterface:
int GetChannelDimIndex() { return 0; }
}];
}
def TFL_CosOp: TFL_Op<"cos", [
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
@ -637,6 +632,11 @@ def TFL_CosOp: TFL_Op<"cos", [
def TFL_DepthwiseConv2DOp :
TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> {
let arguments = !con(TFL_Conv2DOp.arguments, (ins I32Attr:$depth_multiplier));
let extraClassDeclaration = [{
// StatefulOpInterface:
int GetChannelDimIndex() { return 3; }
}];
}
def TFL_FCWO_Default : StrEnumAttrCase<"DEFAULT">;
@ -650,7 +650,8 @@ def TFL_FullyConnectedOptionsWeightFormatAttr :
// TODO(jpienaar): Update post discussion on semantics of FC OP.
def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
NoSideEffect, AccumulatorUniformScale<2, 0, 1>, ChannelDimIndex<0>,
NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
TFL_ChannelDimIndexInterface,
AffineOpCoefficient<-1, 1>]> {
let summary = "Fully connected op";
@ -672,6 +673,11 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
let verifier = [{ return Verify(*this); }];
let hasOptions = 1;
let extraClassDeclaration = [{
// ChannelDimIndexInterface:
int GetChannelDimIndex() { return 0; }
}];
}
def TFL_GatherOp : TFL_Op<"gather", [
@ -679,7 +685,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
SameOperandsAndResultsScale,
TFL_OperandHasAtleastRank<0, 1>,
PredOpTrait<"params and output must have same element type",
TCresVTEtIsSameAsOp<0, 0>>
TFL_TCresVTEtIsSameAsOp<0, 0>>
]> {
let summary = "Gather operator";
@ -688,7 +694,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
}];
let arguments = (ins
TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8]>:$params,
TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8, QI16]>:$params,
TensorOf<[I32, I64]>:$indices,
I32Attr:$axis
);
@ -701,7 +707,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
];
let results = (outs
TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8]>:$output
TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8, QI16]>:$output
);
let hasOptions = 1;
@ -724,9 +730,9 @@ def TFL_GatherNdOp : TFL_Op<"gather_nd", [NoSideEffect]> {
);
}
// Same type check of lhs and rhs is handled by the Broadcastable trait.
// Same type check of lhs and rhs is handled by the ResultsBroadcastableShape trait.
def TFL_LessEqualOp : TFL_Op<"less_equal", [
Broadcastable, NoSideEffect, NoQuantizableResult]> {
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
let summary = "Less_equal operator";
let description = [{
@ -782,7 +788,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag
}
def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [
Broadcastable, NoSideEffect, NoQuantizableResult]> {
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
let summary = "Greater_equal operator";
let description = [{
@ -943,7 +949,7 @@ larger than 0.
}
def TFL_NotEqualOp : TFL_Op<"not_equal", [
Broadcastable, Commutative, NoSideEffect, NoQuantizableResult]> {
ResultsBroadcastableShape, Commutative, NoSideEffect, NoQuantizableResult]> {
let summary = "Not_equal operator";
let description = [{
@ -970,7 +976,7 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
}
def TFL_DivOp : TFL_Op<"div", [Broadcastable, NoSideEffect]> {
def TFL_DivOp : TFL_Op<"div", [ResultsBroadcastableShape, NoSideEffect]> {
let summary = "Division operator";
let description = [{
@ -1029,7 +1035,7 @@ def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup",
let results = (outs TensorOf<[F32, I8, TFL_Uint8]>:$output);
}
def TFL_EqualOp: TFL_Op<"equal", [Commutative, Broadcastable,
def TFL_EqualOp: TFL_Op<"equal", [Commutative, ResultsBroadcastableShape,
NoQuantizableResult,
PredOpTrait<"Operands have same value type", TCopVTEtIsSameAs<0, 1>>]> {
let summary = "Equal operator";
@ -1063,7 +1069,8 @@ def TFL_ExpOp: TFL_Op<"exp", [NoSideEffect, SameOperandsAndResultType]> {
let hasOptions = 0b1;
}
def TFL_ExpandDimsOp: TFL_Op<"expand_dims", [NoSideEffect]> {
def TFL_ExpandDimsOp: TFL_Op<"expand_dims", [
NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Inserts a dimension of 1 into a tensor's shape.";
let description = [{
@ -1173,7 +1180,7 @@ def TFL_FloorOp: TFL_Op<"floor", [NoSideEffect, SameOperandsAndResultType]> {
}
def TFL_FloorDivOp : TFL_Op<"floor_div", [
Broadcastable, NoSideEffect, BinaryOpSameElementTypeConstraint]> {
ResultsBroadcastableShape, NoSideEffect, BinaryOpSameElementTypeConstraint]> {
let summary = "Floor div operator";
let description = [{
@ -1192,7 +1199,7 @@ def TFL_FloorDivOp : TFL_Op<"floor_div", [
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
}
def TFL_FloorModOp : TFL_Op<"floor_mod", [Broadcastable, NoSideEffect]> {
def TFL_FloorModOp : TFL_Op<"floor_mod", [ResultsBroadcastableShape, NoSideEffect]> {
let summary = "Division reminder";
let description = [{
@ -1209,7 +1216,7 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [Broadcastable, NoSideEffect]> {
}
def TFL_GreaterOp : TFL_Op<"greater", [
Broadcastable, NoSideEffect, NoQuantizableResult]> {
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
let summary = "Greater operator";
let description = [{
@ -1291,7 +1298,7 @@ def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [NoSideEffect, SameOperandsAndResultTy
}
def TFL_LessOp : TFL_Op<"less", [
Broadcastable, NoSideEffect, NoQuantizableResult]> {
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
let summary = "Less operator";
let description = [{
@ -1516,7 +1523,7 @@ def TFL_MaxUnpooling2DOp :
}
def TFL_MaximumOp : TFL_Op<"maximum", [
Broadcastable, NoSideEffect, Commutative, SameOperandsAndResultsScale,
ResultsBroadcastableShape, NoSideEffect, Commutative, SameOperandsAndResultsScale,
TFL_OperandHasRankLessThan<0, 4>, TFL_OperandHasRankLessThan<1, 4>]> {
let summary = "Max operator";
let description = [{
@ -1655,7 +1662,8 @@ def TFL_SumOp: TFL_Op<"sum", [NoSideEffect]> {
let customOption = "ReducerOptions";
}
def TFL_ReduceMinOp: TFL_Op<"reduce_min", [NoSideEffect]> {
def TFL_ReduceMinOp: TFL_Op<"reduce_min", [
NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Min-reduction operator";
let description = [{
@ -1674,7 +1682,8 @@ def TFL_ReduceMinOp: TFL_Op<"reduce_min", [NoSideEffect]> {
let customOption = "ReducerOptions";
}
def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [NoSideEffect]> {
def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [
NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Max-reduction operator";
let description = [{
@ -1713,7 +1722,7 @@ def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [NoSideEffect]> {
}
def TFL_MinimumOp : TFL_Op<"minimum", [
Broadcastable, NoSideEffect, Commutative, SameOperandsAndResultsScale,
ResultsBroadcastableShape, NoSideEffect, Commutative, SameOperandsAndResultsScale,
TFL_OperandHasRankLessThan<0, 4>, TFL_OperandHasRankLessThan<1, 4>]> {
let summary = "Min operator";
let description = [{
@ -1734,7 +1743,7 @@ def TFL_MinimumOp : TFL_Op<"minimum", [
let hasOptions = 0;
}
def TFL_MulOp : TFL_Op<"mul", [Broadcastable, NoSideEffect, Commutative]> {
def TFL_MulOp : TFL_Op<"mul", [ResultsBroadcastableShape, NoSideEffect, Commutative]> {
let summary = "Multiplication operator";
let description = [{
@ -1771,6 +1780,8 @@ def TFL_NegOp: TFL_Op<"neg", [NoSideEffect, SameOperandsAndResultType]> {
let results = (outs AnyTensor:$y);
let hasOptions = 0b1;
let hasFolder = 1;
}
def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> {
@ -1804,14 +1815,14 @@ def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> {
}];
let arguments = (ins
Variadic<TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8]>>:$values,
Variadic<TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8, QI16]>>:$values,
I32Attr:$values_count,
I32Attr:$axis
);
let results = (outs
TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8]>:$output
TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8, QI16]>:$output
);
let verifier = [{ return Verify(*this); }];
@ -1909,7 +1920,7 @@ def TFL_PadV2Op : TFL_Op<"padv2", [
let hasOptions = 1;
}
def TFL_PowOp : TFL_Op<"pow", [Broadcastable, NoSideEffect, NoQuantizableResult]> {
def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
let summary = "Power operator";
let description = [{
@ -2278,7 +2289,7 @@ def TFL_SquareOp: TFL_Op<"square", [
let hasFolder = 1;
}
def TFL_SubOp : TFL_Op<"sub", [Broadcastable, NoSideEffect]> {
def TFL_SubOp : TFL_Op<"sub", [ResultsBroadcastableShape, NoSideEffect]> {
let summary = "Subtraction operator";
let description = [{
@ -2306,7 +2317,7 @@ def TFL_SubOp : TFL_Op<"sub", [Broadcastable, NoSideEffect]> {
// TODO(jpienaar): Expand the kernel implementation to support all types besides
// I32 and F32.
def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [
Broadcastable, NoSideEffect, NoQuantizableResult]> {
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
let summary = "Squared difference operator";
let description = [{
@ -2345,9 +2356,9 @@ def TFL_TanhOp: TFL_Op<"tanh", [
let results = (outs TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$y);
}
def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, SameOperandsAndResultsScale,
PredOpTrait<"resultant element type needs to match first operand type",
TCresVTEtIsSameAsOp<0,0>>]> {
TFL_TCresVTEtIsSameAsOp<0,0>>]> {
let summary = "Tile operator.";
let description = [{
Constructs a tensor by tiling a given tensor.
@ -2360,10 +2371,11 @@ def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
}];
let arguments = (ins
TensorOf<[F32, I1, I32, I64, TFL_Uint8]>:$input,
TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8]>:$input,
TFL_I32OrI64Tensor:$multiples);
let results = (outs TensorOf<[F32, I1, I32, I64, TFL_Uint8]>:$output);
let results = (outs
TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8]>:$output);
let hasOptions = 0;
}
@ -2373,7 +2385,7 @@ def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
// TODO(jpienaar): Check that k is less or equal the internal dimension
def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>,
PredOpTrait<"result and input element type match",
TCresVTEtIsSameAsOp<0,0>>]> {
TCresVTEtIsSameAsOp<0,0>>, SameOperandsAndResultsScale]> {
let summary = "TopK operator";
let description = [{
@ -2383,11 +2395,11 @@ def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>,
}];
let arguments = (ins
TensorOf<[F32, I8, I32, I64, TFL_Uint8]>:$input,
TensorOf<[F32, I8, I32, I64, TFL_Uint8, QI8, QUI8]>:$input,
I32Tensor:$k);
let results = (outs
AnyTensor:$values,
TensorOf<[F32, I8, I32, I64, TFL_Uint8, QI8, QUI8]>:$values,
I32Tensor:$indices);
let builders = [OpBuilder<"Builder *builder, OperationState &result, "
@ -2426,7 +2438,7 @@ def TFL_TransposeOp : TFL_Op<"transpose",
let hasFolder = 1;
}
def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect]> {
def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Unpacks a tensor along a dimension into multiple tensors";
let description = [{
@ -2642,7 +2654,9 @@ def TFL_ResizeBilinearOp: TFL_Op<"resize_bilinear", [
// TODO(ycling): Support quantized types.
TensorOf<[F32, I32, QI8, QUI8]>:$input,
TensorOf<[I32]>:$size,
BoolAttr:$align_corners);
BoolAttr:$align_corners,
DefaultValuedAttr<BoolAttr, "false">:$half_pixel_centers
);
let results = (outs
TensorOf<[F32, QI8, QUI8]>:$output
@ -2751,12 +2765,11 @@ def TFL_CastOp : TFL_Op<"cast", [
Casts input from input type to output type.
}];
// TODO(b/135538711): Add complex types here.
let arguments = (ins
TensorOf<[F32, I1, I32, I64, TFL_Quint8, TFL_Uint8]>:$input
TensorOf<[F32, I1, I32, I64, TFL_Quint8, TFL_Uint8, Complex<F<32>>]>:$input
);
let results = (outs TensorOf<[F32, I1, I32, I64]>:$output);
let results = (outs TensorOf<[F32, I1, I32, I64, Complex<F<32>>]>:$output);
// TFLite's cast op does not utilize CastOptions, instead derives types
// from the TfLiteTensors.
@ -2856,7 +2869,9 @@ def TFL_FakeQuantOp : TFL_Op<"fake_quant", [NoSideEffect]> {
let arguments = (
ins AnyTensor:$input,
// The expected [min, max] range of values.
MinMaxAttr:$minmax,
F32Attr:$min,
F32Attr:$max,
// The bitwidth of the quantization; between 2 and 16, inclusive.
I32Attr:$num_bits,
// Quantization range starts from 0 or 1; starts from 1 if true.
@ -2865,6 +2880,8 @@ def TFL_FakeQuantOp : TFL_Op<"fake_quant", [NoSideEffect]> {
let results = (outs AnyTensor:$output);
let hasCanonicalizer = 0b1;
let hasOptions = 1;
}
def TFL_QConstOp : Op<TFL_Dialect, "pseudo_qconst", [
@ -2911,6 +2928,20 @@ def TFL_QuantizeOp: TFL_Op<"quantize", [
let results = (outs AnyTensor:$output);
}
def TFL_DensifyOp: TFL_Op<"densify", [NoSideEffect,
SameOperandsAndResultType,
NoQuantizableResult]> {
let summary = "Densify operator";
let description = [{
Converts sparse tensor to dense format.
}];
let arguments = (ins AnyTensor:$input);
let results = (outs AnyTensor:$output);
}
//===----------------------------------------------------------------------===//
// LSTM Ops
//===----------------------------------------------------------------------===//
@ -3000,7 +3031,7 @@ def TFL_LSTMOp :
LstmOptionalPeepholeWeightConstraint,
LstmProjectionWeightBiasConstraint,
LstmResultConstraint,
StatefulOperands<[18, 19]>]> {
TFL_StatefulOp]> {
let summary = "The full lstm operator";
let description = [{
@ -3084,6 +3115,11 @@ Ba et al. “Layer Normalization”
let hasOptions = 1;
let verifier = [{ return Verify(*this); }];
let extraClassDeclaration = [{
// StatefulOpInterface:
std::vector<int> GetStatefulOperands() { return {18, 19}; }
}];
}
// UnidirectionalSequenceLstm op.
@ -3095,7 +3131,7 @@ def TFL_UnidirectionalSequenceLSTMOp :
LstmOptionalPeepholeWeightConstraint,
LstmProjectionWeightBiasConstraint,
LstmResultConstraint,
StatefulOperands<[18, 19]>]> {
TFL_StatefulOp]> {
let summary = "Unidirectional sequence lstm operator";
let description = [{
@ -3164,6 +3200,11 @@ def TFL_UnidirectionalSequenceLSTMOp :
let hasOptions = 1;
let verifier = [{ return Verify(*this); }];
let extraClassDeclaration = [{
// StatefulOpInterface:
std::vector<int> GetStatefulOperands() { return {18, 19}; }
}];
}
def RnnResultConstraint : PredOpTrait<
@ -3173,7 +3214,7 @@ def RnnResultConstraint : PredOpTrait<
// UnidirectionalSequenceRNN op.
def TFL_UnidirectionalSequenceRNNOp :
TFL_Op<"unidirectional_sequence_rnn",
[RnnResultConstraint, StatefulOperands<[4]>]> {
[RnnResultConstraint, TFL_StatefulOp]> {
let summary = "Unidirectional sequence rnn operator";
@ -3217,6 +3258,11 @@ def TFL_UnidirectionalSequenceRNNOp :
let customOption = "SequenceRNNOptions";
let verifier = [{ return Verify(*this); }];
let extraClassDeclaration = [{
// StatefulOpInterface:
std::vector<int> GetStatefulOperands() { return {4}; }
}];
}
def TFL_WhereOp : TFL_Op<"where", [NoSideEffect]> {
@ -3268,7 +3314,7 @@ def SVDFResultConstraint: PredOpTrait<
// SVDF op.
def TFL_SVDFOp :
TFL_Op<"svdf",
[SVDFResultConstraint, StatefulOperands<[4]>]> {
[SVDFResultConstraint, TFL_StatefulOp]> {
let summary = "Single value decomposition filter operator";
@ -3304,6 +3350,72 @@ def TFL_SVDFOp :
let hasOptions = 1;
let verifier = [{ return Verify(*this); }];
let extraClassDeclaration = [{
// StatefulOpInterface:
std::vector<int> GetStatefulOperands() { return {4}; }
}];
}
def TFL_SegmentSumOp: TFL_Op<"segment_sum", [NoSideEffect]> {
let summary = "SegmentSum operator";
let description = [{
Computes the sum along segments of a tensor.
}];
let arguments = (ins
TensorOf<[F32, I32]>:$data,
I32Tensor:$segment_ids
);
let results = (outs TensorOf<[F32, I32]>:$output);
}
def TFL_YieldOp : Op<TFL_Dialect, "yield", [Terminator]> {
let summary = "Yield operation";
let description = [{
The "yield" operation represents a return operation within the conditional
and body of structured control flow (e.g., while). The operation takes
variable number of operands and produces no results. The operand number and
types must match the signature of the region that contains the operation.
}];
let arguments = (ins Variadic<AnyType>:$operands);
}
def TFL_WhileOp : Op<TFL_Dialect, "while", [
DeclareOpInterfaceMethods<LoopLikeOpInterface>,
SingleBlockImplicitTerminator<"YieldOp">,
// Make isolated from above to force values through operands to simplify
// exporting to subgraphs.
IsolatedFromAbove]> {
let summary = [{While loop}];
let description = [{
output = input; while (cond(output)) { output = body(output) }
While loop where all values are passes through arguments with no implicit
capture.
input: A list of input tensors whose types are T.
output: A list of output tensors whose types are T.
cond: A region takes 'input' and returns a boolean scalar tensor.
body: A region that takes a list of tensors and returns another
list of tensors. Both lists have the same types.
}];
let arguments = (ins
Variadic<AnyTensor>:$input,
// Used to map StatelessWhile and While op defined in TensorFlow to a common
// op.
DefaultValuedAttr<BoolAttr, "false">:$is_stateless
);
let results = (outs Variadic<AnyTensor>:$output);
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
let hasCanonicalizer = 1;
}
#endif // TFL_OPS

View File

@ -1,67 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the op traits used in the MLIR TensorFlow Lite dialect.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LLVM.h" // TF:llvm-project
namespace mlir {
namespace OpTrait {
namespace TFL {
// The trait to specify that the specified operands of the TFL op are stateful.
// This is used as a trait like this:
//
// class LSTMOp
// : public Op<LSTMOp, OpTrait::TFL::StatefulOperands<18, 19>::Impl> {
//
template <int... Operands>
class StatefulOperands {
public:
template <typename ConcreteType>
class Impl
: public TraitBase<ConcreteType, StatefulOperands<Operands...>::Impl> {
public:
static std::vector<int> GetStatefulOperands() {
return std::vector<int>({Operands...});
}
};
};
// The trait to specify the channel dimension index of the input (first operand)
// of an affine TFL op (Conv2D, DepthwiseConv2D, FullyConnected).
//
// class Conv2DOp
// : public Op<Conv2DOp, OpTrait::TFL::ChannelDimIndex<0>::Impl> {
//
template <int Index>
class ChannelDimIndex {
public:
template <typename ConcreteType>
class Impl : public TraitBase<ConcreteType, ChannelDimIndex<Index>::Impl> {
public:
static int GetChannelDimIndex() { return Index; }
};
};
} // namespace TFL
} // namespace OpTrait
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_

View File

@ -41,13 +41,20 @@ limitations under the License.
#include "tensorflow/lite/delegates/flex/delegate.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/optional_debug_tools.h"
using llvm::cl::desc;
using llvm::cl::init;
using llvm::cl::opt;
// NOLINTNEXTLINE
static opt<std::string> inputFileName(llvm::cl::Positional,
llvm::cl::desc("<input file>"),
llvm::cl::init("-"));
static opt<std::string> input_filename(llvm::cl::Positional,
desc("<input file>"), init("-"));
// NOLINTNEXTLINE
static opt<bool> dump_state("dump-interpreter-state",
desc("dump interpreter state post execution"),
init(false));
// TODO(jpienaar): Move these functions to some debug utils.
static std::string TfLiteTensorDimString(const TfLiteTensor& tensor) {
@ -82,9 +89,9 @@ int main(int argc, char** argv) {
llvm::InitLLVM y(argc, argv);
llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR TFLite runner\n");
auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(inputFileName.c_str());
auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(input_filename.c_str());
if (std::error_code error = file_or_err.getError()) {
LOG(ERROR) << argv[0] << ": could not open input file '" << inputFileName
LOG(ERROR) << argv[0] << ": could not open input file '" << input_filename
<< "': " << error.message() << "\n";
return 1;
}
@ -133,5 +140,7 @@ int main(int argc, char** argv) {
TfLiteTensorString(out).c_str());
}
if (dump_state) tflite::PrintInterpreterState(interpreter.get());
return 0;
}

View File

@ -122,7 +122,7 @@ static void EmitOptionBuilders(const RecordKeeper &record_keeper,
os << formatv(
" auto {0} = Convert{1}ForOptionWriter(op.{0}(), fbb);\n",
val.getName(), record->getClasses()[0]->getName());
options.push_back(val.getName());
options.push_back(std::string(val.getName()));
}
}
}

View File

@ -71,18 +71,17 @@ cc_library(
"quantization_utils.cc",
],
hdrs = [
"quantization_traits.h",
"quantization_utils.h",
],
deps = [
"//tensorflow/core:lib_proto_parsing",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
# TODO(fengliuai): remove this dependence.
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/core:lib_proto_parsing",
],
)

View File

@ -206,10 +206,17 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
std::unique_ptr<OpPassBase<FuncOp>>
CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) {
auto get_name_func = [](Operation *op) {
if (auto name = op->getAttrOfType<StringAttr>("name"))
return name.getValue();
else
return llvm::StringRef("");
Location loc = op->getLoc();
if (auto name = loc.dyn_cast<NameLoc>()) {
return name.getName().strref();
} else if (auto fused_name = loc.dyn_cast<FusedLoc>()) {
for (auto sub_loc : fused_name.getLocations()) {
if (auto named_sub_loc = sub_loc.dyn_cast<NameLoc>()) {
return named_sub_loc.getName().strref();
}
}
}
return llvm::StringRef("");
};
return CreateImportQuantStatsPass(get_name_func, stats_str);

View File

@ -12,6 +12,7 @@ package_group(
includes = ["//third_party/mlir:subpackages"],
packages = [
"//learning/brain/experimental/mlir/...",
"//tensorflow/compiler/mlir/lite/...",
"//tensorflow/lite/...",
],
)
@ -23,7 +24,6 @@ cc_library(
],
hdrs = [
"quantize_model.h",
"//tensorflow/compiler/mlir/lite:transforms/passes.h",
],
deps = [
"//tensorflow/compiler/mlir/lite:common",
@ -42,6 +42,24 @@ cc_library(
],
)
cc_library(
name = "tfl_to_std",
srcs = [
"tfl_to_std.cc",
],
hdrs = [
"tfl_to_std.h",
],
deps = [
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
],
)
# Binary to apply quantization on the annotated files.
tf_cc_binary(
name = "tfl_quantizer",

View File

@ -73,19 +73,19 @@ TfLiteStatus QuantizeModel(
// Apply quantization passes
PassManager pm(module->getContext());
TFL::QuantizationSpecs pass_config;
pass_config.inference_type = tensorflow::DT_QINT8;
pass_config.post_training_quantization = true;
TFL::QuantizationSpecs quant_specs;
quant_specs.inference_type = tensorflow::DT_QINT8;
quant_specs.post_training_quantization = true;
bool emit_adaptor = false;
auto input_tf_type = tflite::TflTypeToTfType(input_type);
if (input_tf_type == tensorflow::DT_FLOAT) {
emit_adaptor = true;
} else if (input_tf_type == tensorflow::DT_UINT8) {
pass_config.inference_type = tensorflow::DT_QUINT8;
quant_specs.inference_type = tensorflow::DT_QUINT8;
}
pm.addPass(TFL::CreatePrepareQuantizePass(pass_config));
pm.addPass(TFL::CreatePrepareQuantizePass(quant_specs));
pm.addPass(TFL::CreateQuantizePass());
pm.addPass(TFL::CreatePostQuantizePass(emit_adaptor));

View File

@ -0,0 +1,62 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
namespace mlir {
namespace TFL {
void ConvertTFLQuantOpsToMlirQuantOps(FuncOp func) {
OpBuilder b(func);
func.walk([&](Operation* op) {
b.setInsertionPoint(op);
if (auto dq = llvm::dyn_cast<DequantizeOp>(op)) {
auto dcast = b.create<quant::DequantizeCastOp>(
dq.getLoc(), dq.output().getType(), dq.input());
dq.output().replaceAllUsesWith(dcast);
dq.erase();
} else if (auto q = llvm::dyn_cast<QuantizeOp>(op)) {
auto qcast = b.create<quant::QuantizeCastOp>(
q.getLoc(), q.output().getType(), q.input());
q.output().replaceAllUsesWith(qcast);
q.erase();
}
});
}
void ConvertMlirQuantOpsToTFLQuantOps(FuncOp func) {
OpBuilder b(func);
func.walk([&](Operation* op) {
b.setInsertionPoint(op);
if (auto dq = llvm::dyn_cast<quant::DequantizeCastOp>(op)) {
auto dcast = b.create<DequantizeOp>(dq.getLoc(), dq.getResult().getType(),
dq.arg());
dq.getResult().replaceAllUsesWith(dcast);
dq.erase();
} else if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(op)) {
auto out_type = q.getResult().getType();
auto qcast = b.create<QuantizeOp>(q.getLoc(), out_type, q.arg(),
TypeAttr::get(out_type));
q.getResult().replaceAllUsesWith(qcast);
q.erase();
}
});
}
} // namespace TFL
} // namespace mlir

View File

@ -0,0 +1,34 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_
#include "mlir/IR/Function.h" // TF:llvm-project
namespace mlir {
namespace TFL {
// Converts all the tfl.quantize/tfl.dequantize ops to the ops in the mlir.quant
// dialect ones in the function.
void ConvertTFLQuantOpsToMlirQuantOps(FuncOp func);
// Converts all the mlir.quant dialect ops to the tfl.quantize/tfl.dequantize
// ops in the function.
void ConvertMlirQuantOpsToTFLQuantOps(FuncOp func);
} // namespace TFL
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_

View File

@ -22,21 +22,6 @@ limitations under the License.
include "mlir/IR/OpBase.td"
include "mlir/Dialect/QuantOps/QuantPredicates.td"
//===----------------------------------------------------------------------===//
// Min-max range pair definitions.
//===----------------------------------------------------------------------===//
// A pair of floating point values which defines the min and max of a value
// range for quantization. The attribute is allowed to be empty or
// have 2 elements.
def MinMaxAttr : Attr<Or<[CPred<"$_self.cast<ArrayAttr>().size() == 0">,
CPred<"$_self.cast<ArrayAttr>().size() == 2">]>,
"min-max range pair"> {
let storageType = [{ ArrayAttr }];
let returnType = [{ ArrayRef<Attribute> }];
}
//===----------------------------------------------------------------------===//
// QuantizedType definitions.
//===----------------------------------------------------------------------===//

View File

@ -23,6 +23,8 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
@ -34,14 +36,14 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/core/platform/logging.h"
#define DEBUG_TYPE "quantization-driver"
namespace mlir {
namespace TFL {
namespace quant {
namespace {
static bool EmptyParams(QuantParams p) { return p == quant::QuantizedType(); }
@ -282,6 +284,37 @@ class QuantizationDriver {
cached.first->second = InitializeState(op, index, res, /*as_result=*/true);
}
void DumpStates(Operation *current_op) {
if (current_op) {
llvm::errs() << "\n\n\n" << current_op->getName() << "\n";
}
fn_.walk([&](Operation *op) {
if (llvm::isa<quant::QuantizeCastOp>(op) ||
llvm::isa<quant::DequantizeCastOp>(op) || llvm::isa<ConstantOp>(op))
return;
if (current_op == op) llvm::errs() << "===>>>";
llvm::errs() << op->getName() << " : (";
for (auto i = 0; i < op->getNumOperands(); ++i) {
if (auto params = GetOperandQuantState(op, i).params)
params.print(llvm::errs());
else
op->getOperand(i).getType().cast<ShapedType>().getElementType().print(
llvm::errs());
llvm::errs() << ",";
}
llvm::errs() << ") -> (";
for (auto i = 0; i < op->getNumResults(); ++i) {
if (auto params = GetResultQuantState(op, i).params)
params.print(llvm::errs());
else
op->getResult(i).getType().cast<ShapedType>().getElementType().print(
llvm::errs());
llvm::errs() << ",";
}
llvm::errs() << ")\n";
});
}
FuncOp fn_;
OpBuilder builder_;
bool is_signed_;
@ -351,7 +384,7 @@ int QuantizationDriver::InitializeState(Operation *op, int index, Value val,
}
bool QuantizationDriver::SetConstantResultParams(Operation *op) {
ElementsAttr attr;
DenseFPElementsAttr attr;
Value res = op->getResult(0);
if (!matchPattern(res, m_Constant(&attr))) {
return false;
@ -458,11 +491,9 @@ void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
// This value isn't an expressed type (float), skip.
if (!new_type) return;
TypeAttr type_attr = TypeAttr::get(new_type);
auto quantize =
builder_.create<TFL::QuantizeOp>(loc, new_type, value, type_attr);
auto dequantize = builder_.create<TFL::DequantizeOp>(loc, expressed_type,
quantize.output());
auto quantize = builder_.create<quant::QuantizeCastOp>(loc, new_type, value);
auto dequantize = builder_.create<quant::DequantizeCastOp>(
loc, expressed_type, quantize.getResult());
// `original_result` has a use to `quantize`, so this will replace that use
// by the result of `dequantize`. Remember to reset that use afterwards
value.replaceAllUsesWith(dequantize);
@ -476,7 +507,7 @@ void QuantizationDriver::RequantizeOpResult(Operation *op, int index,
Value value = op->getResult(index);
if (state->pos == RequantizeState::ON_OUTPUT) {
Operation *user = value.getUses().begin().getUser();
if (llvm::isa<TFL::QuantizeOp>(user)) {
if (llvm::isa<quant::QuantizeCastOp>(user)) {
// The requantize op is inserted between `quantize` and `dequantize` ops.
value = user->getResult(0);
builder_.setInsertionPointAfter(user);
@ -491,8 +522,8 @@ void QuantizationDriver::RequantizeArg(BlockArgument arg,
builder_.setInsertionPointToStart(arg.getOwner());
if (value.hasOneUse()) {
auto user = value.use_begin().getUser();
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
value = q.output();
if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
value = q.getResult();
builder_.setInsertionPoint(arg.getOwner(), ++Block::iterator(user));
}
}
@ -519,9 +550,8 @@ void QuantizationDriver::RequantizeValue(Value value, RequantizeState *state,
// This value isn't an expressed type (float), skip.
if (!new_type) return;
TypeAttr type_attr = TypeAttr::get(new_type);
auto requantize_op =
builder_.create<TFL::QuantizeOp>(loc, new_type, value, type_attr);
builder_.create<quant::QuantizeCastOp>(loc, new_type, value);
value.replaceAllUsesWith(requantize_op);
requantize_op.getOperation()->replaceUsesOfWith(requantize_op, value);
}
@ -651,8 +681,8 @@ void QuantizationDriver::SetupAllStates() {
// If the argument is quantized, it should only has one user.
if (arg.hasOneUse()) {
auto user = value.use_begin().getUser();
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
value = q.output();
if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
value = q.getResult();
}
}
InitializeArgState(arg, value, &value_to_state);
@ -660,7 +690,9 @@ void QuantizationDriver::SetupAllStates() {
fn_.walk([&](Operation *op) {
if (op->isKnownTerminator() ||
op->hasTrait<OpTrait::quant::NoQuantizableResult>())
op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
llvm::isa<quant::DequantizeCastOp>(op) ||
llvm::isa<quant::QuantizeCastOp>(op))
return;
work_list_.push_back(op);
@ -669,8 +701,8 @@ void QuantizationDriver::SetupAllStates() {
if (auto *inst = operand.getDefiningOp()) {
// If the operand comes from a tfl.dequantize op, we use the quantized
// input of this tfl.dequantize op to set the state.
if (auto dq = llvm::dyn_cast<TFL::DequantizeOp>(inst)) {
operand = dq.input();
if (auto dq = llvm::dyn_cast<quant::DequantizeCastOp>(inst)) {
operand = dq.arg();
}
}
InitializeOperandState(op, i, operand, &value_to_state);
@ -683,8 +715,8 @@ void QuantizationDriver::SetupAllStates() {
// create the state and mark it immutable.
if (result.hasOneUse()) {
auto user = result.use_begin().getUser();
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
result = q.output();
if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
result = q.getResult();
}
}
InitializeResultState(op, res, result, &value_to_state);
@ -714,6 +746,8 @@ bool QuantizationDriver::PropagateParams() {
Operation *op = work_list_.back();
work_list_.pop_back();
LLVM_DEBUG(DumpStates(op));
// This op has been quantized, so we should not consider it again.
if (llvm::is_contained(quantized_, op)) continue;
quantized_.insert(op);
@ -738,12 +772,23 @@ bool QuantizationDriver::PropagateParams() {
}
// Use the final state to set all the operands' parameters.
for (int i = 0, e = op->getNumOperands(); i != e; ++i)
changed |= SetOperandParams(op, i, params);
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
if (auto type = op->getOperand(i).getType().dyn_cast<ShapedType>()) {
// Without this check, it will accidently propagate the quantization
// information by the shared non-float tensors.
if (type.getElementType().isa<FloatType>())
changed |= SetOperandParams(op, i, params);
}
}
// Use the final state to set all the results' parameters.
for (int res = 0, e = op->getNumResults(); res != e; ++res)
changed |= SetResultParams(op, res, params);
if (auto type = op->getResult(res).getType().dyn_cast<ShapedType>()) {
// Without this check, it will accidently propagate the quantization
// information by the shared non-float-tensors.
if (type.getElementType().isa<FloatType>())
changed |= SetResultParams(op, res, params);
}
}
// TODO(fengliuai): make the bit width configurable.
@ -822,5 +867,5 @@ void ApplyQuantizationParamsPropagation(
.Run();
}
} // namespace TFL
} // namespace quant
} // namespace mlir

View File

@ -70,7 +70,8 @@ class FixedResultUniformScale {
QuantizedType GetResultQuantizedType(int index) {
auto op = this->getOperation();
auto result_type =
op->getResult(index).getType().template cast<TensorType>();
op->getResult(index).getType().template cast<ShapedType>();
if (!result_type.getElementType().template isa<FloatType>()) return {};
Builder builder(op->getContext());
IntegerType storage_type = builder.getIntegerType(BitWidth);
const double scale = static_cast<double>(ScaleMantissa) *

View File

@ -30,10 +30,9 @@ limitations under the License.
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
namespace mlir {
namespace TFL {
namespace quant {
const float kNearZeroTolerance = 1.0e-6;
@ -66,6 +65,37 @@ static Type GetQuantizedType(Builder builder, Type input_type,
return converter.convert(quantizedEleType);
}
// TODO(fengliuai): promote this utility method to mlir QuantOps.
TypeAttr RescaleQuantizedType(Type input, Attribute factor) {
auto factor_values = factor.dyn_cast_or_null<DenseFPElementsAttr>();
if (!factor_values) return {};
auto ele_type = quant::QuantizedType::getQuantizedElementType(input);
if (!ele_type) return {};
if (auto qtype = ele_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
ArrayRef<double> scales = qtype.getScales();
// Broadcasting hasn't been implemented yet.
if (scales.size() != factor_values.getNumElements()) return {};
SmallVector<double, 4> new_scales;
new_scales.reserve(scales.size());
auto scales_iter = scales.begin();
for (auto f : factor_values) {
new_scales.push_back(*(scales_iter++) *
std::fabs(FloatAttr::getValueAsDouble(f)));
}
// We are assuming symmetric quantization.
auto new_ele_type = quant::UniformQuantizedPerAxisType::get(
qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(),
new_scales, qtype.getZeroPoints(), qtype.getQuantizedDimension(),
qtype.getStorageTypeMin(), qtype.getStorageTypeMax());
if (auto new_type = new_ele_type.castFromExpressedType(
quant::QuantizedType::castToExpressedType(input))) {
return TypeAttr::get(new_type);
}
}
// Currently, we only support per-axis quantized type.
return {};
}
TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min,
Attribute max, int quant_dim,
IntegerAttr num_bits, BoolAttr narrow_range,
@ -369,7 +399,7 @@ static bool PreferResultScale(Operation* op) {
for (auto operand : op->getOperands()) {
if (auto operand_type = operand.getType().dyn_cast<ShapedType>()) {
if (operand_type.getElementType().isa<FloatType>()) {
if (float_operands++ > 1) return true;
if (++float_operands > 1) return true;
}
}
}
@ -429,7 +459,7 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
}
// Step 2: backward pass: For the ops skiped in the forward pass, propagate
// its results scale backwards.
// its results scale backwards as far as possible.
func.walk([&](quant::StatisticsOp stats_op) {
if (redundant_stats_ops.find(stats_op) == redundant_stats_ops.end()) {
all_stats_ops.push_back(stats_op);
@ -441,8 +471,7 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
all_stats_ops.pop_back();
if (auto def = stats_op.arg().getDefiningOp()) {
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() &&
PreferResultScale(def)) {
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>()) {
for (auto input : def->getOperands()) {
if (auto next_stats = llvm::dyn_cast_or_null<quant::StatisticsOp>(
input.getDefiningOp())) {
@ -465,5 +494,5 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
// Returns false if the steps finish without errors.
return false;
}
} // namespace TFL
} // namespace quant
} // namespace mlir

View File

@ -38,7 +38,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
namespace mlir {
namespace TFL {
namespace quant {
using QuantParams = quant::QuantizedType;
using SignedInteger = std::pair<unsigned, unsigned>; // bitwidth and sign
@ -113,8 +113,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
rewriter.setInsertionPointAfter(op);
Type result_type = quant_type.castFromExpressedType(op.getType());
auto q = rewriter.create<Q>(op.getLoc(), result_type, op.arg(),
TypeAttr::get(result_type));
auto q = rewriter.create<Q>(op.getLoc(), result_type, op.arg());
auto dq = rewriter.create<DQ>(op.getLoc(), op.getType(), q);
op.getResult().replaceAllUsesWith(dq);
q.getOperation()->replaceUsesOfWith(dq, op.arg());
@ -168,9 +167,12 @@ struct QuantizationPattern : public RewritePattern {
return matchFailure();
}
// If it is terminator or not quantizable, we shouldn't rewrite.
// If it is terminator or not quantizable or any ops form the mlir quant
// ops dialect, we shouldn't rewrite.
if (quantized_op->isKnownTerminator() ||
quantized_op->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
quantized_op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
llvm::isa<quant::QuantizeCastOp>(quantized_op) ||
llvm::isa<quant::DequantizeCastOp>(quantized_op)) {
return matchFailure();
}
@ -316,7 +318,7 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
PatternMatchResult matchAndRewrite(Q op,
PatternRewriter& rewriter) const override {
Type output_type = op.output().getType();
Type output_type = op.getResult().getType();
auto qtype = QType::getQuantizedElementType(output_type);
if (!qtype || qtype.isSigned()) return this->matchFailure();
@ -352,14 +354,19 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
return this->matchFailure();
}
if (!new_qtype) return this->matchFailure();
Type new_output_type = new_qtype.castFromExpressedType(
QType::castToExpressedType(output_type));
rewriter.replaceOpWithNewOp<Q>(op, new_output_type, op.input(),
TypeAttr::get(new_output_type));
rewriter.replaceOpWithNewOp<Q>(op, new_output_type, op.arg());
return this->matchSuccess();
}
};
// Given a quantized type `input`, magnifying its scales by the factor stored in
// `factor`. If `input` isn't a quantized type or the `factor` doesn't match the
// dimension size of `input` or isn't floating-point, nullptr will be returned.
TypeAttr RescaleQuantizedType(Type input, Attribute factor);
// Converts the min/max/num_bits/narrow_range information to a
// QuantizedType, and then returns the attribute containing the QuantizedType.
// The `min` and `max` arguments can be FloatAttr or DenseFPElementsAttr and
@ -438,7 +445,7 @@ void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
bool RemoveRedundantStatsOps(mlir::FuncOp func,
OpQuantSpecGetter op_quant_spec_getter);
} // namespace TFL
} // namespace quant
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_

View File

@ -0,0 +1,36 @@
package(
default_visibility = [
":friends",
],
licenses = ["notice"], # Apache 2.0
)
package_group(
name = "friends",
includes = ["//third_party/mlir:subpackages"],
packages = [
"//tensorflow/compiler/mlir/...",
"//tensorflow/compiler/mlir/lite/...",
],
)
cc_library(
name = "tf_to_quant",
srcs = [
"tf_to_quant.cc",
],
hdrs = [
"passes.h",
],
deps = [
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
],
alwayslink = 1,
)

View File

@ -0,0 +1,32 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_TENSORFLOW_PASSES_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_TENSORFLOW_PASSES_H_
#include <memory>
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
namespace mlir {
namespace TF {
// Legalize the tf ops to the quant ops, so the quantization passes can work.
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFToQuantPass();
} // namespace TF
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_TENSORFLOW_PASSES_H_

View File

@ -0,0 +1,19 @@
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
package(licenses = ["notice"])
glob_lit_tests(
data = [":test_utilities"],
driver = "@llvm-project//mlir:run_lit.sh",
test_file_exts = ["mlir"],
)
# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "test_utilities",
testonly = True,
data = [
"//tensorflow/compiler/mlir:tf-opt",
"@llvm-project//llvm:FileCheck",
],
)

View File

@ -0,0 +1,148 @@
// RUN: tf-opt -tf-to-quant %s | FileCheck %s
// CHECK-LABEL: fakeQuantPerChannelForActivation
func @fakeQuantPerChannelForActivation(%arg0: tensor<8x3xf32>) -> (tensor<8x3xf32>) {
%arg1 = constant dense<[0.0, -1.0, 1.0]> : tensor<3xf32>
%arg2 = constant dense<[255.0, 254.0, 256.0]> : tensor<3xf32>
%0 = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8x3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<8x3xf32>
return %0 : tensor<8x3xf32>
// CHECK: %[[fq:.*]] = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %cst, %cst_0)
// CHECK: %[[q:.*]] = "quant.qcast"(%[[fq]]) : (tensor<8x3xf32>) -> tensor<8x3x!quant.uniform<i8:f32:1, {1.000000e+00:-128,1.000000e+00:-127,1.000000e+00:-128}>>
// CHECK: %[[dq:.*]] = "quant.dcast"(%[[q]])
// CHECK: return %[[dq]]
}
// CHECK-LABEL: fakeQuantForActivation
func @fakeQuantForActivation(tensor<8xf32>) -> (tensor<8xf32>) {
^bb0(%arg0: tensor<8xf32>):
%arg1 = constant dense<0.0> : tensor<f32>
%arg2 = constant dense<255.0> : tensor<f32>
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
return %0 : tensor<8xf32>
// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0)
// CHECK: %1 = "quant.qcast"(%0) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
// CHECK: %2 = "quant.dcast"(%1)
// CHECK: return %2
}
// CHECK-LABEL: fakeQuantForActivationNoDuplication
func @fakeQuantForActivationNoDuplication(tensor<8xf32>) -> (tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>) {
^bb0(%arg0: tensor<8xf32>):
%arg1 = constant dense<0.0> : tensor<f32>
%arg2 = constant dense<255.0> : tensor<f32>
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
%1 = "quant.qcast"(%0) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
return %1 : tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) {narrow_range = false, num_bits = 3 : i64}
// CHECK: %1 = "quant.qcast"(%0) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
// CHECK: return %1
}
// CHECK-LABEL: fakeQuantFolded
func @fakeQuantFolded() -> (tensor<8xf32>) {
%in = constant dense<0.0> : tensor<8xf32>
%min = constant dense<0.0> : tensor<f32>
%max = constant dense<255.0> : tensor<f32>
%mini = "tf.Identity"(%min) : (tensor<f32>) -> tensor<f32>
%maxi = "tf.Identity"(%max) : (tensor<f32>) -> tensor<f32>
%rst = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
return %rst : tensor<8xf32>
// CHECK: %[[CONSTANT:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<8xf32>}
// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT]]) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
// CHECK: return %[[DEQUANTIZE]] : tensor<8xf32>
}
// CHECK-LABEL: fakeQuantNotFolded
func @fakeQuantNotFolded(tensor<8xf32>, tensor<f32>, tensor<f32>) -> (tensor<8xf32>) {
^bb0(%arg0: tensor<8xf32>, %arg3: tensor<f32>, %arg4: tensor<f32>):
%1 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg3, %arg4) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
return %1 : tensor<8xf32>
// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2)
// CHECK: return %0 : tensor<8xf32>
}
// CHECK-LABEL: fakeQuantWithConv2D
func @fakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
^bb0(%arg: tensor<256x32x32x3xf32>) :
%in = constant dense<0.0> : tensor<3x3x3x16xf32>
%min = constant dense<0.0> : tensor<f32>
%max = constant dense<255.0> : tensor<f32>
%mini = "tf.Identity"(%min) : (tensor<f32>) -> tensor<f32>
%maxi = "tf.Identity"(%max) : (tensor<f32>) -> tensor<f32>
%fq = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<f32>, tensor<f32>) -> tensor<3x3x3x16xf32>
%rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
return %rst : tensor<256x30x30x16xf32>
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32, 1.000000e+00:-128>>
// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
// CHECK: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[DEQUANTIZE]])
// CHECK: return %[[CONV]]
}
// CHECK-LABEL: perChannelFakeQuantWithConv2D
func @perChannelFakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
^bb0(%arg: tensor<256x32x32x3xf32>) :
%in = constant dense<0.0> : tensor<3x3x3x16xf32>
%min = constant dense<0.0> : tensor<16xf32>
%max = constant dense<255.0> : tensor<16xf32>
%mini = "tf.Identity"(%min) : (tensor<16xf32>) -> tensor<16xf32>
%maxi = "tf.Identity"(%max) : (tensor<16xf32>) -> tensor<16xf32>
%fq = "tf.FakeQuantWithMinMaxVarsPerChannel"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<3x3x3x16xf32>
%rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
return %rst : tensor<256x30x30x16xf32>
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32:3,
// CHECK-SAME: {1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,
// CHECK-SAME: 1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128}>>
// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
// CHECK: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[DEQUANTIZE]])
// CHECK: return %[[CONV]] : tensor<256x30x30x16xf32>
}
// CHECK-LABEL: fakeQuantWithDepthwiseConv2D
func @fakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
^bb0(%arg: tensor<256x32x32x3xf32>) :
%in = constant dense<0.0> : tensor<3x3x3x16xf32>
%min = constant dense<0.0> : tensor<f32>
%max = constant dense<255.0> : tensor<f32>
%mini = "tf.Identity"(%min) : (tensor<f32>) -> tensor<f32>
%maxi = "tf.Identity"(%max) : (tensor<f32>) -> tensor<f32>
%fq = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<f32>, tensor<f32>) -> tensor<3x3x3x16xf32>
%rst = "tf.DepthwiseConv2dNative"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
return %rst : tensor<256x30x30x16xf32>
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32, 1.000000e+00:-128>>
// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
// CHECK: %[[CONV:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[DEQUANTIZE]])
// CHECK: return %[[CONV]]
}
// CHECK-LABEL: perChannelFakeQuantWithDepthwiseConv2D
func @perChannelFakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
^bb0(%arg: tensor<256x32x32x3xf32>) :
%in = constant dense<0.0> : tensor<3x3x3x16xf32>
%min = constant dense<0.0> : tensor<16xf32>
%max = constant dense<255.0> : tensor<16xf32>
%mini = "tf.Identity"(%min) : (tensor<16xf32>) -> tensor<16xf32>
%maxi = "tf.Identity"(%max) : (tensor<16xf32>) -> tensor<16xf32>
%fq = "tf.FakeQuantWithMinMaxVarsPerChannel"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<3x3x3x16xf32>
%rst = "tf.DepthwiseConv2dNative"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
return %rst : tensor<256x30x30x16xf32>
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32:3,
// CHECK-SAME: {1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,
// CHECK-SAME: 1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128}>>
// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
// CHECK: %[[CONV:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[DEQUANTIZE]])
// CHECK: return %[[CONV]]
}

View File

@ -0,0 +1,162 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TF {
//===----------------------------------------------------------------------===//
// The pass to legalize the quantization emulation ops from TF.
//
namespace {
// Legalize TF quantization emulation ops to that in Quant ops dialect.
struct LegalizeTFToQuant : public FunctionPass<LegalizeTFToQuant> {
explicit LegalizeTFToQuant() = default;
LegalizeTFToQuant(const LegalizeTFToQuant &) {}
/// Performs the lowering to Quant ops dialect.
void runOnFunction() override;
};
// TODO(fengliuai): move this rule to PreparePatterns.td
// TODO(b/140968741): propagate the sign from the command line. Currently all
// the FakeQuant is assumed to targeting UIN8, but per-channel kernel is
// actually INT8.
// Inserts a "tfl.quantize" and "tfl.dequantize" op pair (QDQs) after the
// "tf.FakeQuantWithMinMaxVarsOp" to be constant folded. Since the constant
// folding logic will use a "std.constant" op to replace the
// "tf.FakeQuantWithMinMaxVarsOp", the "tfl.quantize" op is used to preserve
// the quantization parameters as a TypeAttr and "tfl.dequantize" op used to
// convert the output type to the next op. Here are the transformations:
//
// input min cst max cst input min cst max cst
// \ | | \ | |
// \ (tf.Identity) (tf.Identity) => \ (tf.Identity) (tf.Identity)
// \ | | \ | |
// tf.FakeQuantWithMinMaxVars tf.FakeQuantWithMinMaxVars
// | |
// tf.quantize
// |
// tf.dequantize
// |
// If the input is a constant, the result pattern will eventually converted to
//
// quant-emulated input
// |
// tf.quantize
// |
// tf.dequantize
// |
template <typename TFFakeQuantOp, bool PerAxis>
struct InsertQuantOpsAfterTFFakeQuantOp
: public OpRewritePattern<TFFakeQuantOp> {
using BaseType = InsertQuantOpsAfterTFFakeQuantOp<TFFakeQuantOp, PerAxis>;
explicit InsertQuantOpsAfterTFFakeQuantOp<TFFakeQuantOp, PerAxis>(
MLIRContext *ctx)
: OpRewritePattern<TFFakeQuantOp>(ctx) {}
PatternMatchResult matchAndRewrite(TFFakeQuantOp tf_op,
PatternRewriter &rewriter) const override {
// We don't want to insert quantize/dequantize if the quantize op exists.
auto res = tf_op.outputs();
if (!res.hasOneUse() || isa<quant::QuantizeCastOp>(*res.user_begin()))
return this->matchFailure();
// Extract the min/max constant values from the operands. We also consider
// a special case that there are tf.Identity ops between the min/max
// constants and the tf.FakeQuantWithMinMaxVarsOp.
Value min = tf_op.min(), max = tf_op.max();
DenseFPElementsAttr min_value, max_value;
if (auto id1 = dyn_cast_or_null<TF::IdentityOp>(min.getDefiningOp())) {
id1.replaceAllUsesWith(id1.input());
min = tf_op.min();
rewriter.eraseOp(id1);
}
if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max.getDefiningOp())) {
id2.replaceAllUsesWith(id2.input());
max = tf_op.max();
rewriter.eraseOp(id2);
}
if (!matchPattern(min, m_Constant(&min_value))) return this->matchFailure();
if (!matchPattern(max, m_Constant(&max_value))) return this->matchFailure();
int quant_dim = -1;
if (PerAxis) {
// This is a special case that the quant_dim is the last dimensions
// according to the tf.FakeQuantWithMinMaxPerChannel.
quant_dim = res.getType().template cast<ShapedType>().getRank() - 1;
}
// Use the min/max from the operands and the num_bits and narrow_range
// attribute to create the quantization parameter for the new quantize op.
rewriter.setInsertionPointAfter(tf_op);
IntegerAttr num_bits =
rewriter.getI64IntegerAttr(tf_op.num_bits().getSExtValue());
BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.narrow_range());
Type res_type = tf_op.getType();
TypeAttr qtype = quant::GetQuantizedTypeAttr(
rewriter, res_type, min_value, max_value, quant_dim, num_bits,
narrow_range, /*is_signed=*/true);
if (!qtype) this->matchFailure();
// Finally, use the quantization parameter to create the quantize and
// dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp
// and its users.
Value value = tf_op.outputs();
auto quantize = rewriter.create<quant::QuantizeCastOp>(
tf_op.getLoc(), qtype.getValue(), value);
auto dequantize = rewriter.create<quant::DequantizeCastOp>(
tf_op.getLoc(), res_type, quantize.getResult());
value.replaceAllUsesWith(dequantize);
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
return this->matchSuccess();
}
};
using PreparePerTensorFakeQuant =
InsertQuantOpsAfterTFFakeQuantOp<TF::FakeQuantWithMinMaxVarsOp, false>;
using PreparePerChannelFakeQuant =
InsertQuantOpsAfterTFFakeQuantOp<TF::FakeQuantWithMinMaxVarsPerChannelOp,
true>;
// TODO(fengliuai): add the support of the tf.QuantizeAndDequantize*
// legalization.
void LegalizeTFToQuant::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
auto *ctx = func.getContext();
patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant>(ctx);
applyPatternsGreedily(func, patterns);
}
} // namespace
// Creates an instance of the TensorFlow dialect to QuantOps dialect pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFToQuantPass() {
return std::make_unique<LegalizeTFToQuant>();
}
static PassRegistration<LegalizeTFToQuant> pass(
"tf-to-quant", "Legalize TF to quant ops dialect");
} // namespace TF
} // namespace mlir

View File

@ -3,7 +3,8 @@
// CHECK-LABEL: import_stats_skip
func @import_stats_skip(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "skip"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
loc(fused["skip1", "skip2.cc":10:8, callsite("op" at "skip3.cc":10:8)])
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
// CHECK-NEXT: "tfl.split"
@ -12,7 +13,8 @@ func @import_stats_skip(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf3
// CHECK-LABEL: import_stats_name
func @import_stats_name(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
loc(fused["skip1.cc":10:8, "op", callsite("skip2" at "skip3.cc":10:8)])
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"
@ -23,7 +25,8 @@ func @import_stats_name(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf3
// CHECK-LABEL: import_stats_name_port
func @import_stats_name_port(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op_0"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
loc(fused["skip1.cc":10:8, "op_0", callsite("skip2" at "skip3.cc":10:8)])
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"
@ -34,6 +37,7 @@ func @import_stats_name_port(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor
// CHECK-LABEL: import_stats_name_regex
func @import_stats_name_regex(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op_regex"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
loc(fused["skip1.cc":10:8, "op_regex", callsite("skip2" at "skip3.cc":10:8)])
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"

View File

@ -46,9 +46,9 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) {
std::vector<Record *> defs = records.getAllDerivedDefinitions("Op");
llvm::sort(defs, LessRecord());
OUT(0) << "static std::unique_ptr<OpQuantSpec> "
OUT(0) << "static std::unique_ptr<quant::OpQuantSpec> "
"GetOpQuantSpec(mlir::Operation *op) {\n";
OUT(2) << "auto spec = absl::make_unique<OpQuantSpec>();\n";
OUT(2) << "auto spec = absl::make_unique<quant::OpQuantSpec>();\n";
llvm::SmallVector<llvm::StringRef, 3> matches;
for (auto *def : defs) {
Operator op(def);
@ -74,7 +74,7 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) {
if (acc_uniform_trait_regex.match(trait_str, &matches)) {
OUT(4) << "spec->biases_params.emplace(std::make_pair(" << matches[1]
<< ", std::make_pair(tfl.GetAllNonBiasOperands(),"
<< "GetUniformQuantizedTypeForBias)));\n";
<< "quant::GetUniformQuantizedTypeForBias)));\n";
matches.clear();
}
// There is a "QuantChannelDim" trait, set the quantization dimension.

View File

@ -0,0 +1,36 @@
package(
default_visibility = [
":friends",
],
licenses = ["notice"], # Apache 2.0
)
package_group(
name = "friends",
includes = ["//third_party/mlir:subpackages"],
packages = [
"//tensorflow/compiler/mlir/...",
"//tensorflow/compiler/mlir/lite/...",
],
)
cc_library(
name = "hlo_xla_quantization_passes",
srcs = [
"op_quant_spec.inc",
"propagate.cc",
],
hdrs = [
"passes.h",
],
deps = [
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
],
alwayslink = 1,
)

View File

@ -0,0 +1,7 @@
// TODO(fengliuai): automatically generate this file
// TODO(fengliuai): add all the xla_hlo ops
static std::unique_ptr<quant::OpQuantSpec> GetOpQuantSpec(mlir::Operation *op) {
auto spec = absl::make_unique<quant::OpQuantSpec>();
return spec;
}

Some files were not shown because too many files have changed in this diff Show More