Merge branch 'master' into google_upstream_rocblas_complex

This commit is contained in:
ekuznetsov139 2020-01-21 21:21:25 -08:00 committed by GitHub
commit b64dde60e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1602 changed files with 63774 additions and 20626 deletions

View File

@ -279,7 +279,6 @@ build:windows --host_linkopt=/OPT:REF
build:windows --linkopt=/OPT:ICF build:windows --linkopt=/OPT:ICF
build:windows --host_linkopt=/OPT:ICF build:windows --host_linkopt=/OPT:ICF
build:windows --experimental_strict_action_env=true build:windows --experimental_strict_action_env=true
build:windows --incompatible_windows_native_test_wrapper
# Verbose failure logs when something goes wrong # Verbose failure logs when something goes wrong
build:windows --verbose_failures build:windows --verbose_failures
@ -344,6 +343,7 @@ build:rbe_linux --config=avx_linux
build:rbe_linux --config=short_logs build:rbe_linux --config=short_logs
# TODO(gunan): Check why we need this specified in rbe, but not in other builds. # TODO(gunan): Check why we need this specified in rbe, but not in other builds.
build:rbe_linux --linkopt=-lrt build:rbe_linux --linkopt=-lrt
build:rbe_linux --linkopt=-lm
build:rbe_cpu_linux --config=rbe_linux build:rbe_cpu_linux --config=rbe_linux
build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain" build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain"

View File

@ -1 +1 @@
1.1.0 1.2.1

File diff suppressed because one or more lines are too long

View File

@ -1,11 +1,13 @@
workspace(name = "org_tensorflow") workspace(name = "org_tensorflow")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("//third_party:repo.bzl", "tf_http_archive")
http_archive( tf_http_archive(
name = "io_bazel_rules_closure", name = "io_bazel_rules_closure",
sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9", sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9",
strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149", strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149",
patch_file = "@org_tensorflow//third_party:rules_closure.patch",
urls = [ urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", "https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz",
"https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13 "https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13
@ -48,38 +50,6 @@ load("//third_party/toolchains/preconfig/generate:workspace.bzl",
remote_config_workspace() remote_config_workspace()
# Apple and Swift rules.
http_archive(
name = "build_bazel_rules_apple",
sha256 = "a045a436b642c70fb0c10ca84ff0fd2dcbd59cc89100d597a61e8374afafb366",
urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.18.0/rules_apple.0.18.0.tar.gz"],
) # https://github.com/bazelbuild/rules_apple/releases
http_archive(
name = "build_bazel_rules_swift",
sha256 = "18cd4df4e410b0439a4935f9ca035bd979993d42372ba79e7f2d4fafe9596ef0",
urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.12.1/rules_swift.0.12.1.tar.gz"],
) # https://github.com/bazelbuild/rules_swift/releases
http_archive(
name = "build_bazel_apple_support",
sha256 = "122ebf7fe7d1c8e938af6aeaee0efe788a3a2449ece5a8d6a428cb18d6f88033",
urls = ["https://github.com/bazelbuild/apple_support/releases/download/0.7.1/apple_support.0.7.1.tar.gz"],
) # https://github.com/bazelbuild/apple_support/releases
http_archive(
name = "bazel_skylib",
sha256 = "1dde365491125a3db70731e25658dfdd3bc5dbdfd11b840b3e987ecf043c7ca0",
urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.9.0/bazel-skylib.0.9.0.tar.gz"],
) # https://github.com/bazelbuild/bazel-skylib/releases
http_archive(
name = "com_github_apple_swift_swift_protobuf",
type = "zip",
strip_prefix = "swift-protobuf-1.6.0/",
urls = ["https://github.com/apple/swift-protobuf/archive/1.6.0.zip"],
) # https://github.com/apple/swift-protobuf/releases
http_file(
name = "xctestrunner",
executable = 1,
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.9/ios_test_runner.par"],
) # https://github.com/google/xctestrunner/releases
# Use `swift_rules_dependencies` to fetch the toolchains. With the # Use `swift_rules_dependencies` to fetch the toolchains. With the
# `git_repository` rules above, the following call will skip redefining them. # `git_repository` rules above, the following call will skip redefining them.
load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies") load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies")

View File

@ -49,8 +49,8 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
_TF_WORKSPACE_ROOT = '' _TF_WORKSPACE_ROOT = ''
_TF_BAZELRC = '' _TF_BAZELRC = ''
_TF_CURRENT_BAZEL_VERSION = None _TF_CURRENT_BAZEL_VERSION = None
_TF_MIN_BAZEL_VERSION = '1.0.0' _TF_MIN_BAZEL_VERSION = '1.2.1'
_TF_MAX_BAZEL_VERSION = '1.1.0' _TF_MAX_BAZEL_VERSION = '1.2.1'
NCCL_LIB_PATHS = [ NCCL_LIB_PATHS = [
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', '' 'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''

View File

@ -2,6 +2,7 @@
# TensorFlow is a computational framework, primarily for use in machine # TensorFlow is a computational framework, primarily for use in machine
# learning applications. # learning applications.
load("@bazel_skylib//lib:selects.bzl", "selects")
load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary") load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary")
load( load(
"//tensorflow/core/platform:build_config.bzl", "//tensorflow/core/platform:build_config.bzl",
@ -478,6 +479,7 @@ bzl_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//tensorflow/core/platform:build_config_root_bzl", "//tensorflow/core/platform:build_config_root_bzl",
"//tensorflow/core/platform:rules_cc_bzl",
"//tensorflow/core/platform/default:cuda_build_defs_bzl", "//tensorflow/core/platform/default:cuda_build_defs_bzl",
"//third_party/mkl:build_defs_bzl", "//third_party/mkl:build_defs_bzl",
"//third_party/mkl_dnn:build_defs_bzl", "//third_party/mkl_dnn:build_defs_bzl",

View File

@ -23,10 +23,6 @@ from __future__ import print_function
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
from tensorflow.python.util.lazy_loader import LazyLoader
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
del LazyLoader
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
from tensorflow.python.platform import app # pylint: disable=g-import-not-at-top from tensorflow.python.platform import app # pylint: disable=g-import-not-at-top
app.flags = flags app.flags = flags

View File

@ -302,6 +302,7 @@ tf_cuda_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:attr_builder", "//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/platform", "//tensorflow/core/platform",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",

View File

@ -458,7 +458,7 @@ static void TF_Run_Helper(
EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape()); EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
continue; continue;
} }
c_outputs[i] = TF_TensorFromTensor(src, status); c_outputs[i] = TF_TensorFromTensor(src, &status->status);
if (!status->status.ok()) return; if (!status->status.ok()) return;
} }
} }
@ -1493,7 +1493,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
Tensor t; Tensor t;
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t); status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
if (!status->status.ok()) return; if (!status->status.ok()) return;
*value = TF_TensorFromTensor(t, status); *value = TF_TensorFromTensor(t, &status->status);
} }
void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name, void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
@ -1504,7 +1504,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
if (!status->status.ok()) return; if (!status->status.ok()) return;
const auto len = std::min(max_values, static_cast<int>(ts.size())); const auto len = std::min(max_values, static_cast<int>(ts.size()));
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
values[i] = TF_TensorFromTensor(ts[i], status); values[i] = TF_TensorFromTensor(ts[i], &status->status);
} }
} }
@ -2398,7 +2398,7 @@ unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
graph->graph.versions().producer(), &evaluated, &result_tensor); graph->graph.versions().producer(), &evaluated, &result_tensor);
if (evaluated) { if (evaluated) {
DCHECK(status->status.ok()); DCHECK(status->status.ok());
*result = TF_TensorFromTensor(result_tensor, status); *result = TF_TensorFromTensor(result_tensor, &status->status);
if (!status->status.ok()) evaluated = false; if (!status->status.ok()) evaluated = false;
} }
return evaluated; return evaluated;

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/shape_inference.h"
@ -549,7 +550,7 @@ TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(TFE_Op* op,
TF_Status* status) { TF_Status* status) {
TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification; TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification;
n->thread.reset(op->operation.EagerContext()->TFEnv()->StartThread( n->thread.reset(op->operation.EagerContext().TFEnv()->StartThread(
tensorflow::ThreadOptions(), "ExecuteOpThread", tensorflow::ThreadOptions(), "ExecuteOpThread",
[op, retvals, num_retvals, n]() { [op, retvals, num_retvals, n]() {
TFE_Execute(op, retvals, num_retvals, n->status.get()); TFE_Execute(op, retvals, num_retvals, n->status.get());
@ -634,7 +635,7 @@ TF_Tensor* TF_CheckpointReaderGetTensor(TF_CheckpointReader* reader,
std::unique_ptr<tensorflow::Tensor> tensor; std::unique_ptr<tensorflow::Tensor> tensor;
reader->GetTensor(name, &tensor, status); reader->GetTensor(name, &tensor, status);
if (!status->status.ok()) return nullptr; if (!status->status.ok()) return nullptr;
return tensorflow::TF_TensorFromTensor(*tensor, status); return tensorflow::TF_TensorFromTensor(*tensor, &status->status);
} }
void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader, void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader,
@ -767,8 +768,9 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
} while (0); } while (0);
// New server created for new server_def. Unused if updating server_def. // New server created for new server_def. Unused if updating server_def.
tensorflow::EagerContext* context = ctx->context;
tensorflow::GrpcServer* grpc_server = tensorflow::GrpcServer* grpc_server =
dynamic_cast<tensorflow::GrpcServer*>(ctx->context->GetServer()); dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
if (grpc_server == nullptr) { if (grpc_server == nullptr) {
std::unique_ptr<tensorflow::ServerInterface> new_server; std::unique_ptr<tensorflow::ServerInterface> new_server;
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server)); LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
@ -779,12 +781,12 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
} }
LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
LOG_AND_RETURN_IF_ERROR(ctx->context->StoreCollectiveOpsServer( LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
std::move(new_server), grpc_server->worker_env()->device_mgr, std::move(new_server), grpc_server->worker_env()->device_mgr,
grpc_server->worker_env()->collective_executor_mgr)); grpc_server->worker_env()->collective_executor_mgr));
} else { } else {
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def)); LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
LOG_AND_RETURN_IF_ERROR(ctx->context->StoreCollectiveOpsServer( LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
/*new_server=*/nullptr, grpc_server->worker_env()->device_mgr, /*new_server=*/nullptr, grpc_server->worker_env()->device_mgr,
grpc_server->worker_env()->collective_executor_mgr)); grpc_server->worker_env()->collective_executor_mgr));
} }

View File

@ -1260,11 +1260,10 @@ TEST_F(CApiFunctionTest, GraphToFunctionDefWithPlaceholderAttr) {
NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node3", "v2", NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node3", "v2",
&node3); &node3);
TF_Output inputs[] = {};
TF_Output outputs[] = {{node1, 0}, {node2, 0}, {node3, 0}}; TF_Output outputs[] = {{node1, 0}, {node2, 0}, {node3, 0}};
func_ = TF_GraphToFunction( func_ = TF_GraphToFunction(
func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1, func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
/*opers=*/nullptr, 0, inputs, 3, outputs, /*opers=*/nullptr, 0, nullptr, 3, outputs,
/*output_names=*/nullptr, /*output_names=*/nullptr,
/*opts=*/nullptr, /*description=*/nullptr, s.get()); /*opts=*/nullptr, /*description=*/nullptr, s.get());
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get()); ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
@ -1300,10 +1299,9 @@ TEST_F(CApiFunctionTest, GraphToFunctionDefWithArgAttr) {
&node); &node);
TF_Output inputs[] = {{node, 0}}; TF_Output inputs[] = {{node, 0}};
TF_Output outputs[] = {};
func_ = TF_GraphToFunction( func_ = TF_GraphToFunction(
func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1, func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
/*opers=*/nullptr, 1, inputs, 0, outputs, /*opers=*/nullptr, 1, inputs, 0, nullptr,
/*output_names=*/nullptr, /*output_names=*/nullptr,
/*opts=*/nullptr, /*description=*/nullptr, s.get()); /*opts=*/nullptr, /*description=*/nullptr, s.get());
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get()); ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
@ -1603,11 +1601,10 @@ void DefineStatefulFunction(const char* name, TF_Function** func) {
TF_Operation* random = TF_Operation* random =
RandomUniform(shape, TF_FLOAT, func_graph.get(), s.get()); RandomUniform(shape, TF_FLOAT, func_graph.get(), s.get());
TF_Output inputs[] = {};
TF_Output outputs[] = {{random, 0}}; TF_Output outputs[] = {{random, 0}};
*func = TF_GraphToFunction(func_graph.get(), name, *func = TF_GraphToFunction(func_graph.get(), name,
/*append_hash_to_fn_name=*/false, -1, /*append_hash_to_fn_name=*/false, -1,
/*opers=*/nullptr, 0, inputs, 1, outputs, /*opers=*/nullptr, 0, nullptr, 1, outputs,
/*output_names=*/nullptr, /*output_names=*/nullptr,
/*opts=*/nullptr, "", s.get()); /*opts=*/nullptr, "", s.get());
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get()); ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());

View File

@ -188,7 +188,7 @@ namespace tensorflow {
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status); TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
TF_Buffer* out); TF_Buffer* out);

View File

@ -51,7 +51,7 @@ limitations under the License.
#include "tensorflow/core/util/equal_graph_def.h" #include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow { namespace tensorflow {
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status); TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
namespace { namespace {
@ -227,7 +227,7 @@ TEST(CAPI, LibraryLoadFunctions) {
void TestEncodeDecode(int line, const std::vector<string>& data) { void TestEncodeDecode(int line, const std::vector<string>& data) {
const tensorflow::int64 n = data.size(); const tensorflow::int64 n = data.size();
TF_Status* status = TF_NewStatus(); Status status;
for (const std::vector<tensorflow::int64>& dims : for (const std::vector<tensorflow::int64>& dims :
std::vector<std::vector<tensorflow::int64>>{ std::vector<std::vector<tensorflow::int64>>{
{n}, {1, n}, {n, 1}, {n / 2, 2}}) { {n}, {1, n}, {n, 1}, {n / 2, 2}}) {
@ -236,8 +236,8 @@ void TestEncodeDecode(int line, const std::vector<string>& data) {
for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) { for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) {
src.flat<tstring>()(i) = data[i]; src.flat<tstring>()(i) = data[i];
} }
TF_Tensor* dst = TF_TensorFromTensor(src, status); TF_Tensor* dst = TF_TensorFromTensor(src, &status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); ASSERT_TRUE(status.ok()) << status.error_message();
// Convert back to a C++ Tensor and ensure we get expected output. // Convert back to a C++ Tensor and ensure we get expected output.
Tensor output; Tensor output;
@ -249,7 +249,6 @@ void TestEncodeDecode(int line, const std::vector<string>& data) {
TF_DeleteTensor(dst); TF_DeleteTensor(dst);
} }
TF_DeleteStatus(status);
} }
TEST(CAPI, TensorEncodeDecodeStrings) { TEST(CAPI, TensorEncodeDecodeStrings) {
@ -1394,8 +1393,9 @@ TEST(CAPI, SavedModel) {
TF_Operation* input_op = TF_Operation* input_op =
TF_GraphOperationByName(graph, input_op_name.c_str()); TF_GraphOperationByName(graph, input_op_name.c_str());
ASSERT_TRUE(input_op != nullptr); ASSERT_TRUE(input_op != nullptr);
csession.SetInputs({{input_op, TF_TensorFromTensor(input, s)}}); Status status;
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); csession.SetInputs({{input_op, TF_TensorFromTensor(input, &status)}});
ASSERT_TRUE(status.ok()) << status.error_message();
const tensorflow::string output_op_name( const tensorflow::string output_op_name(
tensorflow::ParseTensorName(output_name).first); tensorflow::ParseTensorName(output_name).first);
@ -2522,12 +2522,11 @@ TEST(CAPI, TestTensorIsNotAligned) {
// Take an unaligned slice. // Take an unaligned slice.
Tensor y = x.Slice(1, 13); Tensor y = x.Slice(1, 13);
TF_Status* status = TF_NewStatus(); Status status;
TF_Tensor* a = TF_TensorFromTensor(y, status); TF_Tensor* a = TF_TensorFromTensor(y, &status);
if (EIGEN_MAX_ALIGN_BYTES > 0) { if (EIGEN_MAX_ALIGN_BYTES > 0) {
EXPECT_FALSE(TF_TensorIsAligned(a)); EXPECT_FALSE(TF_TensorIsAligned(a));
} }
TF_DeleteStatus(status);
TF_DeleteTensor(a); TF_DeleteTensor(a);
} }

View File

@ -17,7 +17,7 @@ limitations under the License.
#include <memory.h> #include <memory.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <sys/time.h> #include <time.h>
#include <unistd.h> #include <unistd.h>
#include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api.h"
@ -58,12 +58,8 @@ int main(int argc, char** argv) {
} }
char file_name[100]; char file_name[100];
struct timeval t; time_t t = time(NULL);
if (gettimeofday(&t, NULL)) { snprintf(file_name, sizeof(file_name), "test-%d-%ld.txt", getpid(), t);
perror("gettimeofday failed");
return 1;
}
snprintf(file_name, sizeof(file_name), "test-%d-%ld.txt", getpid(), t.tv_sec);
size_t length = 2 + strlen(path) + strlen(file_name); size_t length = 2 + strlen(path) + strlen(file_name);
char* full_path = malloc(length); char* full_path = malloc(length);

View File

@ -26,8 +26,8 @@ tf_cuda_library(
"c_api.cc", "c_api.cc",
"c_api_debug.cc", "c_api_debug.cc",
"c_api_experimental.h", "c_api_experimental.h",
"c_api_internal.cc",
"c_api_internal.h", "c_api_internal.h",
"tensor_handle_interface.h",
], ],
hdrs = ["c_api.h"], hdrs = ["c_api.h"],
copts = tf_copts() + tfe_xla_copts(), copts = tf_copts() + tfe_xla_copts(),
@ -93,6 +93,7 @@ filegroup(
srcs = [ srcs = [
"c_api_experimental.h", "c_api_experimental.h",
"c_api_internal.h", "c_api_internal.h",
"tensor_handle_interface.h",
], ],
visibility = [ visibility = [
"//tensorflow/core:__pkg__", "//tensorflow/core:__pkg__",
@ -102,7 +103,10 @@ filegroup(
tf_cuda_library( tf_cuda_library(
name = "c_api_internal", name = "c_api_internal",
srcs = ["c_api_experimental.h"], srcs = [
"c_api_experimental.h",
"tensor_handle_interface.h",
],
hdrs = ["c_api_internal.h"], hdrs = ["c_api_internal.h"],
visibility = [ visibility = [
"//learning/deepmind/courier:__subpackages__", "//learning/deepmind/courier:__subpackages__",

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/tf_tensor_internal.h" #include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h"
@ -81,6 +82,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/thread_annotations.h"
@ -93,10 +95,8 @@ using tensorflow::string;
namespace { namespace {
const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) { const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) {
if (op->inference_ctx) { const tensorflow::OpDef* op_def = op->operation.OpDef();
return op->inference_ctx->op_def; if (op_def) return op_def;
}
const tensorflow::OpDef* op_def;
status->status = status->status =
tensorflow::OpDefForOp(op->operation.Name().c_str(), &op_def); tensorflow::OpDefForOp(op->operation.Name().c_str(), &op_def);
return op_def; return op_def;
@ -409,6 +409,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
// New server created for new server_def. Unused if updating server_def. // New server created for new server_def. Unused if updating server_def.
std::unique_ptr<tensorflow::ServerInterface> new_server; std::unique_ptr<tensorflow::ServerInterface> new_server;
tensorflow::EagerContext* context = ctx->context;
tensorflow::GrpcServer* grpc_server; tensorflow::GrpcServer* grpc_server;
if (reset_context) { if (reset_context) {
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server)); LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
@ -416,26 +417,25 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
LOG_AND_RETURN_IF_ERROR( LOG_AND_RETURN_IF_ERROR(
ListRemoteWorkers(grpc_server, worker_name, &remote_workers)); ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
} else { } else {
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers( LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
ctx->context->GetServer(), worker_name, &curr_remote_workers)); &curr_remote_workers));
// No need to check the cast here, since `ListRemoteWorkers` already checks // No need to check the cast here, since `ListRemoteWorkers` already checks
// if the server is a GRPC server or not. // if the server is a GRPC server or not.
grpc_server = grpc_server = dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
dynamic_cast<tensorflow::GrpcServer*>(ctx->context->GetServer());
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def)); LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
LOG_AND_RETURN_IF_ERROR( LOG_AND_RETURN_IF_ERROR(
ListRemoteWorkers(grpc_server, worker_name, &remote_workers)); ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
} }
tensorflow::uint64 context_id = ctx->context->GetContextId(); tensorflow::uint64 context_id = context->GetContextId();
tensorflow::uint64 context_view_id = ctx->context->GetContextViewId(); tensorflow::uint64 context_view_id = context->GetContextViewId();
if (reset_context) { if (reset_context) {
context_id = tensorflow::EagerContext::NewContextId(); context_id = tensorflow::EagerContext::NewContextId();
context_view_id = 0; context_view_id = 0;
// Make master eager context accessible by local eager service, which might // Make master eager context accessible by local eager service, which might
// receive send tensor requests from remote workers. // receive send tensor requests from remote workers.
LOG_AND_RETURN_IF_ERROR(grpc_server->AddMasterEagerContextToEagerService( LOG_AND_RETURN_IF_ERROR(
context_id, ctx->context)); grpc_server->AddMasterEagerContextToEagerService(context_id, context));
} }
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers; std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
@ -464,11 +464,11 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
&new_remote_device_mgr)); &new_remote_device_mgr));
remote_device_mgr = new_remote_device_mgr.get(); remote_device_mgr = new_remote_device_mgr.get();
} else { } else {
ctx->context->ClearCachesAndDefaultExecutor(); context->ClearCachesAndDefaultExecutor();
// TODO(b/143914772): Potential memory leak if rendezvous has pending // TODO(b/143914772): Potential memory leak if rendezvous has pending
// tensors for removed / replaced workers. // tensors for removed / replaced workers.
remote_device_mgr = ctx->context->GetOwnedRemoteDeviceMgr(); remote_device_mgr = context->GetOwnedRemoteDeviceMgr();
if (remote_device_mgr == nullptr) { if (remote_device_mgr == nullptr) {
LOG_AND_RETURN_IF_ERROR(tensorflow::errors::InvalidArgument( LOG_AND_RETURN_IF_ERROR(tensorflow::errors::InvalidArgument(
"Updating context with an invalid set of remote devices.")); "Updating context with an invalid set of remote devices."));
@ -479,8 +479,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
&added_workers, &removed_workers, &added_workers, &removed_workers,
&existing_workers); &existing_workers);
LOG_AND_RETURN_IF_ERROR(GetReplacedFromExistingWorkers( LOG_AND_RETURN_IF_ERROR(GetReplacedFromExistingWorkers(
&existing_workers, context_id, ctx->context->GetContextViewId(), &existing_workers, context_id, context->GetContextViewId(), server_def,
server_def, remote_eager_workers.get(), &replaced_workers)); remote_eager_workers.get(), &replaced_workers));
if (VLOG_IS_ON(1)) { if (VLOG_IS_ON(1)) {
VLOG(1) << "Updating cluster with following changes"; VLOG(1) << "Updating cluster with following changes";
for (const string& w : added_workers) VLOG(1) << " Added worker " << w; for (const string& w : added_workers) VLOG(1) << " Added worker " << w;
@ -516,7 +516,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
grpc_server->worker_env()->device_mgr->ListDeviceAttributes( grpc_server->worker_env()->device_mgr->ListDeviceAttributes(
&local_device_attributes); &local_device_attributes);
// This request make sure that we can create Rendevzous properly between // This request make sure that we can create Rendezvous properly between
// Local and Remote context. // Local and Remote context.
tensorflow::eager::CreateContextRequest base_request; tensorflow::eager::CreateContextRequest base_request;
for (const auto& da : cluster_device_attributes) { for (const auto& da : cluster_device_attributes) {
@ -534,9 +534,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
if (reset_context) { if (reset_context) {
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
remote_workers, context_id, context_view_id, keep_alive_secs, remote_workers, context_id, context_view_id, keep_alive_secs,
server_def, remote_eager_workers.get(), server_def, remote_eager_workers.get(), context->Executor().Async(),
ctx->context->Executor().Async(), context->LazyCopyFunctionRemoteInputs(), base_request));
ctx->context->LazyCopyFunctionRemoteInputs(), base_request));
} else { } else {
// The master's context_view_id will be incremented by one // The master's context_view_id will be incremented by one
// the UpdateRemoteMaster call later. We want all new workers and // the UpdateRemoteMaster call later. We want all new workers and
@ -545,9 +544,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
// context_view_id + 1. // context_view_id + 1.
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
added_workers, context_id, context_view_id + 1, keep_alive_secs, added_workers, context_id, context_view_id + 1, keep_alive_secs,
server_def, remote_eager_workers.get(), server_def, remote_eager_workers.get(), context->Executor().Async(),
ctx->context->Executor().Async(), context->LazyCopyFunctionRemoteInputs(), base_request));
ctx->context->LazyCopyFunctionRemoteInputs(), base_request));
if (!existing_workers.empty()) { if (!existing_workers.empty()) {
if (VLOG_IS_ON(1)) { if (VLOG_IS_ON(1)) {
for (const string& w : existing_workers) { for (const string& w : existing_workers) {
@ -578,12 +576,12 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get())); TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr = tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
tensorflow::eager::CreateClusterFLR(context_id, ctx->context, tensorflow::eager::CreateClusterFLR(context_id, context,
worker_session.get()); worker_session.get());
auto remote_mgr = absl::make_unique<tensorflow::eager::RemoteMgr>( auto remote_mgr = absl::make_unique<tensorflow::eager::RemoteMgr>(
/*is_master=*/true, ctx->context); /*is_master=*/true, context);
LOG_AND_RETURN_IF_ERROR(ctx->context->InitializeRemoteMaster( LOG_AND_RETURN_IF_ERROR(context->InitializeRemoteMaster(
std::move(new_server), grpc_server->worker_env(), worker_session, std::move(new_server), grpc_server->worker_env(), worker_session,
std::move(remote_eager_workers), std::move(new_remote_device_mgr), std::move(remote_eager_workers), std::move(new_remote_device_mgr),
remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr, remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr,
@ -601,9 +599,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
grpc_server->worker_env()->session_mgr->WorkerSessionForSession( grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
session_name, &worker_session)); session_name, &worker_session));
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr = tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
tensorflow::eager::CreateClusterFLR(context_id, ctx->context, tensorflow::eager::CreateClusterFLR(context_id, context,
worker_session.get()); worker_session.get());
LOG_AND_RETURN_IF_ERROR(ctx->context->UpdateRemoteMaster( LOG_AND_RETURN_IF_ERROR(context->UpdateRemoteMaster(
grpc_server->worker_env(), std::move(remote_eager_workers), grpc_server->worker_env(), std::move(remote_eager_workers),
added_workers, removed_workers, context_id, r, device_mgr, added_workers, removed_workers, context_id, r, device_mgr,
keep_alive_secs, cluster_flr)); keep_alive_secs, cluster_flr));
@ -614,77 +612,6 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
} }
#endif // !IS_MOBILE_PLATFORM #endif // !IS_MOBILE_PLATFORM
tensorflow::Status OpInferSingleInputAttrs(TFE_Op* op,
TFE_TensorHandle* input) {
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++);
if (!input_def.number_attr().empty() || !input_def.type_list_attr().empty()) {
// Some clients that are still setting their input attributes manually are
// adding input list to their op by calling `TFE_OpAddInput` for each of
// its elements instead of calling `TFE_OpAddInputList`. When this happens,
// we cannot detect the end of such list, thus lose track of the input
// arguments in the op definition. To guarantee backward compatibility with
// those clients, disable automatic inference in this case.
op->inference_ctx.reset(nullptr);
return tensorflow::Status::OK();
}
const std::string& type_attr = input_def.type_attr();
if (!type_attr.empty() && ictx->attrs.find(type_attr) == ictx->attrs.end()) {
op->operation.MutableAttrs()->Set(type_attr, input->handle->dtype);
ictx->attrs.insert(type_attr);
}
return tensorflow::Status::OK();
}
void OpInferSingleTypeInputListAttrs(TFE_Op* op,
const tensorflow::OpDef::ArgDef& input_def,
TFE_TensorHandle** inputs,
int num_inputs) {
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
if (ictx->attrs.find(input_def.number_attr()) == ictx->attrs.end()) {
op->operation.MutableAttrs()->Set(input_def.number_attr(), num_inputs);
ictx->attrs.insert(input_def.number_attr());
}
if (ictx->attrs.find(input_def.type_attr()) == ictx->attrs.end()) {
op->operation.MutableAttrs()->Set(input_def.type_attr(),
inputs[0]->handle->dtype);
ictx->attrs.insert(input_def.type_attr());
}
}
void OpInferMixedTypeInputListAttrs(TFE_Op* op,
const tensorflow::OpDef::ArgDef& input_def,
TFE_TensorHandle** inputs, int num_inputs) {
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
if (ictx->attrs.find(input_def.type_list_attr()) == ictx->attrs.end()) {
std::unique_ptr<tensorflow::DataType[]> dtypes(
new tensorflow::DataType[num_inputs]);
for (int i = 0; i < num_inputs; ++i) {
dtypes[i] = inputs[i]->handle->dtype;
}
op->operation.MutableAttrs()->Set(
input_def.type_list_attr(),
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(dtypes.get(),
num_inputs));
ictx->attrs.insert(input_def.type_list_attr());
}
}
tensorflow::Status OpInferInputListAttrs(TFE_Op* op, TFE_TensorHandle** inputs,
int num_inputs) {
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++);
if (!input_def.type_list_attr().empty()) {
OpInferMixedTypeInputListAttrs(op, input_def, inputs, num_inputs);
} else if (!input_def.type_attr().empty() &&
!input_def.number_attr().empty()) {
OpInferSingleTypeInputListAttrs(op, input_def, inputs, num_inputs);
} else {
return tensorflow::errors::InvalidArgument("Invalid input list definition");
}
return tensorflow::Status::OK();
}
} // namespace } // namespace
extern "C" { extern "C" {
@ -720,12 +647,14 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
tensorflow::Rendezvous* r = tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get()); new tensorflow::IntraProcessRendezvous(device_mgr.get());
return new TFE_Context(opts->session_options.options, return new TFE_Context{new tensorflow::EagerContext(
opts->device_placement_policy, opts->mirroring_policy, opts->session_options.options,
opts->async, opts->lazy_remote_inputs_copy, static_cast<tensorflow::ContextDevicePlacementPolicy>(
device_mgr.release(), opts->device_placement_policy),
/*device_mgr_owned*/ true, r, static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
tensorflow::GetDefaultCustomKernelCreator()); opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
/*device_mgr_owned*/ true, r,
tensorflow::GetDefaultCustomKernelCreator())};
} }
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts, TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
@ -736,22 +665,28 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
tensorflow::Rendezvous* r = tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr); new tensorflow::IntraProcessRendezvous(device_mgr);
return new TFE_Context(opts->session_options.options, return new TFE_Context{new tensorflow::EagerContext(
opts->device_placement_policy, opts->mirroring_policy, opts->session_options.options,
opts->async, opts->lazy_remote_inputs_copy, device_mgr, static_cast<tensorflow::ContextDevicePlacementPolicy>(
/*device_mgr_owned*/ false, r, opts->device_placement_policy),
tensorflow::GetDefaultCustomKernelCreator()); static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
/*device_mgr_owned*/ false, r,
tensorflow::GetDefaultCustomKernelCreator())};
} }
void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; } void TFE_DeleteContext(TFE_Context* ctx) {
// context->RefCountIsOne() should be true here.
// TODO(iga): Remove EagerContext refcounting.
ctx->context->Unref();
delete ctx;
}
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
TF_DeviceList* list = new TF_DeviceList; TF_DeviceList* l = new TF_DeviceList;
ctx->context->local_device_mgr()->ListDeviceAttributes(&list->response); ctx->context->ListDevices(&l->response);
if (ctx->context->remote_device_mgr()) { return l;
ctx->context->remote_device_mgr()->ListDeviceAttributes(&list->response);
}
return list;
} }
void TFE_ContextClearCaches(TFE_Context* ctx) { void TFE_ContextClearCaches(TFE_Context* ctx) {
@ -812,8 +747,9 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
"TFE_ContextSetServerDef not supported on mobile"); "TFE_ContextSetServerDef not supported on mobile");
return false; return false;
#else // !defined(IS_MOBILE_PLATFORM) #else // !defined(IS_MOBILE_PLATFORM)
tensorflow::EagerContext* context = ctx->context;
tensorflow::GrpcServer* grpc_server = tensorflow::GrpcServer* grpc_server =
static_cast<tensorflow::GrpcServer*>(ctx->context->GetServer()); static_cast<tensorflow::GrpcServer*>(context->GetServer());
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers; std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
status->status = grpc_server->master_env()->worker_cache->GetEagerClientCache( status->status = grpc_server->master_env()->worker_cache->GetEagerClientCache(
@ -832,7 +768,7 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
// Send a rpc request to the worker to check aliveness. // Send a rpc request to the worker to check aliveness.
tensorflow::eager::KeepAliveRequest request; tensorflow::eager::KeepAliveRequest request;
request.set_context_id(ctx->context->GetContextId()); request.set_context_id(context->GetContextId());
tensorflow::eager::KeepAliveResponse response; tensorflow::eager::KeepAliveResponse response;
tensorflow::Status keep_alive_status; tensorflow::Status keep_alive_status;
@ -887,108 +823,180 @@ void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
if (h == nullptr) return; if (h == nullptr) return;
tensorflow::profiler::TraceMe activity( tensorflow::profiler::TraceMe activity(
"TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo); "TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
VLOG(1) << "Deleting tensor handle " << h << " with internal handle "
<< h->handle;
if (h->handle) {
h->handle->Unref();
}
delete h; delete h;
} }
tensorflow::TensorHandleInterface::~TensorHandleInterface() {
VLOG(1) << "Deleting tensor handle " << this << " with internal handle "
<< handle_;
if (handle_) {
handle_->Unref();
}
}
bool tensorflow::TensorHandleInterface::IsValid(Status* status) const {
if (handle_ == nullptr) {
*status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return false;
}
return true;
}
TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) { TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
return static_cast<TF_DataType>(h->handle->dtype); return h->handle->DataType();
}
TF_DataType tensorflow::TensorHandleInterface::DataType() const {
return static_cast<TF_DataType>(handle_->dtype);
} }
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr"); "The passed in handle is a nullptr");
return -1; return -1;
} }
return h->handle->NumDims(&status->status);
}
int tensorflow::TensorHandleInterface::NumDims(Status* status) const {
if (!IsValid(status)) {
return -1;
}
int result; int result;
status->status = h->handle->NumDims(&result); *status = handle_->NumDims(&result);
return result; return result;
} }
int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) { int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr"); "The passed in handle is a nullptr");
return -1; return -1;
} }
return h->handle->NumElements(&status->status);
}
int64_t tensorflow::TensorHandleInterface::NumElements(Status* status) const {
if (!IsValid(status)) {
return -1;
}
tensorflow::int64 result; tensorflow::int64 result;
status->status = h->handle->NumElements(&result); *status = handle_->NumElements(&result);
return result; return result;
} }
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
TF_Status* status) { TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr"); "The passed in handle is a nullptr");
return -1; return -1;
} }
return h->handle->Dim(dim_index, &status->status);
}
int64_t tensorflow::TensorHandleInterface::Dim(int dim_index,
Status* status) const {
if (!IsValid(status)) {
return -1;
}
tensorflow::int64 result; tensorflow::int64 result;
status->status = h->handle->Dim(dim_index, &result); *status = handle_->Dim(dim_index, &result);
return result; return result;
} }
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr"); "The passed in handle is a nullptr");
return nullptr; return nullptr;
} }
tensorflow::Device* d = h->handle->op_device(); return h->handle->DeviceName(&status->status);
}
const char* tensorflow::TensorHandleInterface::DeviceName(
Status* status) const {
if (!IsValid(status)) {
return nullptr;
}
tensorflow::Device* d = handle_->op_device();
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str(); : d->name().c_str();
} }
const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h, const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
TF_Status* status) { TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr"); "The passed in handle is a nullptr");
return nullptr; return nullptr;
} }
tensorflow::Device* d = h->handle->device(); return h->handle->BackingDeviceName(&status->status);
}
const char* tensorflow::TensorHandleInterface::BackingDeviceName(
Status* status) const {
if (!IsValid(status)) {
return nullptr;
}
tensorflow::Device* d = handle_->device();
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str(); : d->name().c_str();
} }
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor( TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
TFE_TensorHandle* h, TF_Status* status) { TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr || !h->handle->IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr"); "The passed in handle is a nullptr");
return nullptr; return nullptr;
} }
h->handle->Ref(); return new TFE_TensorHandle{
std::unique_ptr<AbstractTensorHandleInterface>(h->handle->Copy())};
}
return new TFE_TensorHandle(h->handle); AbstractTensorHandleInterface* tensorflow::TensorHandleInterface::Copy() {
handle_->Ref();
return new TensorHandleInterface(handle_);
} }
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr"); "The passed in handle is a nullptr");
return nullptr; return nullptr;
} }
tensorflow::TensorHandle* handle = h->handle;
return h->handle->Resolve(&status->status);
}
TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
if (!IsValid(status)) {
return nullptr;
}
// TODO(agarwal): move this implementation inside TFE_TensorHandle. // TODO(agarwal): move this implementation inside TFE_TensorHandle.
if (handle->IsRemote()) { if (handle_->IsRemote()) {
const tensorflow::Tensor* t = nullptr; const tensorflow::Tensor* t = nullptr;
tensorflow::TensorHandle* h_cpu = nullptr; tensorflow::TensorHandle* h_cpu = nullptr;
status->status = EagerCopyToDevice( *status = EagerCopyToDevice(handle_, handle_->Context(),
handle, handle->Context(), &handle->Context()->Executor(), &handle_->Context()->Executor(),
handle->Context()->HostCPU(), false, &h_cpu); handle_->Context()->HostCPU(), false, &h_cpu);
if (!status->status.ok()) { if (!status->ok()) {
return nullptr; return nullptr;
} }
status->status = h_cpu->Tensor(&t); *status = h_cpu->Tensor(&t);
if (!status->status.ok()) { if (!status->ok()) {
h_cpu->Unref(); h_cpu->Unref();
return nullptr; return nullptr;
} }
@ -997,28 +1005,30 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
return retval; return retval;
} else { } else {
tensorflow::Tensor tensor; tensorflow::Tensor tensor;
if (IsCPU(handle->device())) { if (IsCPU(handle_->device())) {
const tensorflow::Tensor* src = nullptr; const tensorflow::Tensor* src = nullptr;
status->status = handle->Tensor(&src); *status = handle_->Tensor(&src);
if (!status->status.ok()) return nullptr; if (!status->ok()) return nullptr;
tensor = *src; tensor = *src;
} else { } else {
tensorflow::EagerContext* ctx = handle->Context(); tensorflow::EagerContext* ctx = handle_->Context();
CHECK_NE(ctx, nullptr); CHECK_NE(ctx, nullptr);
status->status = h->handle->CopyToDevice(ctx, ctx->HostCPU(), &tensor); *status = handle_->CopyToDevice(*ctx, ctx->HostCPU(), &tensor);
if (!status->status.ok()) return nullptr; if (!status->ok()) return nullptr;
} }
return tensorflow::TF_TensorFromTensor(tensor, status); return tensorflow::TF_TensorFromTensor(tensor, status);
} }
} }
void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) { void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr || !h->handle->IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr"); "The passed in handle is a nullptr");
return nullptr; return nullptr;
} }
tensorflow::TensorHandle* handle = h->handle; tensorflow::TensorHandle* handle =
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle();
if (handle->IsRemote()) { if (handle->IsRemote()) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
@ -1047,7 +1057,8 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
void (*deallocator)(void* data, size_t len, void* arg), void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg, TF_Status* status) { void* deallocator_arg, TF_Status* status) {
tensorflow::Device* device; tensorflow::Device* device;
status->status = ctx->context->FindDeviceFromName(device_name, &device); tensorflow::EagerContext* context = ctx->context;
status->status = context->FindDeviceFromName(device_name, &device);
if (!status->status.ok()) { if (!status->status.ok()) {
deallocator(data, len, deallocator_arg); deallocator(data, len, deallocator_arg);
return nullptr; return nullptr;
@ -1075,11 +1086,12 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
buf->Unref(); buf->Unref();
tensorflow::TensorHandle* ret_handle; tensorflow::TensorHandle* ret_handle;
status->status = tensorflow::TensorHandle::CreateLocalHandle( status->status = tensorflow::TensorHandle::CreateLocalHandle(
t, device, ctx->context, &ret_handle); t, device, context, &ret_handle);
if (!status->status.ok()) { if (!status->status.ok()) {
return nullptr; return nullptr;
} }
return new TFE_TensorHandle(ret_handle); return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(ret_handle)};
} }
// This function will block till the operation that produces `h` has // This function will block till the operation that produces `h` has
@ -1087,12 +1099,14 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
// bytes of the memory pointed to by the device pointer returned above. // bytes of the memory pointed to by the device pointer returned above.
size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h, size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
TF_Status* status) { TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr || !h->handle->IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr"); "The passed in handle is a nullptr");
return 0; return 0;
} }
tensorflow::TensorHandle* handle = h->handle; tensorflow::TensorHandle* handle =
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle();
if (handle->IsRemote()) { if (handle->IsRemote()) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
@ -1110,8 +1124,14 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) { TF_Status* status) {
return NewOrResetOp(ctx, op_or_function_name, nullptr, status, std::unique_ptr<TFE_Op> new_op(
/* op_to_reset= */ nullptr); new TFE_Op{tensorflow::EagerOperation(ctx->context)});
status->status =
new_op->operation.Reset(op_or_function_name, nullptr, false, nullptr);
if (!status->status.ok()) {
new_op.reset();
}
return new_op.release();
} }
void TFE_DeleteOp(TFE_Op* op) { delete op; } void TFE_DeleteOp(TFE_Op* op) { delete op; }
@ -1122,7 +1142,7 @@ void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) { const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
tensorflow::Device* device = (op->operation.Device() == nullptr) tensorflow::Device* device = (op->operation.Device() == nullptr)
? op->operation.EagerContext()->HostCPU() ? op->operation.EagerContext().HostCPU()
: op->operation.Device(); : op->operation.Device();
return device->name().c_str(); return device->name().c_str();
} }
@ -1136,20 +1156,23 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
} }
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) { void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
op->operation.AddInput(input->handle); tensorflow::TensorHandle* h =
if (op->inference_ctx) { tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
status->status = OpInferSingleInputAttrs(op, input); input->handle.get())
} ->Handle();
op->operation.AddInput(h);
status->status = op->operation.MaybeInferSingleInputAttrs(h);
} }
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs, void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
TF_Status* status) { TF_Status* status) {
for (int i = 0; i < num_inputs; ++i) { for (int i = 0; i < num_inputs; ++i) {
op->operation.AddInput(inputs[i]->handle); op->operation.AddInput(
} tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
if (op->inference_ctx) { inputs[i]->handle.get())
status->status = OpInferInputListAttrs(op, inputs, num_inputs); ->Handle());
} }
status->status = op->operation.InferInputListAttrs(num_inputs);
} }
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
@ -1382,15 +1405,16 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) { TF_Status* status) {
VLOG(1) << "Calling TFE_Execute() on op " << op;
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals); absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
VLOG(1) << "Calling TFE_Execute() on op " << op;
status->status = tensorflow::EagerExecute(&op->operation, status->status = tensorflow::EagerExecute(&op->operation,
handle_retvals.data(), num_retvals); handle_retvals.data(), num_retvals);
if (!status->status.ok()) { if (!status->status.ok()) {
return; return;
} }
for (int i = 0; i < *num_retvals; ++i) { for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = new TFE_TensorHandle(handle_retvals[i]); retvals[i] = new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle_retvals[i])};
} }
} }
@ -1400,15 +1424,18 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
TF_Status* status) { TF_Status* status) {
tensorflow::TensorHandle* handle = nullptr; tensorflow::TensorHandle* handle = nullptr;
tensorflow::Device* device; tensorflow::Device* device;
status->status = ctx->context->FindDeviceFromName(device_name, &device); tensorflow::EagerContext* context = ctx->context;
status->status = context->FindDeviceFromName(device_name, &device);
if (!status->status.ok()) { if (!status->status.ok()) {
return nullptr; return nullptr;
} }
status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context, status->status = tensorflow::EagerCopyToDevice(
&ctx->context->Executor(), tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
device, false, &handle); ->Handle(),
context, &context->Executor(), device, false, &handle);
if (status->status.ok()) { if (status->status.ok()) {
return new TFE_TensorHandle(handle); return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
} }
return nullptr; return nullptr;
} }
@ -1456,11 +1483,12 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
TF_Status* status) { TF_Status* status) {
status->status = ctx->context->Executor().WaitForAllPendingNodes(); tensorflow::EagerContext* context = ctx->context;
status->status = context->Executor().WaitForAllPendingNodes();
if (!status->status.ok()) return; if (!status->status.ok()) return;
tensorflow::mutex_lock ml(*ctx->context->MetadataMu()); tensorflow::mutex_lock ml(*context->MetadataMu());
status->status = MessageToBuffer(*ctx->context->RunMetadataProto(), buf); status->status = MessageToBuffer(*context->RunMetadataProto(), buf);
ctx->context->ClearRunMetadata(); context->ClearRunMetadata();
} }
namespace { namespace {

View File

@ -206,14 +206,14 @@ typedef struct TFE_TensorDebugInfo TFE_TensorDebugInfo;
// error and nullptr is returned. This function can block till the operation // error and nullptr is returned. This function can block till the operation
// that produces `handle` has completed. // that produces `handle` has completed.
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
TFE_TensorHandle* handle, TF_Status* status); TFE_TensorHandle* h, TF_Status* status);
// Deletes `debug_info`. // Deletes `debug_info`.
TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo( TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo(
TFE_TensorDebugInfo* debug_info); TFE_TensorDebugInfo* debug_info);
// Returns the number of dimensions used to represent the tensor on its device. // Returns the number of dimensions used to represent the tensor on its device.
// The number of dimensions used to reprensent the tensor on device can be // The number of dimensions used to represent the tensor on device can be
// different from the number returned by TFE_TensorHandleNumDims. // different from the number returned by TFE_TensorHandleNumDims.
// The return value was current at the time of TFE_TensorDebugInfo creation. // The return value was current at the time of TFE_TensorDebugInfo creation.
TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims( TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims(

View File

@ -28,19 +28,22 @@ using tensorflow::string;
namespace { namespace {
std::vector<int64> TensorShapeAsVector(TFE_TensorHandle* handle, std::vector<int64> TensorShapeAsVector(const tensorflow::TensorHandle& handle,
TF_Status* status) { tensorflow::Status* status) {
std::vector<int64> shape; std::vector<int64> shape;
int rank = TFE_TensorHandleNumDims(handle, status); int rank = -1;
if (TF_GetCode(status) != TF_OK) { *status = handle.NumDims(&rank);
if (!status->ok()) {
return shape; return shape;
} }
shape.reserve(rank); shape.reserve(rank);
for (int i = 0; i < rank; ++i) { for (int i = 0; i < rank; ++i) {
shape.push_back(TFE_TensorHandleDim(handle, i, status)); tensorflow::int64 dim;
if (TF_GetCode(status) != TF_OK) { *status = handle.Dim(i, &dim);
if (!status->ok()) {
return shape; return shape;
} }
shape.push_back(dim);
} }
return shape; return shape;
} }
@ -50,15 +53,20 @@ std::vector<int64> TensorShapeAsVector(TFE_TensorHandle* handle,
extern "C" { extern "C" {
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
TFE_TensorHandle* handle, TF_Status* status) { TFE_TensorHandle* h, TF_Status* status) {
return h->handle->TensorDebugInfo(&status->status);
}
TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
Status* status) {
const tensorflow::Tensor* tensor; const tensorflow::Tensor* tensor;
status->status = handle->handle->Tensor(&tensor); *status = handle_->Tensor(&tensor);
if (TF_GetCode(status) != TF_OK) { if (!status->ok()) {
return nullptr; return nullptr;
} }
#ifdef TENSORFLOW_EAGER_USE_XLA #ifdef TENSORFLOW_EAGER_USE_XLA
tensorflow::Device* device = handle->handle->device(); tensorflow::Device* device = handle_->device();
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn. // If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
tensorflow::XlaDevice* xla_device = tensorflow::XlaDevice* xla_device =
@ -67,15 +75,15 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
tensorflow::XlaDevice::PaddedShapeFn shape_fn = tensorflow::XlaDevice::PaddedShapeFn shape_fn =
xla_device->metadata().padded_shape_fn(); xla_device->metadata().padded_shape_fn();
xla::Shape padded_shape; xla::Shape padded_shape;
status->status = shape_fn(*tensor, &padded_shape); *status = shape_fn(*tensor, &padded_shape);
if (!status->status.ok()) { if (!status->ok()) {
return nullptr; return nullptr;
} }
if (VLOG_IS_ON(3)) { if (VLOG_IS_ON(3)) {
std::vector<int64> shape_to_log = TensorShapeAsVector(handle, status); std::vector<int64> shape_to_log = TensorShapeAsVector(*handle_, status);
if (!status->status.ok()) { if (!status->ok()) {
// Ignore the status here as we are simply logging. // Ignore the status here as we are simply logging.
status->status = tensorflow::Status::OK(); *status = tensorflow::Status::OK();
} else { } else {
VLOG(3) << "Fully padded shape of [" VLOG(3) << "Fully padded shape of ["
<< absl::StrJoin(shape_to_log, ", ") << "] is " << absl::StrJoin(shape_to_log, ", ") << "] is "
@ -88,7 +96,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
// Currently, the only case of XlaTensor containing a tuple shape is to // Currently, the only case of XlaTensor containing a tuple shape is to
// represent 64 bit ints, doubles, and complex numbers (we don't support // represent 64 bit ints, doubles, and complex numbers (we don't support
// 64bit complex numbers). // 64bit complex numbers).
status->status = tensorflow::errors::InvalidArgument( *status = tensorflow::errors::InvalidArgument(
"XlaTensors should only contain tuples of size 2. Shape: ", "XlaTensors should only contain tuples of size 2. Shape: ",
padded_shape.DebugString()); padded_shape.DebugString());
return nullptr; return nullptr;
@ -100,13 +108,13 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
const xla::Shape& shape1 = const xla::Shape& shape1 =
xla::ShapeUtil::GetTupleElementShape(padded_shape, 1); xla::ShapeUtil::GetTupleElementShape(padded_shape, 1);
if (shape0.IsTuple() || shape1.IsTuple()) { if (shape0.IsTuple() || shape1.IsTuple()) {
status->status = tensorflow::errors::InvalidArgument( *status = tensorflow::errors::InvalidArgument(
"XlaTensors should not contain nested tuples. Shape: ", "XlaTensors should not contain nested tuples. Shape: ",
padded_shape.DebugString()); padded_shape.DebugString());
return nullptr; return nullptr;
} }
if (!xla::ShapeUtil::Equal(shape0, shape1)) { if (!xla::ShapeUtil::Equal(shape0, shape1)) {
status->status = tensorflow::errors::InvalidArgument( *status = tensorflow::errors::InvalidArgument(
"Subshapes of XlaTensors should be the same. Shape: ", "Subshapes of XlaTensors should be the same. Shape: ",
padded_shape.DebugString()); padded_shape.DebugString());
return nullptr; return nullptr;
@ -131,15 +139,15 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
dev_dims.push_back(padded_shape.dimensions(dim_index)); dev_dims.push_back(padded_shape.dimensions(dim_index));
} }
} }
status->status = tensorflow::Status::OK(); *status = tensorflow::Status::OK();
return new TFE_TensorDebugInfo(dev_dims); return new TFE_TensorDebugInfo(dev_dims);
} }
#endif // TENSORFLOW_EAGER_USE_XLA #endif // TENSORFLOW_EAGER_USE_XLA
// If the tensor is not an XLA tensor, the device shape is // If the tensor is not an XLA tensor, the device shape is
// the same as regular tensor shape. // the same as regular tensor shape.
std::vector<int64> dev_dims = TensorShapeAsVector(handle, status); std::vector<int64> dev_dims = TensorShapeAsVector(*handle_, status);
if (TF_GetCode(status) != TF_OK) { if (!status->ok()) {
return nullptr; return nullptr;
} }
return new TFE_TensorDebugInfo(dev_dims); return new TFE_TensorDebugInfo(dev_dims);

View File

@ -22,18 +22,18 @@ limitations under the License.
#include "tensorflow/core/lib/monitoring/gauge.h" #include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h" #include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/profiler/rpc/client/capture_profile.h" #include "tensorflow/core/profiler/rpc/client/capture_profile.h"
#include "tensorflow/core/profiler/rpc/profiler_server.h" #include "tensorflow/core/profiler/rpc/profiler_server.h"
using tensorflow::string; using tensorflow::string;
void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name, void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status, const char* raw_device_name, TF_Status* status) {
TFE_Op* op_to_reset) {
if (op_to_reset) { if (op_to_reset) {
NewOrResetOp(ctx, op_or_function_name, raw_device_name, status, status->status = op_to_reset->operation.Reset(
op_to_reset); op_or_function_name, raw_device_name, false, nullptr);
} else { } else {
TF_SetStatus(status, TF_INVALID_ARGUMENT, TF_SetStatus(status, TF_INVALID_ARGUMENT,
"op_to_reset should not be nullptr"); "op_to_reset should not be nullptr");
@ -41,7 +41,9 @@ void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name,
} }
void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
op->operation.ConsumeInput(h->handle); op->operation.ConsumeInput(
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle());
} }
TFE_Profiler* TFE_NewProfiler() { return new TFE_Profiler(); } TFE_Profiler* TFE_NewProfiler() { return new TFE_Profiler(); }

View File

@ -29,10 +29,10 @@ extern "C" {
// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster // and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster
// than seperately calling it because if the existing op has the same // than seperately calling it because if the existing op has the same
// `raw_device_name`, it skips parsing and just leave as it is. // `raw_device_name`, it skips parsing and just leave as it is.
TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Context* ctx, TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset,
const char* op_or_function_name, const char* op_or_function_name,
const char* raw_device_name, const char* raw_device_name,
TF_Status* status, TFE_Op* op_to_reset); TF_Status* status);
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
TF_Status* status); TF_Status* status);

View File

@ -1,66 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/host_info.h"
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status,
TFE_Op* op_to_reset) {
const char* name = op_or_function_name; // Shorthand
const tensorflow::AttrTypeMap* types;
bool is_function = false;
status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function);
if (!status->status.ok()) {
return nullptr;
}
if (op_to_reset && op_to_reset->ctx != ctx) {
status->status = tensorflow::errors::Internal(
"Cannot reset a TFE_Op from another TFE_Context");
return nullptr;
}
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
if (!is_function) {
const tensorflow::OpDef* op_def;
status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def);
if (!status->status.ok()) {
return nullptr;
}
inference_ctx.reset(new TFE_OpInferenceContext(op_def));
} else if (!ctx->context->FindFunctionByName(name)) {
status->status = tensorflow::errors::NotFound(
"'", name,
"' is neither a type of a primitive operation nor a name "
"of a function registered in binary running on ",
tensorflow::port::Hostname(),
". Make sure the operation or function is "
"registered in the binary running in this process.");
return nullptr;
}
if (op_to_reset) {
status->status = op_to_reset->Reset(
name, is_function, types, raw_device_name, std::move(inference_ctx));
return op_to_reset;
}
TFE_Op* new_op =
new TFE_Op(ctx, name, is_function, types, std::move(inference_ctx));
status->status = new_op->operation.SetDeviceName(raw_device_name);
return new_op;
}

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/context.h"
@ -62,36 +63,10 @@ struct TFE_ContextOptions {
}; };
struct TFE_Context { struct TFE_Context {
TFE_Context(const tensorflow::SessionOptions& opts,
TFE_ContextDevicePlacementPolicy default_device_placement_policy,
TFE_ContextMirroringPolicy default_mirroring_policy, bool async,
const bool lazy_remote_inputs_copy,
const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned,
tensorflow::Rendezvous* rendezvous,
const tensorflow::CustomKernelCreator* custom_kernel_creator)
: context(new tensorflow::EagerContext(
opts,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
default_device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(
default_mirroring_policy),
async, lazy_remote_inputs_copy, device_mgr, device_mgr_owned,
rendezvous, custom_kernel_creator)) {}
~TFE_Context() {
// TODO(iga): Add a separate API method to shutdown TFE_Context so that we
// don't send RPCs and block in destructor.
context->WaitForAndCloseRemoteContexts();
// context->RefCountIsOne() should be true here.
// TODO(iga): Remove EagerContext refcounting.
context->Unref();
}
tensorflow::EagerContext* context; tensorflow::EagerContext* context;
}; };
struct TFE_TensorHandle { struct TFE_TensorHandle {
explicit TFE_TensorHandle(tensorflow::TensorHandle* h) : handle(h) {}
static TFE_TensorHandle* CreateLocalHandle(const class tensorflow::Tensor& t, static TFE_TensorHandle* CreateLocalHandle(const class tensorflow::Tensor& t,
TF_Status* s) { TF_Status* s) {
tensorflow::TensorHandle* handle; tensorflow::TensorHandle* handle;
@ -99,10 +74,11 @@ struct TFE_TensorHandle {
if (!s->status.ok()) { if (!s->status.ok()) {
return nullptr; return nullptr;
} }
return new TFE_TensorHandle(handle); return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
} }
tensorflow::TensorHandle* handle; std::unique_ptr<AbstractTensorHandleInterface> handle;
}; };
struct TFE_TensorDebugInfo { struct TFE_TensorDebugInfo {
@ -113,46 +89,10 @@ struct TFE_TensorDebugInfo {
std::vector<tensorflow::int64> dev_dims; std::vector<tensorflow::int64> dev_dims;
}; };
struct TFE_OpInferenceContext {
explicit TFE_OpInferenceContext(const tensorflow::OpDef* op_def)
: op_def(op_def) {}
const tensorflow::OpDef* op_def; // op definition from protobuf
int input_arg_idx = 0; // arg definition index for the next input to be added
tensorflow::gtl::FlatSet<std::string> attrs; // attributes inferred so far
};
struct TFE_Op { struct TFE_Op {
TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
const tensorflow::AttrTypeMap* t,
std::unique_ptr<TFE_OpInferenceContext> inference_ctx)
: ctx(ctx),
operation(ctx->context, op, is_function, t),
inference_ctx(std::move(inference_ctx)) {}
void Clear() {
operation.Clear();
inference_ctx.reset();
}
tensorflow::Status Reset(const char* op, bool is_function,
const tensorflow::AttrTypeMap* t,
const char* raw_device_name,
std::unique_ptr<TFE_OpInferenceContext> infer_ctx) {
inference_ctx = std::move(infer_ctx);
return operation.Reset(ctx->context, op, is_function, t, raw_device_name,
nullptr);
}
TFE_Context* ctx;
tensorflow::EagerOperation operation; tensorflow::EagerOperation operation;
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
}; };
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status,
TFE_Op* op_to_reset = nullptr);
struct TFE_Profiler { struct TFE_Profiler {
explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); } explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); }

View File

@ -1362,10 +1362,11 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
TFE_TensorHandle* inputs[] = {input1, input2}; TFE_TensorHandle* inputs[] = {input1, input2};
TFE_OpAddInput(concatOp, dim, status); TFE_OpAddInput(concatOp, dim, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CHECK(concatOp->inference_ctx); CHECK(concatOp->operation.OpDef());
TFE_OpAddInput(concatOp, inputs[0], status); TFE_OpAddInput(concatOp, inputs[0], status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_FALSE(concatOp->inference_ctx) << "Inference context is still present"; EXPECT_FALSE(concatOp->operation.OpDef())
<< "Inference context is still present";
TFE_OpAddInput(concatOp, inputs[1], status); TFE_OpAddInput(concatOp, inputs[1], status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);

View File

@ -284,7 +284,7 @@ class ForwardAccumulator {
// Temporarily push or pop transient state for this accumulator. // Temporarily push or pop transient state for this accumulator.
// //
// Allows an accumulator which is currently processing an operation to // Allows an accumulator which is currently processing an operation to
// temporarily reset its state. Without pushing and poping, accumulators // temporarily reset its state. Without pushing and popping, accumulators
// ignore operations executed as a direct result of their own jvp // ignore operations executed as a direct result of their own jvp
// computations. // computations.
void PushState() { call_state_.emplace(nullptr, false); } void PushState() { call_state_.emplace(nullptr, false); }

View File

@ -0,0 +1,90 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
#define TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
// Abstract interface to a TensorHandle.
//
// A TensorHandle is management class around a Tensor which may track additional
// metadata and synchronization.
//
// This allows us to hide concrete implementations of TensorHandle from header
// files. The interface lists the common functionality that must be provided by
// any concrete implementation. However, in cases where the true concrete class
// is needed a static_cast can be applied.
class AbstractTensorHandleInterface {
public:
virtual ~AbstractTensorHandleInterface() {}
// Check if the handle is in a valid initialized state.
virtual bool IsValid(tensorflow::Status* status) const = 0;
// Returns tensor dtype.
virtual TF_DataType DataType() const = 0;
// Returns number of dimensions.
virtual int NumDims(tensorflow::Status* status) const = 0;
// Returns number of elements across all dimensions.
virtual int64_t NumElements(tensorflow::Status* status) const = 0;
// Returns size of specified dimension
virtual int64_t Dim(int dim_index, tensorflow::Status* status) const = 0;
// Returns the device which created the handle.
virtual const char* DeviceName(tensorflow::Status* status) const = 0;
// Returns the device where the tensor was placed.
virtual const char* BackingDeviceName(tensorflow::Status* status) const = 0;
// Returns a tensor for the handle. If tensor is remote, it will be copied.
virtual TF_Tensor* Resolve(tensorflow::Status* status) = 0;
// Returns debug information about the tensor.
virtual TFE_TensorDebugInfo* TensorDebugInfo(tensorflow::Status* status) = 0;
// Return a copy of the handle.
virtual AbstractTensorHandleInterface* Copy() = 0;
};
namespace tensorflow {
class TensorHandleInterface : public AbstractTensorHandleInterface {
public:
explicit TensorHandleInterface(TensorHandle* h) : handle_(h) {}
~TensorHandleInterface() override;
bool IsValid(Status* status) const override;
TF_DataType DataType() const override;
int NumDims(Status* status) const override;
int64_t NumElements(Status* status) const override;
int64_t Dim(int dim_index, Status* status) const override;
const char* DeviceName(Status* status) const override;
const char* BackingDeviceName(Status* status) const override;
TF_Tensor* Resolve(Status* status) override;
TFE_TensorDebugInfo* TensorDebugInfo(Status* status) override;
AbstractTensorHandleInterface* Copy() override;
// TODO(gjn): This is not a very generic interface, but is needed for specific
// use cases.
TensorHandle* Handle() { return handle_; }
private:
TensorHandle* handle_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_

View File

@ -18,37 +18,23 @@ cc_library(
], ],
) )
# Core TensorFlow depends on this, this will be included in main library
cc_library(
name = "filesystem_interface_impl",
srcs = ["filesystem_interface.cc"],
hdrs = ["filesystem_interface.h"],
deps = [
":modular_filesystem",
"//tensorflow/c:tf_file_statistics",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_internal",
"//tensorflow/core:ptr_util",
"//tensorflow/core/platform:env",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:strcat",
"//tensorflow/core/platform:stringpiece",
],
alwayslink = 1,
)
# Core TensorFlow depends on this, will be included in main library # Core TensorFlow depends on this, will be included in main library
cc_library( cc_library(
name = "modular_filesystem", name = "modular_filesystem",
srcs = ["modular_filesystem.cc"], srcs = [
"modular_filesystem.cc",
"modular_filesystem_registration.cc",
"modular_filesystem_registration.h",
],
hdrs = ["modular_filesystem.h"], hdrs = ["modular_filesystem.h"],
deps = [ deps = [
":filesystem_interface", ":filesystem_interface",
"//tensorflow/c:tf_status_helper", "//tensorflow/c:tf_status_helper",
"//tensorflow/core:lib", "//tensorflow/c:tf_status_internal",
"//tensorflow/core:ptr_util", "//tensorflow/core:ptr_util",
"//tensorflow/core/platform:env", "//tensorflow/core/platform:env",
"//tensorflow/core/platform:strcat", "//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
], ],
) )
@ -63,16 +49,12 @@ tf_cc_test(
"notap", # b/139060984, requires implementing modular support for Google filesystem "notap", # b/139060984, requires implementing modular support for Google filesystem
], ],
deps = [ deps = [
":filesystem_interface_impl", ":modular_filesystem",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_internal",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core/lib/io:path", "//tensorflow/core/lib/io:path",
"//tensorflow/core/platform:env", "//tensorflow/core/platform:env",
"//tensorflow/core/platform:error", "//tensorflow/core/platform:error",
"//tensorflow/core/platform:stacktrace_handler", "//tensorflow/core/platform:stacktrace_handler",
"//tensorflow/core/platform:str_util",
"//tensorflow/core/platform:strcat",
"//tensorflow/core/platform:test", "//tensorflow/core/platform:test",
], ],
) )

View File

@ -1,366 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/util/ptr_util.h"
/// This translation unit is linked in core TensorFlow and provides the
/// functionality needed for plugin registration to check ABI/API compatibility,
/// to ensure required methods are present, to ensure plugins are not allowed to
/// change functionality after being loaded and to register the filesystems
/// provided by a plugin. Consult the header file for more information about
/// how this is achieved.
namespace tensorflow {
namespace {
// Checks if the plugin and core ABI numbers match, filling in `status`.
//
// If the numbers don't match, plugin cannot be loaded.
static bool CheckABIHelper(int pluginABI, int coreABI, StringPiece where,
TF_Status* status) {
if (pluginABI != coreABI) {
TF_SetStatus(
status, TF_FAILED_PRECONDITION,
strings::StrCat("Plugin ABI (", pluginABI, ") for ", where,
" operations doesn't match expected core ABI (",
coreABI, "). Plugin cannot be loaded.")
.c_str());
return false;
}
return true;
}
// Checks if the plugin and core ABI numbers match, for all operations.
//
// If the numbers don't match, plugin cannot be loaded.
//
// Uses the simpler `CheckABIHelper(int, int, StringPiece, TF_Status*)`
static bool CheckABI(
int plugin_filesystem_ops_ABI,
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
int plugin_random_access_file_ops_ABI,
const TF_WritableFileOps* plugin_writable_file_ops,
int plugin_writable_file_ops_ABI,
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
int plugin_read_only_memory_region_ops_ABI, TF_Status* status) {
if (!CheckABIHelper(plugin_filesystem_ops_ABI, TF_FILESYSTEM_OPS_ABI,
"filesystem", status))
return false;
if (plugin_random_access_file_ops != nullptr &&
!CheckABIHelper(plugin_random_access_file_ops_ABI,
TF_RANDOM_ACCESS_FILE_OPS_ABI, "random access file",
status))
return false;
if (plugin_writable_file_ops != nullptr &&
!CheckABIHelper(plugin_writable_file_ops_ABI, TF_WRITABLE_FILE_OPS_ABI,
"writable file", status))
return false;
if (plugin_read_only_memory_region_ops != nullptr &&
!CheckABIHelper(plugin_read_only_memory_region_ops_ABI,
TF_READ_ONLY_MEMORY_REGION_OPS_ABI,
"read only memory region", status))
return false;
return true;
}
// Checks if the plugin and core API numbers match, logging mismatches.
static void CheckAPIHelper(int plugin_API, int core_API, StringPiece where) {
if (plugin_API != core_API) {
VLOG(0) << "Plugin API (" << plugin_API << ") for " << where
<< " operations doesn't match expected core API (" << core_API
<< "). Plugin will be loaded but functionality might be missing.";
}
}
// Checks if the plugin and core API numbers match, for all operations.
//
// Uses the simpler `CheckAPIHelper(int, int, StringPiece)`.
static void CheckAPI(
int plugin_filesystem_ops_API,
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
int plugin_random_access_file_ops_API,
const TF_WritableFileOps* plugin_writable_file_ops,
int plugin_writable_file_ops_API,
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
int plugin_read_only_memory_region_ops_API) {
CheckAPIHelper(plugin_filesystem_ops_API, TF_FILESYSTEM_OPS_API,
"filesystem");
if (plugin_random_access_file_ops != nullptr)
CheckAPIHelper(plugin_random_access_file_ops_API,
TF_RANDOM_ACCESS_FILE_OPS_API, "random access file");
if (plugin_writable_file_ops != nullptr)
CheckAPIHelper(plugin_writable_file_ops_API, TF_WRITABLE_FILE_OPS_API,
"writable file");
if (plugin_read_only_memory_region_ops != nullptr)
CheckAPIHelper(plugin_read_only_memory_region_ops_API,
TF_READ_ONLY_MEMORY_REGION_OPS_API,
"read only memory region");
}
// Validates the filesystem operations supplied by the plugin.
static bool ValidateHelper(const TF_FilesystemOps* ops, TF_Status* status) {
if (ops == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without operations");
return false;
}
if (ops->init == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `init` operation");
return false;
}
if (ops->cleanup == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `cleanup` operation");
return false;
}
return true;
}
// Validates the random access file operations supplied by the plugin.
static bool ValidateHelper(const TF_RandomAccessFileOps* ops,
TF_Status* status) {
if (ops == nullptr) {
// We allow filesystems where files can only be written to (from TF code)
return true;
}
if (ops->cleanup == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `cleanup` operation on "
"random access files");
return false;
}
return true;
}
// Validates the writable file operations supplied by the plugin.
static bool ValidateHelper(const TF_WritableFileOps* ops, TF_Status* status) {
if (ops == nullptr) {
// We allow read-only filesystems
return true;
}
if (ops->cleanup == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `cleanup` operation on "
"writable files");
return false;
}
return true;
}
// Validates the read only memory region operations given by the plugin.
static bool ValidateHelper(const TF_ReadOnlyMemoryRegionOps* ops,
TF_Status* status) {
if (ops == nullptr) {
// read only memory region support is always optional
return true;
}
if (ops->cleanup == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `cleanup` operation on "
"read only memory regions");
return false;
}
if (ops->data == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `data` operation on "
"read only memory regions");
return false;
}
if (ops->length == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Trying to register filesystem without `length` operation on "
"read only memory regions");
return false;
}
return true;
}
// Validates the operations supplied by the plugin.
//
// Uses the 4 simpler `ValidateHelper(const TF_..., TF_Status*)` to validate
// each individual function table and then checks that the function table for a
// specific file type exists if the plugin offers support for creating that
// type of files.
static bool Validate(
const TF_FilesystemOps* plugin_filesystem_ops,
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
const TF_WritableFileOps* plugin_writable_file_ops,
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
TF_Status* status) {
if (!ValidateHelper(plugin_filesystem_ops, status)) return false;
if (!ValidateHelper(plugin_random_access_file_ops, status)) return false;
if (!ValidateHelper(plugin_writable_file_ops, status)) return false;
if (!ValidateHelper(plugin_read_only_memory_region_ops, status)) return false;
if (plugin_filesystem_ops->new_random_access_file != nullptr &&
plugin_random_access_file_ops == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Filesystem allows creation of random access files but no "
"operations on them have been supplied.");
return false;
}
if ((plugin_filesystem_ops->new_writable_file != nullptr ||
plugin_filesystem_ops->new_appendable_file != nullptr) &&
plugin_writable_file_ops == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Filesystem allows creation of writable files but no "
"operations on them have been supplied.");
return false;
}
if (plugin_filesystem_ops->new_read_only_memory_region_from_file != nullptr &&
plugin_read_only_memory_region_ops == nullptr) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Filesystem allows creation of readonly memory regions but no "
"operations on them have been supplied.");
return false;
}
return true;
}
// Copies a function table from plugin memory space to core memory space.
//
// This has three benefits:
// * allows having newer plugins than the current core TensorFlow: the
// additional entries in the plugin's table are just discarded;
// * allows having older plugins than the current core TensorFlow (though
// we are still warning users): the entries that core TensorFlow expects
// but plugins didn't provide will be set to `nullptr` values and core
// TensorFlow will know to not call these on behalf of users;
// * increased security as plugins will not be able to alter function table
// after loading up. Thus, malicious plugins can't alter functionality to
// probe for gadgets inside core TensorFlow. We can even protect the area
// of memory where the copies reside to not allow any more writes to it
// after all copies are created.
template <typename T>
static std::unique_ptr<const T> CopyToCore(const T* plugin_ops,
size_t plugin_size) {
if (plugin_ops == nullptr) return nullptr;
size_t copy_size = sizeof(T);
if (plugin_size < copy_size) {
copy_size = plugin_size;
}
auto core_ops = tensorflow::MakeUnique<T>();
memcpy(const_cast<T*>(core_ops.get()), plugin_ops, copy_size);
return core_ops;
}
} // namespace
} // namespace tensorflow
void RegisterFilesystemPlugin(
int plugin_filesystem_ops_ABI, int plugin_filesystem_ops_API,
size_t plugin_filesystem_ops_size, int plugin_random_access_file_ops_ABI,
int plugin_random_access_file_ops_API,
size_t plugin_random_access_file_ops_size, int plugin_writable_file_ops_ABI,
int plugin_writable_file_ops_API, size_t plugin_writable_file_ops_size,
int plugin_read_only_memory_region_ops_ABI,
int plugin_read_only_memory_region_ops_API,
size_t plugin_read_only_memory_region_ops_size, const char* scheme,
const TF_FilesystemOps* plugin_filesystem_ops,
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
const TF_WritableFileOps* plugin_writable_file_ops,
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
TF_Status* status) {
if (scheme == nullptr) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"`scheme` argument must not be `nullptr`.");
return;
}
// ABI numbers must match exactly for plugin to be loaded
if (!tensorflow::CheckABI(
plugin_filesystem_ops_ABI, plugin_random_access_file_ops,
plugin_random_access_file_ops_ABI, plugin_writable_file_ops,
plugin_writable_file_ops_ABI, plugin_read_only_memory_region_ops,
plugin_read_only_memory_region_ops_ABI, status)) {
return;
}
// API numbers should match but mismatch doesn't block plugin load
tensorflow::CheckAPI(plugin_filesystem_ops_API, plugin_random_access_file_ops,
plugin_random_access_file_ops_API,
plugin_writable_file_ops, plugin_writable_file_ops_API,
plugin_read_only_memory_region_ops,
plugin_read_only_memory_region_ops_API);
// Plugin can only be loaded if all supplied ops are valid
if (!tensorflow::Validate(plugin_filesystem_ops,
plugin_random_access_file_ops,
plugin_writable_file_ops,
plugin_read_only_memory_region_ops, status)) {
return;
}
// Copy all the function tables to core TensorFlow memory space
auto core_filesystem_ops = tensorflow::CopyToCore<TF_FilesystemOps>(
plugin_filesystem_ops, plugin_filesystem_ops_size);
auto core_random_access_file_ops =
tensorflow::CopyToCore<TF_RandomAccessFileOps>(
plugin_random_access_file_ops, plugin_random_access_file_ops_size);
auto core_writable_file_ops = tensorflow::CopyToCore<TF_WritableFileOps>(
plugin_writable_file_ops, plugin_writable_file_ops_size);
auto core_read_only_memory_region_ops =
tensorflow::CopyToCore<TF_ReadOnlyMemoryRegionOps>(
plugin_read_only_memory_region_ops,
plugin_read_only_memory_region_ops_size);
// Initialize the opaque filesystem structure
auto filesystem = tensorflow::MakeUnique<TF_Filesystem>();
core_filesystem_ops->init(filesystem.get(), status);
if (!status->status.ok()) {
core_filesystem_ops->cleanup(filesystem.get());
return;
}
// Register new filesystem
status->status = tensorflow::Env::Default()->RegisterFileSystem(
scheme, tensorflow::MakeUnique<tensorflow::ModularFileSystem>(
std::move(filesystem), std::move(core_filesystem_ops),
std::move(core_random_access_file_ops),
std::move(core_writable_file_ops),
std::move(core_read_only_memory_region_ops)));
}

View File

@ -529,7 +529,7 @@ typedef struct TF_FilesystemOps {
/// If `statuses` is not null, plugins must fill each element with detailed /// If `statuses` is not null, plugins must fill each element with detailed
/// status for each file, as if calling `path_exists` on each one. Core /// status for each file, as if calling `path_exists` on each one. Core
/// TensorFlow initializes the `statuses` array and plugins must use /// TensorFlow initializes the `statuses` array and plugins must use
/// `TF_SetStatus` to set each element instead of dirrectly assigning. /// `TF_SetStatus` to set each element instead of directly assigning.
/// ///
/// DEFAULT IMPLEMENTATION: Checks existence of every file. Needs /// DEFAULT IMPLEMENTATION: Checks existence of every file. Needs
/// `path_exists`. /// `path_exists`.
@ -736,95 +736,108 @@ constexpr size_t TF_FILESYSTEM_OPS_SIZE = sizeof(TF_FilesystemOps);
/// SECTION 4. Plugin registration and initialization /// SECTION 4. Plugin registration and initialization
/// ---------------------------------------------------------------------------- /// ----------------------------------------------------------------------------
/// ///
/// In this section we define two functions: /// In this section we define the API used by core TensorFlow to initialize a
/// * `TF_InitPlugin`: must be present in the plugin shared object as it will /// filesystem provided by a plugin. That is, we define the following:
/// be called by core TensorFlow when the filesystem plugin is loaded; /// * `TF_InitPlugin` function: must be present in the plugin shared object as
/// * `RegisterFilesystemPlugin`: it is implemented by core TensorFlow but /// it will be called by core TensorFlow when the filesystem plugin is
/// plugins must call it in their `TF_InitPlugin`, usually using the macro /// loaded;
/// `TF_REGISTER_FILESYSTEM_PLUGIN`. /// * `TF_FilesystemPluginInfo` struct: used to transfer information between
/// plugins and core TensorFlow about the operations provided and metadata;
/// * `TF_SetFilesystemVersionMetadata` function: must be called by plugins in
/// their `TF_InitPlugin` to record the versioning information the plugins
/// are compiled against.
/// ///
/// The `TF_InitPlugin` function is used by plugins to set up the data /// The `TF_InitPlugin` function is used by plugins to set up the data
/// structures that implement this interface, as presented in Section 2. /// structures that implement this interface, as presented in Section 2. In
/// /// order to not have plugin shared objects call back symbols defined in core
/// The `RegisterFilesystemPlugin` is used by core TensorFlow to check that /// TensorFlow, `TF_InitPlugin` has a `TF_FilesystemPluginInfo` argument which
/// plugins satisfy the requirements expected by core TensorFlow, as follows: /// the plugin must fill (using the `TF_SetFilesystemVersionMetadata` for the
/// 1. If ABI numbers don't match we don't load the plugin, else we continue. /// metadata and setting up all the supported operations and the URI schemes
/// 2. If the API numbers are mismatched, we warn the user and continue /// that are supported).
/// loading the plugin.
/// 3. If any required operation is missing, we stop loading the plugin.
///
/// If all these checks succeed, we copy the plugin operations to a different
/// memory location so that core TensorFlow has the guarantee that they won't be
/// changed by plugins at a later time. Finally, we initialize the opaque
/// pointer of `TF_Filesystem` by calling the required `init` function of
/// `TF_FilesystemOps` and if that succeeds we register the filesystem.
// Initializes a TensorFlow plugin. /// This structure incorporates the operations defined in Section 2 and the
// /// metadata defined in section 3, allowing plugins to define different ops
// Must be implemented by the plugin DSO. It is called by TensorFlow runtime. /// for different URI schemes.
// ///
// Filesystem plugins can be loaded on demand by users via /// Every URI scheme is of the form "fs" for URIs of form "fs:///path/to/file".
// `Env::LoadLibrary` or during TensorFlow's startup if they are on certain /// For local filesystems (i.e., when the URI is "/path/to/file"), the scheme
// paths (although this has a security risk if two plugins register for the /// must be "". The scheme must never be `nullptr`.
// same filesystem and the malicious one loads before the legimitate one - ///
// but we consider this to be something that users should care about and /// Every plugin fills this in `TF_InitPlugin`, using the alocator passed as
// manage themselves). In both of these cases, core TensorFlow looks for /// argument to allocate memory. After `TF_InitPlugin` finishes, core
// the `TF_InitPlugin` symbol and calls that function. /// TensorFlow uses the information present in this to initialize filesystems
// /// for the URI schemes that the plugin requests.
// A plugin is loaded only if this `status` is `TF_OK` after the call. ///
TF_CAPI_EXPORT extern void TF_InitPlugin(TF_Status* status); /// All pointers defined in this structure point to memory allocated by the DSO
/// using an allocator provided by core TensorFlow when calling `TF_InitPlugin`.
///
/// IMPORTANT: To maintain binary compatibility, the layout of this structure
/// must not change! In the unlikely case that a new type of file needs to be
/// supported, add the new ops and metadata at the end of the structure.
typedef struct TF_FilesystemPluginInfo {
char* scheme;
int filesystem_ops_abi;
int filesystem_ops_api;
size_t filesystem_ops_size;
TF_FilesystemOps* filesystem_ops;
int random_access_file_ops_abi;
int random_access_file_ops_api;
size_t random_access_file_ops_size;
TF_RandomAccessFileOps* random_access_file_ops;
int writable_file_ops_abi;
int writable_file_ops_api;
size_t writable_file_ops_size;
TF_WritableFileOps* writable_file_ops;
int read_only_memory_region_ops_abi;
int read_only_memory_region_ops_api;
size_t read_only_memory_region_ops_size;
TF_ReadOnlyMemoryRegionOps* read_only_memory_region_ops;
} TF_FilesystemPluginInfo;
/// Registers a filesystem plugin so that core TensorFlow can use it. /// Convenience function for setting the versioning metadata.
/// ///
/// Must be called by the plugin during `TF_InitPlugin`, usually by using the /// The argument is guaranteed to not be `nullptr`.
/// convenience `TF_REGISTER_FILESYSTEM_PLUGIN` macro.
/// ///
/// Arguments (grouped by category): /// We want this to be defined in the plugin's memory space and we guarantee
/// * `..ABI`: ABI compatibility numbers (see Section 3.). /// that core TensorFlow will never call this.
/// * `..API`: API compatibility numbers (see Section 3.). static inline void TF_SetFilesystemVersionMetadata(
/// * `..Size`: Sizes of the operation tables (see Section 3.). TF_FilesystemPluginInfo* info) {
/// * `scheme`: The URI scheme that plugin is registering filesystems for. info->filesystem_ops_abi = TF_FILESYSTEM_OPS_ABI;
/// Must be of the form "fs" for URIs of form "fs:///path/to/file". For info->filesystem_ops_api = TF_FILESYSTEM_OPS_API;
/// local filesystems (i.e., when the URI is "/path/to/file"), `scheme` info->filesystem_ops_size = TF_FILESYSTEM_OPS_SIZE;
/// must be "". Must never be `nullptr`. info->random_access_file_ops_abi = TF_RANDOM_ACCESS_FILE_OPS_ABI;
/// * `..Ops`: The function tables provided by the plugin. Owned by the info->random_access_file_ops_api = TF_RANDOM_ACCESS_FILE_OPS_API;
/// plugin, but core TensorFlow makes a copy of these. info->random_access_file_ops_size = TF_RANDOM_ACCESS_FILE_OPS_SIZE;
/// * `status`: The output variable for representing success/failure. info->writable_file_ops_abi = TF_WRITABLE_FILE_OPS_ABI;
/// info->writable_file_ops_api = TF_WRITABLE_FILE_OPS_API;
/// Sets `status` to `TF_OK` if plugin was registered and filesystem operations info->writable_file_ops_size = TF_WRITABLE_FILE_OPS_SIZE;
/// can be invoked from anywhere during TensorFlow's runtime. Any other value of info->read_only_memory_region_ops_abi = TF_READ_ONLY_MEMORY_REGION_OPS_ABI;
/// `status` means that plugin failed to load properly and as such the info->read_only_memory_region_ops_api = TF_READ_ONLY_MEMORY_REGION_OPS_API;
/// operations it provides cannot be used at all (i.e., core TensorFlow will info->read_only_memory_region_ops_size = TF_READ_ONLY_MEMORY_REGION_OPS_SIZE;
/// never run them, returning early with `TF_UNIMPLEMENTED` or similar error }
/// values).
TF_CAPI_EXPORT extern void RegisterFilesystemPlugin(
int pluginFilesystemOpsABI, int pluginFilesystemOpsAPI,
size_t pluginFilesystemOpsSize, int pluginRandomAccessFileOpsABI,
int pluginRandomAccessFileOpsAPI, size_t pluginRandomAccessFileOpsSize,
int pluginWritableFileOpsABI, int pluginWritableFileOpsAPI,
size_t pluginWritableFileOpsSize, int pluginReadOnlyMemoryRegionOpsABI,
int pluginReadOnlyMemoryRegionOpsAPI,
size_t pluginReadOnlyMemoryRegionOpsSize, const char* scheme,
const TF_FilesystemOps* pluginFilesystemOps,
const TF_RandomAccessFileOps* pluginRandomAccessFileOps,
const TF_WritableFileOps* pluginWritableFileOps,
const TF_ReadOnlyMemoryRegionOps* pluginReadOnlyMemoryRegionOps,
TF_Status* status);
/// This macro is just a convenience wrapper around `RegisterFilesystemPlugin`. /// Initializes a TensorFlow plugin.
/// Plugins should prefer using this macro instead of a direct call. ///
#define TF_REGISTER_FILESYSTEM_PLUGIN( \ /// Must be implemented by the plugin DSO. It is called by TensorFlow runtime.
scheme, pluginFilesystemOps, pluginRandomAccessFileOps, \ ///
pluginWritableFileOps, pluginReadOnlyMemoryRegionOps, status) \ /// Filesystem plugins can be loaded on demand by users via
RegisterFilesystemPlugin( \ /// `Env::LoadLibrary` or during TensorFlow's startup if they are on certain
TF_FILESYSTEM_OPS_ABI, TF_FILESYSTEM_OPS_API, TF_FILESYSTEM_OPS_SIZE, \ /// paths (although this has a security risk if two plugins register for the
TF_RANDOM_ACCESS_FILE_OPS_ABI, TF_RANDOM_ACCESS_FILE_OPS_API, \ /// same filesystem and the malicious one loads before the legimitate one -
TF_RANDOM_ACCESS_FILE_OPS_SIZE, TF_WRITABLE_FILE_OPS_ABI, \ /// but we consider this to be something that users should care about and
TF_WRITABLE_FILE_OPS_API, TF_WRITABLE_FILE_OPS_SIZE, \ /// manage themselves). In both of these cases, core TensorFlow looks for
TF_READ_ONLY_MEMORY_REGION_OPS_ABI, TF_READ_ONLY_MEMORY_REGION_OPS_API, \ /// the `TF_InitPlugin` symbol and calls this function.
TF_READ_ONLY_MEMORY_REGION_OPS_SIZE, scheme, pluginFilesystemOps, \ ///
pluginRandomAccessFileOps, pluginWritableFileOps, \ /// All memory allocated by this function must be allocated via the `allocator`
pluginReadOnlyMemoryRegionOps, status) /// argument.
///
/// For every filesystem URI scheme that this plugin supports, the plugin must
/// add one `TF_FilesystemPluginInfo` entry in `plugin_info`.
///
/// Returns number of entries in `plugin_info` (i.e., number of URI schemes
/// supported).
TF_CAPI_EXPORT extern int TF_InitPlugin(void* (*allocator)(size_t size),
TF_FilesystemPluginInfo** plugin_info);
#ifdef __cplusplus #ifdef __cplusplus
} // end extern "C" } // end extern "C"

View File

@ -18,11 +18,10 @@ limitations under the License.
#include <string> #include <string>
#include <utility> #include <utility>
#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h"
#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/file_system_helper.h" #include "tensorflow/core/platform/file_system_helper.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/util/ptr_util.h" #include "tensorflow/core/util/ptr_util.h"
// TODO(mihaimaruseac): After all filesystems are converted, all calls to // TODO(mihaimaruseac): After all filesystems are converted, all calls to
@ -435,4 +434,8 @@ Status ModularWritableFile::Tell(int64* position) {
return StatusFromTF_Status(plugin_status.get()); return StatusFromTF_Status(plugin_status.get());
} }
Status RegisterFilesystemPlugin(const std::string& dso_path) {
return filesystem_registration::RegisterFilesystemPluginImpl(dso_path);
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -32,7 +32,7 @@ namespace tensorflow {
// TODO(b/143949615): After all filesystems are converted, this file will be // TODO(b/143949615): After all filesystems are converted, this file will be
// moved to core/platform, and this class can become a singleton and replace the // moved to core/platform, and this class can become a singleton and replace the
// need for `Env::Default()`. At that time, we might decide to remove the need // need for `Env::Default()`. At that time, we might decide to remove the need
// for `Env::Default()` altoghether, but that's a different project, not in // for `Env::Default()` altogether, but that's a different project, not in
// scope for now. I'm just mentioning this here as that transition will mean // scope for now. I'm just mentioning this here as that transition will mean
// removal of the registration part from `Env` and adding it here instead: we // removal of the registration part from `Env` and adding it here instead: we
// will need tables to hold for each scheme the function tables that implement // will need tables to hold for each scheme the function tables that implement
@ -156,6 +156,9 @@ class ModularReadOnlyMemoryRegion final : public ReadOnlyMemoryRegion {
TF_DISALLOW_COPY_AND_ASSIGN(ModularReadOnlyMemoryRegion); TF_DISALLOW_COPY_AND_ASSIGN(ModularReadOnlyMemoryRegion);
}; };
// Registers a filesystem plugin so that core TensorFlow can use it.
Status RegisterFilesystemPlugin(const std::string& dso_path);
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_H_ #endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_H_

View File

@ -0,0 +1,325 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
// Checks that all schemes provided by a plugin are valid.
// TODO(mihaimaruseac): More validation could be done here, based on supported
// charset, maximum length, etc. Punting it for later.
static Status ValidateScheme(const char* scheme) {
if (scheme == nullptr)
return errors::InvalidArgument(
"Attempted to register filesystem with `nullptr` URI scheme");
return Status::OK();
}
// Checks if the plugin and core ABI numbers match.
//
// If the numbers don't match, plugin cannot be loaded.
static Status CheckABI(int pluginABI, int coreABI, StringPiece where) {
if (pluginABI != coreABI)
return errors::FailedPrecondition(
strings::StrCat("Plugin ABI (", pluginABI, ") for ", where,
" operations doesn't match expected core ABI (",
coreABI, "). Plugin cannot be loaded."));
return Status::OK();
}
// Checks if the plugin and core ABI numbers match, for all operations.
//
// If the numbers don't match, plugin cannot be loaded.
//
// Uses the simpler `CheckABI(int, int, StringPiece)`.
static Status ValidateABI(const TF_FilesystemPluginInfo* info) {
TF_RETURN_IF_ERROR(
CheckABI(info->filesystem_ops_abi, TF_FILESYSTEM_OPS_ABI, "filesystem"));
if (info->random_access_file_ops != nullptr)
TF_RETURN_IF_ERROR(CheckABI(info->random_access_file_ops_abi,
TF_RANDOM_ACCESS_FILE_OPS_ABI,
"random access file"));
if (info->writable_file_ops != nullptr)
TF_RETURN_IF_ERROR(CheckABI(info->writable_file_ops_abi,
TF_WRITABLE_FILE_OPS_ABI, "writable file"));
if (info->read_only_memory_region_ops != nullptr)
TF_RETURN_IF_ERROR(CheckABI(info->read_only_memory_region_ops_abi,
TF_READ_ONLY_MEMORY_REGION_OPS_ABI,
"read only memory region"));
return Status::OK();
}
// Checks if the plugin and core API numbers match, logging mismatches.
static void CheckAPI(int plugin_API, int core_API, StringPiece where) {
if (plugin_API != core_API) {
VLOG(0) << "Plugin API (" << plugin_API << ") for " << where
<< " operations doesn't match expected core API (" << core_API
<< "). Plugin will be loaded but functionality might be missing.";
}
}
// Checks if the plugin and core API numbers match, for all operations.
//
// Uses the simpler `CheckAPIHelper(int, int, StringPiece)`.
static void ValidateAPI(const TF_FilesystemPluginInfo* info) {
CheckAPI(info->filesystem_ops_api, TF_FILESYSTEM_OPS_API, "filesystem");
if (info->random_access_file_ops != nullptr)
CheckAPI(info->random_access_file_ops_api, TF_RANDOM_ACCESS_FILE_OPS_API,
"random access file");
if (info->writable_file_ops != nullptr)
CheckAPI(info->writable_file_ops_api, TF_WRITABLE_FILE_OPS_API,
"writable file");
if (info->read_only_memory_region_ops != nullptr)
CheckAPI(info->read_only_memory_region_ops_api,
TF_READ_ONLY_MEMORY_REGION_OPS_API, "read only memory region");
}
// Validates the filesystem operations supplied by the plugin.
static Status ValidateHelper(const TF_FilesystemOps* ops) {
if (ops == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without operations");
if (ops->init == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `init` operation");
if (ops->cleanup == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `cleanup` operation");
return Status::OK();
}
// Validates the random access file operations supplied by the plugin.
static Status ValidateHelper(const TF_RandomAccessFileOps* ops) {
if (ops == nullptr) {
// We allow filesystems where files can only be written to (from TF code)
return Status::OK();
}
if (ops->cleanup == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `cleanup` operation on random "
"access files");
return Status::OK();
}
// Validates the writable file operations supplied by the plugin.
static Status ValidateHelper(const TF_WritableFileOps* ops) {
if (ops == nullptr) {
// We allow read-only filesystems
return Status::OK();
}
if (ops->cleanup == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `cleanup` operation on writable "
"files");
return Status::OK();
}
// Validates the read only memory region operations given by the plugin.
static Status ValidateHelper(const TF_ReadOnlyMemoryRegionOps* ops) {
if (ops == nullptr) {
// read only memory region support is always optional
return Status::OK();
}
if (ops->cleanup == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `cleanup` operation on read "
"only memory regions");
if (ops->data == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `data` operation on read only "
"memory regions");
if (ops->length == nullptr)
return errors::FailedPrecondition(
"Trying to register filesystem without `length` operation on read only "
"memory regions");
return Status::OK();
}
// Validates the operations supplied by the plugin.
//
// Uses the 4 simpler `ValidateHelper(const TF_...*)` to validate each
// individual function table and then checks that the function table for a
// specific file type exists if the plugin offers support for creating that
// type of files.
static Status ValidateOperations(const TF_FilesystemPluginInfo* info) {
TF_RETURN_IF_ERROR(ValidateHelper(info->filesystem_ops));
TF_RETURN_IF_ERROR(ValidateHelper(info->random_access_file_ops));
TF_RETURN_IF_ERROR(ValidateHelper(info->writable_file_ops));
TF_RETURN_IF_ERROR(ValidateHelper(info->read_only_memory_region_ops));
if (info->filesystem_ops->new_random_access_file != nullptr &&
info->random_access_file_ops == nullptr)
return errors::FailedPrecondition(
"Filesystem allows creation of random access files but no "
"operations on them have been supplied.");
if ((info->filesystem_ops->new_writable_file != nullptr ||
info->filesystem_ops->new_appendable_file != nullptr) &&
info->writable_file_ops == nullptr)
return errors::FailedPrecondition(
"Filesystem allows creation of writable files but no "
"operations on them have been supplied.");
if (info->filesystem_ops->new_read_only_memory_region_from_file != nullptr &&
info->read_only_memory_region_ops == nullptr)
return errors::FailedPrecondition(
"Filesystem allows creation of readonly memory regions but no "
"operations on them have been supplied.");
return Status::OK();
}
// Copies a function table from plugin memory space to core memory space.
//
// This has three benefits:
// * allows having newer plugins than the current core TensorFlow: the
// additional entries in the plugin's table are just discarded;
// * allows having older plugins than the current core TensorFlow (though
// we are still warning users): the entries that core TensorFlow expects
// but plugins didn't provide will be set to `nullptr` values and core
// TensorFlow will know to not call these on behalf of users;
// * increased security as plugins will not be able to alter function table
// after loading up. Thus, malicious plugins can't alter functionality to
// probe for gadgets inside core TensorFlow. We can even protect the area
// of memory where the copies reside to not allow any more writes to it
// after all copies are created.
template <typename T>
static std::unique_ptr<const T> CopyToCore(const T* plugin_ops,
size_t plugin_size) {
if (plugin_ops == nullptr) return nullptr;
size_t copy_size = std::min(plugin_size, sizeof(T));
auto core_ops = tensorflow::MakeUnique<T>();
memset(core_ops.get(), 0, sizeof(T));
memcpy(core_ops.get(), plugin_ops, copy_size);
return core_ops;
}
// Registers one filesystem from the plugin.
static Status RegisterFileSystem(const TF_FilesystemPluginInfo* info) {
// Step 1: Copy all the function tables to core TensorFlow memory space
auto core_filesystem_ops = CopyToCore<TF_FilesystemOps>(
info->filesystem_ops, info->filesystem_ops_size);
auto core_random_access_file_ops = CopyToCore<TF_RandomAccessFileOps>(
info->random_access_file_ops, info->random_access_file_ops_size);
auto core_writable_file_ops = CopyToCore<TF_WritableFileOps>(
info->writable_file_ops, info->writable_file_ops_size);
auto core_read_only_memory_region_ops =
CopyToCore<TF_ReadOnlyMemoryRegionOps>(
info->read_only_memory_region_ops,
info->read_only_memory_region_ops_size);
// Step 2: Initialize the opaque filesystem structure
auto filesystem = tensorflow::MakeUnique<TF_Filesystem>();
TF_Status* c_status = TF_NewStatus();
Status status = Status::OK();
core_filesystem_ops->init(filesystem.get(), c_status);
status = Status(c_status->status);
TF_DeleteStatus(c_status);
if (!status.ok()) return status;
// Step 3: Actual registration
return Env::Default()->RegisterFileSystem(
info->scheme, tensorflow::MakeUnique<tensorflow::ModularFileSystem>(
std::move(filesystem), std::move(core_filesystem_ops),
std::move(core_random_access_file_ops),
std::move(core_writable_file_ops),
std::move(core_read_only_memory_region_ops)));
}
// Registers all filesystems, if plugin is providing valid information.
//
// Extracted to a separate function so that pointers inside `info` are freed
// by the caller regardless of whether validation/registration failed or not.
static Status ValidateAndRegisterFilesystems(
const TF_FilesystemPluginInfo* info) {
TF_RETURN_IF_ERROR(ValidateScheme(info->scheme));
TF_RETURN_IF_ERROR(ValidateABI(info));
ValidateAPI(info); // we just warn on API number mismatch
TF_RETURN_IF_ERROR(ValidateOperations(info));
TF_RETURN_IF_ERROR(RegisterFileSystem(info));
return Status::OK();
}
// Alocates memory in plugin DSO.
//
// Provided by core TensorFlow so that it can free this memory after DSO is
// loaded and filesystem information has been used to register the filesystem.
static void* basic_allocator(size_t size) { return calloc(1, size); }
namespace filesystem_registration {
Status RegisterFilesystemPluginImpl(const std::string& dso_path) {
// Step 1: Load plugin
Env* env = Env::Default();
void* dso_handle;
TF_RETURN_IF_ERROR(env->LoadLibrary(dso_path.c_str(), &dso_handle));
// Step 2: Load symbol for `TF_InitPlugin`
void* dso_symbol;
TF_RETURN_IF_ERROR(
env->GetSymbolFromLibrary(dso_handle, "TF_InitPlugin", &dso_symbol));
// Step 3: Call `TF_InitPlugin`
TF_FilesystemPluginInfo* info = nullptr;
auto TF_InitPlugin = reinterpret_cast<int (*)(
decltype(&basic_allocator), TF_FilesystemPluginInfo**)>(dso_symbol);
int num_schemes = TF_InitPlugin(&basic_allocator, &info);
if (num_schemes < 0 || info == nullptr)
return errors::InvalidArgument("DSO returned invalid filesystem data");
// Step 4: Validate and register all filesystems
// Try to register as many filesystems as possible.
// Free memory once we no longer need it
Status status;
for (int i = 0; i < num_schemes; i++) {
status.Update(ValidateAndRegisterFilesystems(&info[i]));
free(info[i].scheme);
free(info[i].filesystem_ops);
free(info[i].random_access_file_ops);
free(info[i].writable_file_ops);
free(info[i].read_only_memory_region_ops);
}
free(info);
return status;
}
} // namespace filesystem_registration
} // namespace tensorflow

View File

@ -0,0 +1,28 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
namespace filesystem_registration {
Status RegisterFilesystemPluginImpl(const std::string& dso_path);
} // namespace filesystem_registration
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_

View File

@ -1,35 +1,47 @@
# Experimental posix filesystem plugin. # Experimental posix filesystem plugin.
load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
package( package(
default_visibility = ["//visibility:private"],
licenses = ["notice"], # Apache 2.0 licenses = ["notice"], # Apache 2.0
) )
# Although this target results in a shared object that will be loaded at # Filesystem implementation for POSIX environments: Linux, MacOS, Android, etc.
# runtime, this target must be a `cc_library` instead of a `cc_binary`. Making tf_cc_shared_object(
# it a `cc_binary` requires `linkshared = True`. In turn, this brings in several name = "libposix_filesystem.so",
# TensorFlow symbols under `tensorflow::` namespace, for which we have no ABI framework_so = [],
# guarantees. Hence, in order to maintain ABI compatibility, this is marked as a linkstatic = False,
# `cc_library` for now and we will revisit in the future. visibility = ["//visibility:public"],
# TODO(mihaimaruseac): Determine if `cc_binary` makes more sense (when all deps = [":posix_filesystem_impl"],
# filesystems are converted and BUILD files are refactored to be modular). )
# TODO(b/144585140): The helpers should be separated into a different BUILD target
# but doing that would result in symbols not being visible when loading plugin. # The real implementation of the filesystem.
# Revisit this once POSIX filesystem completely lands. See also the other TODO.
# This also has the unfortunate effect that both versions of copy_file get
# compiled, regardless of which one actually gets used!
cc_library( cc_library(
name = "posix_filesystem", name = "posix_filesystem_impl",
srcs = [ srcs = ["posix_filesystem.cc"],
"posix_filesystem.cc",
"posix_filesystem_helper.cc",
"posix_filesystem_helper.h",
"copy_file.h",
] + select({
"//tensorflow:linux_x86_64": ["copy_file_linux.cc"],
"//conditions:default": ["copy_file_portable.cc"],
}),
deps = [ deps = [
":posix_filesystem_helper",
"//tensorflow/c:tf_status", "//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface", "//tensorflow/c/experimental/filesystem:filesystem_interface",
], ],
) )
# Library implementing helper functionality, so that the above only contains
# the API implementation for modular filesystems.
cc_library(
name = "posix_filesystem_helper",
srcs = ["posix_filesystem_helper.cc"],
hdrs = ["posix_filesystem_helper.h"],
deps = [":copy_file"],
)
# On Linux, we can copy files faster using `sendfile`. But not elsewhere.
# Hence, this private library to select which implementation to use.
cc_library(
name = "copy_file",
srcs = select({
"//tensorflow:linux_x86_64": ["copy_file_linux.cc"],
"//conditions:default": ["copy_file_portable.cc"],
}),
hdrs = ["copy_file.h"],
)

View File

@ -24,8 +24,6 @@ limitations under the License.
#include <sys/stat.h> #include <sys/stat.h>
#include <unistd.h> #include <unistd.h>
#include <vector>
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" #include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_helper.h" #include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_helper.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
@ -396,48 +394,65 @@ static int GetChildren(const TF_Filesystem* filesystem, const char* path,
} // namespace tf_posix_filesystem } // namespace tf_posix_filesystem
void TF_InitPlugin(TF_Status* status) { int TF_InitPlugin(void* (*allocator)(size_t), TF_FilesystemPluginInfo** info) {
TF_RandomAccessFileOps random_access_file_ops = { const int num_schemes = 2;
tf_random_access_file::Cleanup, *info = static_cast<TF_FilesystemPluginInfo*>(
tf_random_access_file::Read, allocator(num_schemes * sizeof((*info)[0])));
};
TF_WritableFileOps writable_file_ops = {
tf_writable_file::Cleanup, tf_writable_file::Append,
tf_writable_file::Tell, tf_writable_file::Flush,
tf_writable_file::Sync, tf_writable_file::Close,
};
TF_ReadOnlyMemoryRegionOps read_only_memory_region_ops = {
tf_read_only_memory_region::Cleanup,
tf_read_only_memory_region::Data,
tf_read_only_memory_region::Length,
};
TF_FilesystemOps filesystem_ops = {
tf_posix_filesystem::Init,
tf_posix_filesystem::Cleanup,
tf_posix_filesystem::NewRandomAccessFile,
tf_posix_filesystem::NewWritableFile,
tf_posix_filesystem::NewAppendableFile,
tf_posix_filesystem::NewReadOnlyMemoryRegionFromFile,
tf_posix_filesystem::CreateDir,
/*recursively_create_dir=*/nullptr,
tf_posix_filesystem::DeleteFile,
tf_posix_filesystem::DeleteDir,
/*delete_recursively=*/nullptr,
tf_posix_filesystem::RenameFile,
tf_posix_filesystem::CopyFile,
tf_posix_filesystem::PathExists,
/*paths_exist=*/nullptr,
tf_posix_filesystem::Stat,
/*is_directory=*/nullptr,
/*get_file_size=*/nullptr,
/*translate_name=*/nullptr,
tf_posix_filesystem::GetChildren,
/*get_matching_paths=*/nullptr,
/*flush_caches=*/nullptr,
};
for (const char* scheme : {"", "file"}) for (int i = 0; i < num_schemes; i++) {
TF_REGISTER_FILESYSTEM_PLUGIN(scheme, &filesystem_ops, TF_FilesystemPluginInfo* current_info = &((*info)[i]);
&random_access_file_ops, &writable_file_ops, TF_SetFilesystemVersionMetadata(current_info);
&read_only_memory_region_ops, status);
current_info->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
allocator(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
current_info->random_access_file_ops->cleanup =
tf_random_access_file::Cleanup;
current_info->random_access_file_ops->read = tf_random_access_file::Read;
current_info->writable_file_ops =
static_cast<TF_WritableFileOps*>(allocator(TF_WRITABLE_FILE_OPS_SIZE));
current_info->writable_file_ops->cleanup = tf_writable_file::Cleanup;
current_info->writable_file_ops->append = tf_writable_file::Append;
current_info->writable_file_ops->tell = tf_writable_file::Tell;
current_info->writable_file_ops->flush = tf_writable_file::Flush;
current_info->writable_file_ops->sync = tf_writable_file::Sync;
current_info->writable_file_ops->close = tf_writable_file::Close;
current_info->read_only_memory_region_ops =
static_cast<TF_ReadOnlyMemoryRegionOps*>(
allocator(TF_READ_ONLY_MEMORY_REGION_OPS_SIZE));
current_info->read_only_memory_region_ops->cleanup =
tf_read_only_memory_region::Cleanup;
current_info->read_only_memory_region_ops->data =
tf_read_only_memory_region::Data;
current_info->read_only_memory_region_ops->length =
tf_read_only_memory_region::Length;
current_info->filesystem_ops =
static_cast<TF_FilesystemOps*>(allocator(TF_FILESYSTEM_OPS_SIZE));
current_info->filesystem_ops->init = tf_posix_filesystem::Init;
current_info->filesystem_ops->cleanup = tf_posix_filesystem::Cleanup;
current_info->filesystem_ops->new_random_access_file =
tf_posix_filesystem::NewRandomAccessFile;
current_info->filesystem_ops->new_writable_file =
tf_posix_filesystem::NewWritableFile;
current_info->filesystem_ops->new_appendable_file =
tf_posix_filesystem::NewAppendableFile;
current_info->filesystem_ops->new_read_only_memory_region_from_file =
tf_posix_filesystem::NewReadOnlyMemoryRegionFromFile;
current_info->filesystem_ops->create_dir = tf_posix_filesystem::CreateDir;
current_info->filesystem_ops->delete_file = tf_posix_filesystem::DeleteFile;
current_info->filesystem_ops->delete_dir = tf_posix_filesystem::DeleteDir;
current_info->filesystem_ops->rename_file = tf_posix_filesystem::RenameFile;
current_info->filesystem_ops->copy_file = tf_posix_filesystem::CopyFile;
current_info->filesystem_ops->path_exists = tf_posix_filesystem::PathExists;
current_info->filesystem_ops->stat = tf_posix_filesystem::Stat;
current_info->filesystem_ops->get_children =
tf_posix_filesystem::GetChildren;
}
(*info)[0].scheme = strdup("");
(*info)[1].scheme = strdup("file");
return num_schemes;
} }

View File

@ -44,7 +44,7 @@ int TransferFileContents(const char* src, const char* dst, mode_t mode,
} }
// Both files have been opened, do the transfer. // Both files have been opened, do the transfer.
// Since errno would be overriden by `close` below, save it here. // Since errno would be overridden by `close` below, save it here.
int error_code = 0; int error_code = 0;
if (CopyFileContents(dst_fd, src_fd, size) < 0) error_code = errno; if (CopyFileContents(dst_fd, src_fd, size) < 0) error_code = errno;

View File

@ -0,0 +1,36 @@
# Experimental windows filesystem plugin.
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
package(
licenses = ["notice"], # Apache 2.0
)
# Filesystem implementation for Windows environment
tf_cc_shared_object(
name = "windows_filesystem.dll",
framework_so = [],
linkstatic = False,
tags = [
"manual",
"nobuilder",
"notap",
],
visibility = ["//visibility:public"],
deps = [":windows_filesystem_impl"],
)
# The real implementation of the filesystem.
cc_library(
name = "windows_filesystem_impl",
srcs = ["windows_filesystem.cc"],
copts = get_win_copts(),
tags = [
"manual",
"nobuilder",
"notap",
],
deps = [
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
],
)

View File

@ -0,0 +1,70 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stdlib.h>
#include <string.h>
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h"
// Implementation of a filesystem for POSIX environments.
// This filesystem will support `file://` and empty (local) URI schemes.
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {
// TODO(mihaimaruseac): Implement later
} // namespace tf_random_access_file
// SECTION 2. Implementation for `TF_WritableFile`
// ----------------------------------------------------------------------------
namespace tf_writable_file {
// TODO(mihaimaruseac): Implement later
} // namespace tf_writable_file
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
// ----------------------------------------------------------------------------
namespace tf_read_only_memory_region {
// TODO(mihaimaruseac): Implement later
} // namespace tf_read_only_memory_region
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
// ----------------------------------------------------------------------------
namespace tf_windows_filesystem {
// TODO(mihaimaruseac): Implement later
} // namespace tf_windows_filesystem
int TF_InitPlugin(void* (*allocator)(size_t), TF_FilesystemPluginInfo** info) {
const int num_schemes = 2;
*info = static_cast<TF_FilesystemPluginInfo*>(
allocator(num_schemes * sizeof((*info)[0])));
for (int i = 0; i < num_schemes; i++) {
TF_FilesystemPluginInfo* current_info = &((*info)[i]);
TF_SetFilesystemVersionMetadata(current_info);
}
(*info)[0].scheme = strdup("");
(*info)[1].scheme = strdup("file");
return num_schemes;
}

View File

@ -181,7 +181,8 @@ void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor,
return; return;
} }
const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i)); const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i));
TF_Tensor* result = ::tensorflow::TF_TensorFromTensor(cc_tensor, status); TF_Tensor* result =
::tensorflow::TF_TensorFromTensor(cc_tensor, &status->status);
if (TF_GetCode(status) == TF_OK) { if (TF_GetCode(status) == TF_OK) {
*tensor = result; *tensor = result;
} }

View File

@ -133,7 +133,7 @@ TEST(OpsTest, TestShapeInference_VectorizeFunction) {
TEST(OpsTest, AttributeAccessors) { TEST(OpsTest, AttributeAccessors) {
TF_OpDefinitionBuilder* builder = TF_OpDefinitionBuilder* builder =
TF_NewOpDefinitionBuilder("AttributeAccesorsOp"); TF_NewOpDefinitionBuilder("AttributeAccessorsOp");
TF_OpDefinitionBuilderAddAttr(builder, "foo1: int >= 2"); TF_OpDefinitionBuilderAddAttr(builder, "foo1: int >= 2");
TF_OpDefinitionBuilderAddAttr(builder, "foo2: string=\"my string\""); TF_OpDefinitionBuilderAddAttr(builder, "foo2: string=\"my string\"");
TF_OpDefinitionBuilderSetIsCommutative(builder, true); TF_OpDefinitionBuilderSetIsCommutative(builder, true);
@ -151,7 +151,7 @@ TEST(OpsTest, AttributeAccessors) {
op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length); op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
bool found = false; bool found = false;
for (const auto& op : op_list.op()) { for (const auto& op : op_list.op()) {
if (op.name() == "AttributeAccesorsOp") { if (op.name() == "AttributeAccessorsOp") {
ASSERT_TRUE(op.is_commutative()); ASSERT_TRUE(op.is_commutative());
ASSERT_TRUE(op.is_aggregate()); ASSERT_TRUE(op.is_aggregate());
ASSERT_TRUE(op.allows_uninitialized_input()); ASSERT_TRUE(op.allows_uninitialized_input());

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h" #include "tensorflow/c/tf_tensor.h"
#include <memory>
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor_internal.h" #include "tensorflow/c/tf_tensor_internal.h"
@ -103,49 +105,35 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg); buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
} }
TF_Tensor* ret = // TODO(gjn): Make the choice of interface a compile-time configuration.
new TF_Tensor{Tensor(static_cast<tensorflow::DataType>(dtype), tensorflow::TensorInterface ret(
tensorflow::TensorShape(dimvec), buf)}; Tensor(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimvec), buf));
buf->Unref(); buf->Unref();
size_t elem_size = TF_DataTypeSize(dtype); size_t elem_size = TF_DataTypeSize(dtype);
if (elem_size > 0 && len < (elem_size * ret->tensor.NumElements())) { if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
delete ret;
return nullptr; return nullptr;
} }
return ret; return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(ret)};
} }
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) { TF_Tensor* TF_TensorMaybeMove(TF_Tensor* t) {
// It is safe to move the Tensor if and only if we own the unique reference to return t->tensor->CanMove() ? t : nullptr;
// it. In that case, we might as well not delete and reallocate, but a future
// implementation might need to do so.
TensorBuffer* buf = tensorflow::TensorCApi::Buffer(tensor->tensor);
if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
buf->OwnsMemory()) {
return tensor;
}
return nullptr;
} }
void TF_DeleteTensor(TF_Tensor* t) { delete t; } void TF_DeleteTensor(TF_Tensor* t) { delete t; }
TF_DataType TF_TensorType(const TF_Tensor* t) { TF_DataType TF_TensorType(const TF_Tensor* t) { return t->tensor->Type(); }
return static_cast<TF_DataType>(t->tensor.dtype());
}
int TF_NumDims(const TF_Tensor* t) { return t->tensor.dims(); } int TF_NumDims(const TF_Tensor* t) { return t->tensor->NumDims(); }
int64_t TF_Dim(const TF_Tensor* t, int dim_index) { int64_t TF_Dim(const TF_Tensor* t, int dim_index) {
return static_cast<int64_t>(t->tensor.dim_size(dim_index)); return t->tensor->Dim(dim_index);
} }
size_t TF_TensorByteSize(const TF_Tensor* t) { size_t TF_TensorByteSize(const TF_Tensor* t) { return t->tensor->ByteSize(); }
return tensorflow::TensorCApi::Buffer(t->tensor)->size();
}
void* TF_TensorData(const TF_Tensor* t) { void* TF_TensorData(const TF_Tensor* t) { return t->tensor->Data(); }
return tensorflow::TensorCApi::Buffer(t->tensor)->data();
}
int64_t TF_TensorElementCount(const TF_Tensor* t) { int64_t TF_TensorElementCount(const TF_Tensor* t) {
int64_t result = 1; int64_t result = 1;
@ -160,16 +148,69 @@ void TF_TensorBitcastFrom(const TF_Tensor* from, TF_DataType type,
TF_Tensor* to, const int64_t* new_dims, TF_Tensor* to, const int64_t* new_dims,
int num_new_dims, TF_Status* status) { int num_new_dims, TF_Status* status) {
TF_SetStatus(status, TF_OK, ""); TF_SetStatus(status, TF_OK, "");
Status cc_status(
static_cast<tensorflow::TensorInterface*>(to->tensor.get())
->BitcastFrom(*static_cast<const tensorflow::TensorInterface*>(
from->tensor.get()),
type, new_dims, num_new_dims));
Set_TF_Status_from_Status(status, cc_status);
}
namespace tensorflow {
bool TensorInterface::CanMove() const {
// It is safe to move the Tensor if and only if we own the unique reference to
// it. In that case, we might as well not delete and reallocate, but a future
// implementation might need to do so.
TensorBuffer* buf = tensorflow::TensorCApi::Buffer(tensor_);
if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
buf->OwnsMemory()) {
return true;
}
return false;
}
TF_DataType TensorInterface::Type() const {
return static_cast<TF_DataType>(tensor_.dtype());
}
int TensorInterface::NumDims() const { return tensor_.dims(); }
int64_t TensorInterface::Dim(int dim_index) const {
return static_cast<int64_t>(tensor_.dim_size(dim_index));
}
int64_t TensorInterface::NumElements() const {
return static_cast<int64_t>(tensor_.NumElements());
}
size_t TensorInterface::ByteSize() const {
return tensorflow::TensorCApi::Buffer(tensor_)->size();
}
void* TensorInterface::Data() const {
return tensorflow::TensorCApi::Buffer(tensor_)->data();
}
Status TensorInterface::BitcastFrom(const TensorInterface& from,
TF_DataType type, const int64_t* new_dims,
int num_new_dims) {
tensorflow::TensorShape s; tensorflow::TensorShape s;
for (int i = 0; i < num_new_dims; ++i) { for (int i = 0; i < num_new_dims; ++i) {
s.AddDim(new_dims[i]); s.AddDim(new_dims[i]);
} }
Status cc_status(to->tensor.BitcastFrom( return tensor_.BitcastFrom(from.tensor_,
from->tensor, static_cast<tensorflow::DataType>(type), s)); static_cast<tensorflow::DataType>(type), s);
Set_TF_Status_from_Status(status, cc_status);
} }
} // namespace tensorflow
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
void StringEncode(const char* src, size_t src_len, char* dst) {
dst = tensorflow::core::EncodeVarint64(dst, src_len);
memcpy(dst, src, src_len);
}
size_t TF_StringEncode(const char* src, size_t src_len, char* dst, size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
size_t dst_len, TF_Status* status) { size_t dst_len, TF_Status* status) {
const size_t sz = TF_StringEncodedSize(src_len); const size_t sz = TF_StringEncodedSize(src_len);
@ -185,8 +226,7 @@ size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
src_len, "-byte string")); src_len, "-byte string"));
return 0; return 0;
} }
dst = tensorflow::core::EncodeVarint64(dst, src_len); StringEncode(src, src_len, dst);
memcpy(dst, src, src_len);
return sz; return sz;
} }
@ -245,13 +285,11 @@ static TF_Tensor* EmptyTensor(TF_DataType dtype,
namespace tensorflow { namespace tensorflow {
// Non-static for testing. // Non-static for testing.
TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) {
TF_Status* status) { *status = tensorflow::Status::OK();
TF_SetStatus(status, TF_OK, "");
if (!src.IsInitialized()) { if (!src.IsInitialized()) {
Set_TF_Status_from_Status( *status = FailedPrecondition(
status, FailedPrecondition( "attempt to use a tensor with an uninitialized value");
"attempt to use a tensor with an uninitialized value"));
return nullptr; return nullptr;
} }
if (src.NumElements() == 0) { if (src.NumElements() == 0) {
@ -259,14 +297,13 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
} }
if (src.dtype() == tensorflow::DT_RESOURCE) { if (src.dtype() == tensorflow::DT_RESOURCE) {
if (src.shape().dims() != 0) { if (src.shape().dims() != 0) {
Set_TF_Status_from_Status( *status = InvalidArgument(
status, InvalidArgument( "Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
"Unexpected non-scalar DT_RESOURCE tensor seen (shape: ", src.shape().DebugString(),
src.shape().DebugString(), "). Please file a bug at "
"). Please file a bug at " "https://github.com/tensorflow/tensorflow/issues/new, "
"https://github.com/tensorflow/tensorflow/issues/new, " "ideally with a "
"ideally with a " "short code snippet that reproduces this error.");
"short code snippet that reproduces this error."));
return nullptr; return nullptr;
} }
const string str = const string str =
@ -276,12 +313,11 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
return t; return t;
} }
if (src.dtype() != tensorflow::DT_STRING) { if (src.dtype() != tensorflow::DT_STRING) {
auto* result = new TF_Tensor(); Tensor tensor;
if (!result->tensor.CopyFrom(src, src.shape())) { if (!tensor.CopyFrom(src, src.shape())) {
delete result;
return nullptr; return nullptr;
} }
return result; return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(tensor)};
} }
// DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly // DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly
// encoded sequence of strings. // encoded sequence of strings.
@ -305,23 +341,15 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
*offsets = (dst - data_start); *offsets = (dst - data_start);
offsets++; offsets++;
const string& s = srcarray(i); const string& s = srcarray(i);
size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status); const size_t consumed = TF_StringEncodedSize(s.size());
if (TF_GetCode(status) != TF_OK) { StringEncode(s.data(), s.size(), dst);
Set_TF_Status_from_Status(
status,
InvalidArgument("invalid string tensor encoding (string #", i, " of ",
srcarray.size(), "): ", TF_Message(status)));
delete[] base;
return nullptr;
}
dst += consumed; dst += consumed;
dst_len -= consumed; dst_len -= consumed;
} }
if (dst != base + size) { if (dst != base + size) {
Set_TF_Status_from_Status( *status = InvalidArgument(
status, InvalidArgument( "invalid string tensor encoding (decoded ", (dst - base),
"invalid string tensor encoding (decoded ", (dst - base), " bytes, but the tensor is encoded in ", size, " bytes");
" bytes, but the tensor is encoded in ", size, " bytes"));
delete[] base; delete[] base;
return nullptr; return nullptr;
} }
@ -339,31 +367,35 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
} }
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) { Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
if (src->tensor.dtype() == DT_RESOURCE) { return static_cast<const tensorflow::TensorInterface*>(src->tensor.get())
if (src->tensor.dims() != 0) { ->ToTensor(dst);
}
Status TensorInterface::ToTensor(Tensor* dst) const {
if (tensor_.dtype() == DT_RESOURCE) {
if (tensor_.dims() != 0) {
return InvalidArgument( return InvalidArgument(
"Malformed TF_RESOURCE tensor: expected a scalar, got a tensor with " "Malformed TF_RESOURCE tensor: expected a scalar, got a tensor with "
"shape ", "shape ",
src->tensor.shape().DebugString()); tensor_.shape().DebugString());
} }
*dst = Tensor(tensorflow::DT_RESOURCE, src->tensor.shape()); *dst = Tensor(tensorflow::DT_RESOURCE, tensor_.shape());
if (!dst->scalar<tensorflow::ResourceHandle>()().ParseFromString( if (!dst->scalar<tensorflow::ResourceHandle>()().ParseFromString(
string(static_cast<const char*>(TF_TensorData(src)), string(static_cast<const char*>(Data()), ByteSize()))) {
TF_TensorByteSize(src)))) {
return InvalidArgument( return InvalidArgument(
"Malformed TF_RESOUCE tensor: unable to parse resource handle"); "Malformed TF_RESOURCE tensor: unable to parse resource handle");
} }
return Status::OK(); return Status::OK();
} }
if (src->tensor.dtype() != DT_STRING) { if (tensor_.dtype() != DT_STRING) {
*dst = src->tensor; *dst = tensor_;
return Status::OK(); return Status::OK();
} }
// TF_STRING tensors require copying since Tensor class expects a sequence of // TF_STRING tensors require copying since Tensor class expects a sequence of
// string objects. // string objects.
const tensorflow::int64 num_elements = src->tensor.NumElements(); const tensorflow::int64 num_elements = tensor_.NumElements();
const char* input = reinterpret_cast<const char*>(TF_TensorData(src)); const char* input = reinterpret_cast<const char*>(Data());
const size_t src_size = TF_TensorByteSize(src); const size_t src_size = ByteSize();
if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) < if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) <
num_elements) { num_elements) {
return InvalidArgument( return InvalidArgument(
@ -372,7 +404,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
const char* data_start = input + sizeof(tensorflow::uint64) * num_elements; const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
const char* limit = input + src_size; const char* limit = input + src_size;
*dst = Tensor(src->tensor.dtype(), src->tensor.shape()); *dst = Tensor(tensor_.dtype(), tensor_.shape());
auto dstarray = dst->flat<tstring>(); auto dstarray = dst->flat<tstring>();
for (tensorflow::int64 i = 0; i < num_elements; ++i) { for (tensorflow::int64 i = 0; i < num_elements; ++i) {
tensorflow::uint64 offset = tensorflow::uint64 offset =
@ -391,8 +423,8 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
return Status::OK(); return Status::OK();
} }
bool TensorInterface::IsAligned() const { return tensor_.IsAligned(); }
} // namespace tensorflow } // namespace tensorflow
bool TF_TensorIsAligned(const TF_Tensor* tensor) { bool TF_TensorIsAligned(const TF_Tensor* t) { return t->tensor->IsAligned(); }
return tensor->tensor.IsAligned();
}

View File

@ -16,9 +16,12 @@ limitations under the License.
#ifndef TENSORFLOW_C_TF_TENSOR_INTERNAL_H_ #ifndef TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
#define TENSORFLOW_C_TF_TENSOR_INTERNAL_H_ #define TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
#include <memory>
#include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_datatype.h"
#include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_interface.h"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
// Internal structures used by the C API. These are likely to change and should // Internal structures used by the C API. These are likely to change and should
@ -28,7 +31,7 @@ limitations under the License.
// passed to or returned from C functions *by pointer*. Otherwise, changes to // passed to or returned from C functions *by pointer*. Otherwise, changes to
// its internal structure will break the C API's binary interface. // its internal structure will break the C API's binary interface.
typedef struct TF_Tensor { typedef struct TF_Tensor {
::tensorflow::Tensor tensor; std::unique_ptr<AbstractTensorInterface> tensor;
} TF_Tensor; } TF_Tensor;
class TF_ManagedBuffer : public tensorflow::TensorBuffer { class TF_ManagedBuffer : public tensorflow::TensorBuffer {
@ -83,4 +86,5 @@ void* allocate_tensor(const char* operation, size_t len, Allocator* allocator);
// a different Allocator as `arg`. // a different Allocator as `arg`.
void deallocate_buffer(void* data, size_t len, void* arg); void deallocate_buffer(void* data, size_t len, void* arg);
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_ #endif // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_

View File

@ -96,7 +96,7 @@ class SymbolicGradientBuilder {
// Used to identify nodes at which to stop backprop. // Used to identify nodes at which to stop backprop.
std::unordered_set<int> GetStopBackpropNodes( std::unordered_set<int> GetStopBackpropNodes(
const std::vector<bool>& reachable_nodes, const std::vector<bool>& reachable_nodes,
const std::unordered_set<int>& output_nodes); const std::unordered_set<int>& output_nodes) const;
const Scope& scope_; const Scope& scope_;
const ops::GradOpRegistry* registry_; const ops::GradOpRegistry* registry_;
@ -190,7 +190,7 @@ std::vector<bool> SymbolicGradientBuilder::GetReachableNodes() {
std::unordered_set<int> SymbolicGradientBuilder::GetStopBackpropNodes( std::unordered_set<int> SymbolicGradientBuilder::GetStopBackpropNodes(
const std::vector<bool>& reachable_nodes, const std::vector<bool>& reachable_nodes,
const std::unordered_set<int>& output_nodes) { const std::unordered_set<int>& output_nodes) const {
// Output nodes that get transitively consumed by other `outputs_` are stored // Output nodes that get transitively consumed by other `outputs_` are stored
// in `internal_outputs`. // in `internal_outputs`.
std::unordered_set<int> internal_outputs; std::unordered_set<int> internal_outputs;
@ -346,8 +346,8 @@ Status SymbolicGradientBuilder::SumGradients(const Output& src, Output* grad) {
"Unable to find backprop list for node.id ", src.node()->name()); "Unable to find backprop list for node.id ", src.node()->name());
} }
const auto& grads = iter->second; const auto& grads = iter->second;
// Filter any backproped 'NoGradient' Outputs from 'grads' (if needed). // Filter any backpropped 'NoGradient' Outputs from 'grads' (if needed).
// Return any valid backproped gradients that remain after filtering, // Return any valid backpropped gradients that remain after filtering,
// or 'NoGradient' otherwise. // or 'NoGradient' otherwise.
std::vector<Output> grads_to_keep; std::vector<Output> grads_to_keep;
for (const Output& o : grads) { for (const Output& o : grads) {
@ -519,7 +519,7 @@ Status SymbolicGradientBuilder::AddGradients() {
// Backprop along the in edges. // Backprop along the in edges.
// TODO(andydavis) Find cleaner way to map each grad output returned by // TODO(andydavis) Find cleaner way to map each grad output returned by
// gradient function to the src node/output to which it should be // gradient function to the src node/output to which it should be
// backproped. Maybe grad functions can return a vector of Output pairs to // backpropped. Maybe grad functions can return a vector of Output pairs to
// make this association explicit. // make this association explicit.
size_t dx_index = 0; size_t dx_index = 0;
for (const Edge* e : n->in_edges()) { for (const Edge* e : n->in_edges()) {

View File

@ -64,7 +64,7 @@ bool IsZero(const Scope& scope, const Output& grad) {
// Multiply after broadcasting vec to match dimensions of mat. // Multiply after broadcasting vec to match dimensions of mat.
// Args: // Args:
// vec: A 1-D tensor of dimension [D0] // vec: A 1-D tensor of dimension [D0]
// mat: A 2-D tensor of dimesnion [D0, D1] // mat: A 2-D tensor of dimension [D0, D1]
// //
// Returns: // Returns:
// A tensor of dimension [D0, D1], the result fo vec * mat. // A tensor of dimension [D0, D1], the result fo vec * mat.

View File

@ -259,6 +259,9 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) {
RunTest(x, x_init_value, y, y_shape); RunTest(x, x_init_value, y, y_shape);
} }
// TODO(rocm):
// Re-enable this test once 3D pooling is supported on ROCm platform
#ifndef TENSORFLOW_USE_ROCM
TEST_F(NNGradTest, MaxPool3DGradHelper) { TEST_F(NNGradTest, MaxPool3DGradHelper) {
TensorShape x_shape({1, 3, 3, 3, 1}); TensorShape x_shape({1, 3, 3, 3, 1});
TensorShape y_shape({1, 1, 1, 1, 1}); TensorShape y_shape({1, 1, 1, 1, 1});
@ -271,6 +274,7 @@ TEST_F(NNGradTest, MaxPool3DGradHelper) {
SetRandomValuesForMaxPooling<float>(&x_init_value); SetRandomValuesForMaxPooling<float>(&x_init_value);
RunTest(x, x_init_value, y, y_shape); RunTest(x, x_init_value, y, y_shape);
} }
#endif
TEST_F(NNGradTest, AvgPoolGradHelper) { TEST_F(NNGradTest, AvgPoolGradHelper) {
TensorShape x_shape({1, 2, 2, 1}); TensorShape x_shape({1, 2, 2, 1});
@ -283,6 +287,9 @@ TEST_F(NNGradTest, AvgPoolGradHelper) {
RunTest(x, x_shape, y, y_shape); RunTest(x, x_shape, y, y_shape);
} }
// TODO(rocm):
// Re-enable this test once 3D pooling is supported on ROCm platform
#ifndef TENSORFLOW_USE_ROCM
TEST_F(NNGradTest, AvgPool3DGradHelper) { TEST_F(NNGradTest, AvgPool3DGradHelper) {
TensorShape x_shape({1, 3, 3, 3, 1}); TensorShape x_shape({1, 3, 3, 3, 1});
TensorShape y_shape({1, 1, 1, 1, 1}); TensorShape y_shape({1, 1, 1, 1, 1});
@ -293,6 +300,7 @@ TEST_F(NNGradTest, AvgPool3DGradHelper) {
auto y = AvgPool3D(scope_, x, ksize, strides, "SAME"); auto y = AvgPool3D(scope_, x, ksize, strides, "SAME");
RunTest(x, x_shape, y, y_shape); RunTest(x, x_shape, y, y_shape);
} }
#endif
TEST_F(NNGradTest, LRN) { TEST_F(NNGradTest, LRN) {
TensorShape x_shape({1, 1, 2, 1}); TensorShape x_shape({1, 1, 2, 1});

View File

@ -124,13 +124,12 @@ cc_library(
hdrs = ["bundle_v2.h"], hdrs = ["bundle_v2.h"],
deps = [ deps = [
":constants", ":constants",
"@com_google_absl//absl/container:flat_hash_set",
] + if_not_mobile([
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:strcat", "//tensorflow/core/platform:strcat",
"//tensorflow/core/util/tensor_bundle", "//tensorflow/core/util/tensor_bundle",
]), "@com_google_absl//absl/container:flat_hash_set",
],
) )
tf_cc_test( tf_cc_test(

View File

@ -1,5 +1,6 @@
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
load("//tensorflow/core/platform:build_config.bzl", "if_llvm_aarch64_available")
package( package(
default_visibility = ["//visibility:private"], default_visibility = ["//visibility:private"],
@ -27,9 +28,14 @@ cc_library(
"compile.h", "compile.h",
"flags.h", "flags.h",
], ],
defines = if_llvm_aarch64_available(["TF_LLVM_AARCH64_AVAILABLE=1"]),
visibility = ["//tensorflow/python:__pkg__"],
deps = [ deps = [
":aot_only_var_handle_op", ":aot_only_var_handle_op",
":embedded_protocol_buffers", ":embedded_protocol_buffers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"//tensorflow/compiler/tf2xla", "//tensorflow/compiler/tf2xla",
"//tensorflow/compiler/tf2xla:mlir_tf2xla", "//tensorflow/compiler/tf2xla:mlir_tf2xla",
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc", "//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
@ -53,10 +59,13 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory", "@llvm-project//llvm:arm_code_gen", # fixdeps: keep
"@com_google_absl//absl/strings", "@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@com_google_absl//absl/types:span", "@llvm-project//llvm:target",
], "@llvm-project//llvm:x86_code_gen", # fixdeps: keep
] + if_llvm_aarch64_available([
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
]),
) )
tf_cc_test( tf_cc_test(
@ -86,6 +95,19 @@ tf_cc_binary(
deps = [":tfcompile_main"], deps = [":tfcompile_main"],
) )
cc_library(
name = "llvm_targets",
visibility = ["//tensorflow/python:__pkg__"],
deps = [
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@llvm-project//llvm:target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
] + if_llvm_aarch64_available([
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
]),
)
cc_library( cc_library(
name = "tfcompile_main", name = "tfcompile_main",
srcs = ["tfcompile_main.cc"], srcs = ["tfcompile_main.cc"],
@ -104,11 +126,6 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:aarch64_code_gen", # fixdeps: keep
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@llvm-project//llvm:target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
], ],
) )
@ -214,8 +231,13 @@ cc_library(
cc_library( cc_library(
name = "aot_only_var_handle_op", name = "aot_only_var_handle_op",
srcs = ["aot_only_var_handle_op.cc"], srcs = ["aot_only_var_handle_op.cc"],
hdrs = ["aot_only_var_handle_op.h"],
visibility = [
"//tensorflow/compiler/tf2xla:__pkg__",
],
deps = [ deps = [
"//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/core:framework",
], ],
alwayslink = 1, alwayslink = 1,
) )

View File

@ -13,9 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/aot/aot_only_var_handle_op.h"
#include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
@ -51,6 +54,31 @@ void XlaAotOnlyVarHandleOp::Compile(XlaOpKernelContext* context) {
} }
} // namespace } // namespace
REGISTER_XLA_OP(Name("VarHandleOp").CompilationOnly(), XlaAotOnlyVarHandleOp); REGISTER_OP(tfcompile::kXlaAotOnlyVarHandleOp)
.Doc(R"doc(
Internal VarHandleOp registration used for XLA AOT compilation.
)doc")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.Attr("dtype: type")
.Attr("shape: shape")
.Output("resource: resource")
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Scalar());
DataType t;
TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t));
PartialTensorShape p;
TF_RETURN_IF_ERROR(c->GetAttr("shape", &p));
shape_inference::ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s));
c->set_output_handle_shapes_and_types(
0, std::vector<shape_inference::ShapeAndType>{{s, t}});
return Status::OK();
});
REGISTER_XLA_OP(Name(tfcompile::kXlaAotOnlyVarHandleOp).CompilationOnly(),
XlaAotOnlyVarHandleOp);
} // namespace tensorflow } // namespace tensorflow

View File

@ -0,0 +1,27 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_
#define TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_
namespace tensorflow {
namespace tfcompile {
static constexpr const char* const kXlaAotOnlyVarHandleOp =
"_XlaAotOnlyVarHandleOp";
} // namespace tfcompile
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_

View File

@ -74,16 +74,16 @@ void DumpStatsToStdout(const Stats& stats) {
const int kBufSize = 1000; const int kBufSize = 1000;
char buf[kBufSize]; char buf[kBufSize];
snprintf(buf, kBufSize, "Mean with %2.0f%% trimmed:", trim_ratio * 100); snprintf(buf, kBufSize, "Mean with %2.0f%% trimmed:", trim_ratio * 100);
const string label_trimmed(buf); std::string label_trimmed(buf);
snprintf(buf, kBufSize, "Mean of %2.0f%% best:", best_ratio * 100); snprintf(buf, kBufSize, "Mean of %2.0f%% best:", best_ratio * 100);
const string label_best(buf); std::string label_best(buf);
std::vector<std::pair<string, double>> groups = { std::vector<std::pair<std::string, double>> groups = {
{"Best:", sorted_us.front()}, {"Best:", sorted_us.front()},
{"Worst:", sorted_us.back()}, {"Worst:", sorted_us.back()},
{"Median:", sorted_us[count_us / 2]}, {"Median:", sorted_us[count_us / 2]},
{"Mean:", sum_us / count_us}, {"Mean:", sum_us / count_us},
{label_trimmed, sum_us_trimmed / count_us_trimmed}, {std::move(label_trimmed), sum_us_trimmed / count_us_trimmed},
{label_best, sum_us_best / count_us_best}, {std::move(label_best), sum_us_best / count_us_best},
}; };
int max_label_size = 0; int max_label_size = 0;
double max_us = 0; double max_us = 0;
@ -102,7 +102,7 @@ void DumpStatsToStdout(const Stats& stats) {
} }
// Dump stats out. // Dump stats out.
printf("Benchmark ran %zu iterations over %lld us\n", count_us, printf("Benchmark ran %zu iterations over %lld us\n", count_us,
stats.total_us); static_cast<long long>(stats.total_us)); // NOLINT
for (const auto& g : groups) { for (const auto& g : groups) {
printf(" %-*s %*.3f us\n", max_label_size, g.first.c_str(), max_digits + 4, printf(" %-*s %*.3f us\n", max_label_size, g.first.c_str(), max_digits + 4,
g.second); g.second);
@ -114,7 +114,8 @@ void Benchmark(const Options& options, const BenchmarkFn& fn, Stats* stats) {
const int64 max_us = (options.max_micros <= 0 && options.max_iters <= 0) const int64 max_us = (options.max_micros <= 0 && options.max_iters <= 0)
? Options::kDefaultMicros ? Options::kDefaultMicros
: options.max_micros; : options.max_micros;
printf("Running benchmark for %lld us\n", max_us); // NOLINTNEXTLINE
printf("Running benchmark for %lld us\n", static_cast<long long>(max_us));
const int64 start_us = NowMicros(); const int64 start_us = NowMicros();
int64 iters = 0; int64 iters = 0;
while (true) { while (true) {

View File

@ -423,8 +423,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
GenNameToIndexCode(config.fetch(), opts.gen_name_to_index); GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
const string include_xla_data_proto = const string include_xla_data_proto =
opts.gen_program_shape opts.gen_program_shape
? ? R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
: ""; : "";
const string include_hlo_profile_printer_data_proto = const string include_hlo_profile_printer_data_proto =

View File

@ -20,6 +20,8 @@ limitations under the License.
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "llvm-c/Target.h"
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/flags.h" #include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/compiler/tf2xla/tf2xla.h" #include "tensorflow/compiler/tf2xla/tf2xla.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
@ -90,7 +92,7 @@ Status CompileXla(xla::CompileOnlyClient* client,
} // namespace } // namespace
Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config, Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
const MainFlags& flags, CompileResult* compile_result) { const MainFlags& flags, CompileResult* compile_result) {
// Converts the graph into an XLA computation, and compiles the // Converts the graph into an XLA computation, and compiles the
// computation. // computation.
@ -108,8 +110,8 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
if (!flags.mlir_components.empty()) { if (!flags.mlir_components.empty()) {
return errors::Unknown("Unknown mlir_components ", flags.mlir_components); return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
} }
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config,
ConvertGraphDefToXla(graph_def, config, client, &computation)); client, &computation));
} }
if (!flags.out_session_module.empty()) { if (!flags.out_session_module.empty()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module, TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
@ -132,5 +134,96 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
return CompileXla(client, computation, aot_opts, compile_result); return CompileXla(client, computation, aot_opts, compile_result);
} }
static Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
if (absl::EndsWith(fname, ".pbtxt")) {
return ReadTextProto(Env::Default(), fname, proto);
} else {
return ReadBinaryProto(Env::Default(), fname, proto);
}
}
static std::once_flag targets_init;
static void InitializeTargets() {
// Initialize all LLVM targets so we can cross compile.
#if TF_LLVM_AARCH64_AVAILABLE
LLVMInitializeAArch64Target();
LLVMInitializeAArch64TargetInfo();
LLVMInitializeAArch64TargetMC();
LLVMInitializeAArch64AsmPrinter();
#endif
LLVMInitializeARMTarget();
LLVMInitializeARMTargetInfo();
LLVMInitializeARMTargetMC();
LLVMInitializeARMAsmPrinter();
LLVMInitializePowerPCTarget();
LLVMInitializePowerPCTargetInfo();
LLVMInitializePowerPCTargetMC();
LLVMInitializePowerPCAsmPrinter();
LLVMInitializeX86Target();
LLVMInitializeX86TargetInfo();
LLVMInitializeX86TargetMC();
LLVMInitializeX86AsmPrinter();
}
Status Main(const MainFlags& flags) {
std::call_once(targets_init, &InitializeTargets);
// Process config.
tf2xla::Config config;
if (flags.config.empty()) {
return errors::InvalidArgument("Must specify --config");
}
TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
TF_RETURN_IF_ERROR(ValidateConfig(config));
if (flags.dump_fetch_nodes) {
std::set<string> nodes;
for (const tf2xla::Fetch& fetch : config.fetch()) {
nodes.insert(fetch.id().node_name());
}
std::cout << absl::StrJoin(nodes, ",");
return Status::OK();
}
// Read and initialize the graph.
if (flags.graph.empty()) {
return errors::InvalidArgument("Must specify --graph");
}
GraphDef graph_def;
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
CompileResult compile_result;
TF_RETURN_IF_ERROR(
CompileGraph(std::move(graph_def), config, flags, &compile_result));
// Write output files.
Env* env = Env::Default();
const std::vector<char>& obj = compile_result.aot->object_file_data();
TF_RETURN_IF_ERROR(
WriteStringToFile(env, flags.out_function_object,
absl::string_view(obj.data(), obj.size())));
CodegenOpts codegen_opts;
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
codegen_opts.gen_program_shape = flags.gen_program_shape;
codegen_opts.target_triple = flags.target_triple;
if (flags.cpp_class.empty()) {
return errors::InvalidArgument("Must specify --cpp_class");
}
codegen_opts.gen_hlo_profile_printer_data =
xla::GetDebugOptionsFromFlags().xla_hlo_profile();
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
&codegen_opts.namespaces));
MetadataResult metadata_result;
TF_RETURN_IF_ERROR(
GenerateMetadata(codegen_opts, compile_result, &metadata_result));
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object,
metadata_result.object_file_data));
string header;
TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result,
metadata_result, &header));
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
return Status::OK();
}
} // namespace tfcompile } // namespace tfcompile
} // namespace tensorflow } // namespace tensorflow

View File

@ -42,9 +42,12 @@ struct CompileResult {
// that performs the graph operations. // that performs the graph operations.
// //
// The XLA compilation options are specified in the flags. // The XLA compilation options are specified in the flags.
Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config, Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
const MainFlags& flags, CompileResult* compile_result); const MainFlags& flags, CompileResult* compile_result);
// The full compilation method, for reuse in a library setting.
Status Main(const MainFlags& flags);
} // namespace tfcompile } // namespace tfcompile
} // namespace tensorflow } // namespace tensorflow

View File

@ -25,6 +25,7 @@ namespace tensorflow {
namespace tfcompile { namespace tfcompile {
// Flags for the tfcompile binary. See *.cc file for descriptions. // Flags for the tfcompile binary. See *.cc file for descriptions.
struct MainFlags { struct MainFlags {
string graph; string graph;
string config; string config;

View File

@ -25,6 +25,7 @@ test_suite(
":test_graph_tfmatmulandadd_test", ":test_graph_tfmatmulandadd_test",
":test_graph_tfsplits_test", ":test_graph_tfsplits_test",
":test_graph_tftop_k_test", ":test_graph_tftop_k_test",
":test_graph_tfvariable_readonly_test",
":test_graph_tfvariable_sequential_updates_test", ":test_graph_tfvariable_sequential_updates_test",
":test_graph_tfvariable_test", ":test_graph_tfvariable_test",
":tfcompile_test", ":tfcompile_test",
@ -73,6 +74,7 @@ genrule(
"test_graph_tfsplits.pb", "test_graph_tfsplits.pb",
"test_graph_tftop_k.pb", "test_graph_tftop_k.pb",
"test_graph_tfvariable.pb", "test_graph_tfvariable.pb",
"test_graph_tfvariable_readonly.pb",
"test_graph_tfvariable_sequential_updates.pb", "test_graph_tfvariable_sequential_updates.pb",
], ],
# Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any # Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any
@ -238,6 +240,17 @@ tf_library(
], ],
) )
tf_library(
name = "test_graph_tfvariable_readonly",
testonly = 1,
config = "test_graph_tfvariable_readonly.config.pbtxt",
cpp_class = "VariableReadonlyComp",
graph = "test_graph_tfvariable_readonly.pb",
tags = [
"manual",
],
)
tf_library( tf_library(
name = "test_graph_tfvariable_sequential_updates", name = "test_graph_tfvariable_sequential_updates",
testonly = 1, testonly = 1,
@ -269,6 +282,7 @@ tf_cc_test(
":test_graph_tfsplits", ":test_graph_tfsplits",
":test_graph_tftop_k", ":test_graph_tftop_k",
":test_graph_tfvariable", ":test_graph_tfvariable",
":test_graph_tfvariable_readonly",
":test_graph_tfvariable_sequential_updates", ":test_graph_tfvariable_sequential_updates",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test",
@ -323,6 +337,42 @@ tf_library(
], ],
) )
tf_library(
name = "test_graph_tfcond_mlir_bridge",
testonly = 1,
config = "test_graph_tfcond.config.pbtxt",
cpp_class = "CondComp",
graph = "test_graph_tfcond.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfassert_eq_mlir_bridge",
testonly = 1,
config = "test_graph_tfassert_eq.config.pbtxt",
cpp_class = "AssertComp",
graph = "test_graph_tfassert_eq.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfgather_mlir_bridge",
testonly = 1,
config = "test_graph_tfgather.config.pbtxt",
cpp_class = "GatherComp",
graph = "test_graph_tfgather.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library( tf_library(
name = "test_graph_tfmatmul_mlir_bridge", name = "test_graph_tfmatmul_mlir_bridge",
testonly = 1, testonly = 1,
@ -361,6 +411,42 @@ tf_library(
], ],
) )
tf_library(
name = "test_graph_tfsplits_mlir_bridge",
testonly = 1,
config = "test_graph_tfsplits.config.pbtxt",
cpp_class = "SplitsComp",
graph = "test_graph_tfsplits.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tftop_k_mlir_bridge",
testonly = 1,
config = "test_graph_tftop_k.config.pbtxt",
cpp_class = "TopKComp",
graph = "test_graph_tftop_k.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfvariable_readonly_mlir_bridge",
testonly = 1,
config = "test_graph_tfvariable_readonly.config.pbtxt",
cpp_class = "VariableReadonlyComp",
graph = "test_graph_tfvariable_readonly.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_cc_test( tf_cc_test(
name = "tfcompile_test_mlir_bridge", name = "tfcompile_test_mlir_bridge",
srcs = ["tfcompile_test.cc"], srcs = ["tfcompile_test.cc"],
@ -372,9 +458,15 @@ tf_cc_test(
":test_graph_tfadd_mlir_bridge", ":test_graph_tfadd_mlir_bridge",
":test_graph_tfadd_with_ckpt_mlir_bridge", ":test_graph_tfadd_with_ckpt_mlir_bridge",
":test_graph_tfadd_with_ckpt_saver_mlir_bridge", ":test_graph_tfadd_with_ckpt_saver_mlir_bridge",
":test_graph_tfassert_eq_mlir_bridge",
":test_graph_tfcond_mlir_bridge",
":test_graph_tfgather_mlir_bridge",
":test_graph_tfmatmul_mlir_bridge", ":test_graph_tfmatmul_mlir_bridge",
":test_graph_tfmatmulandadd_mlir_bridge", ":test_graph_tfmatmulandadd_mlir_bridge",
":test_graph_tfmatmulandadd_with_profiling_mlir_bridge", ":test_graph_tfmatmulandadd_with_profiling_mlir_bridge",
":test_graph_tfsplits_mlir_bridge",
":test_graph_tftop_k_mlir_bridge",
":test_graph_tfvariable_readonly_mlir_bridge",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla:xla_data_proto_cc",

View File

@ -34,6 +34,7 @@ from tensorflow.python.framework import function
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
@ -153,6 +154,14 @@ def tftop_k(_):
array_ops.identity(output[1], name='indices') array_ops.identity(output[1], name='indices')
def tfvariable_readonly(_):
x = variables.Variable(1000.0, name='x')
old_x = x.value()
with ops.control_dependencies([old_x]):
new_value = math_ops.add(old_x, 42.0)
array_ops.identity(new_value, name='result')
def tfvariable(_): def tfvariable(_):
x = variables.Variable(1000.0, name='x') x = variables.Variable(1000.0, name='x')
old_x = x.value() old_x = x.value()
@ -184,6 +193,7 @@ def write_graph(build_graph, out_dir):
def main(_): def main(_):
control_flow_util.enable_control_flow_v2()
write_graph(tfadd, FLAGS.out_dir) write_graph(tfadd, FLAGS.out_dir)
write_graph(tfadd_with_ckpt, FLAGS.out_dir) write_graph(tfadd_with_ckpt, FLAGS.out_dir)
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir) write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
@ -196,6 +206,7 @@ def main(_):
write_graph(tfsplits, FLAGS.out_dir) write_graph(tfsplits, FLAGS.out_dir)
write_graph(tftop_k, FLAGS.out_dir) write_graph(tftop_k, FLAGS.out_dir)
write_graph(tfvariable, FLAGS.out_dir) write_graph(tfvariable, FLAGS.out_dir)
write_graph(tfvariable_readonly, FLAGS.out_dir)
write_graph(tfvariable_sequential_updates, FLAGS.out_dir) write_graph(tfvariable_sequential_updates, FLAGS.out_dir)

View File

@ -0,0 +1,12 @@
# Text form of tensorflow.tf2xla.Config proto.
fetch {
id { node_name: "result" }
}
variable {
node_name: "x"
shape {
}
type: DT_FLOAT
readonly: true
}

View File

@ -30,9 +30,15 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_mlir_bridge.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_mlir_bridge.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver_mlir_bridge.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfcond_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfgather_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul_mlir_bridge.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmul_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_mlir_bridge.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling_mlir_bridge.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tftop_k_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly_mlir_bridge.h"
#else #else
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
@ -47,6 +53,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h" #include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h"
#include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h" #include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable.h" #include "tensorflow/compiler/aot/tests/test_graph_tfvariable.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates.h" #include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates.h"
#endif #endif
@ -167,8 +174,6 @@ TEST(TFCompileTest, AddWithCkptSaver) {
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]); EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
} }
// TODO(bixia): the following tests failed with MLIR bridge.
#if !defined(ENABLE_MLIR_BRIDGE_TEST)
TEST(TFCompileTest, Cond) { TEST(TFCompileTest, Cond) {
CondComp cond; CondComp cond;
EXPECT_EQ(cond.arg0_data(), cond.arg_data(0)); EXPECT_EQ(cond.arg0_data(), cond.arg_data(0));
@ -233,7 +238,6 @@ TEST(TFCompileTest, Gather) {
EXPECT_EQ(gather_const.result0_data(), gather.results()[0]); EXPECT_EQ(gather_const.result0_data(), gather.results()[0]);
} }
} }
#endif
TEST(TFCompileTest, MatMul2) { TEST(TFCompileTest, MatMul2) {
Eigen::ThreadPool tp(2); Eigen::ThreadPool tp(2);
@ -439,6 +443,7 @@ TEST(TFCompileTest, Function) {
EXPECT_EQ(add_fn.result0_data()[0], 3); EXPECT_EQ(add_fn.result0_data()[0], 3);
EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]); EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]);
} }
#endif
TEST(TFCompileTest, Splits) { TEST(TFCompileTest, Splits) {
Eigen::ThreadPool tp(1); Eigen::ThreadPool tp(1);
@ -492,6 +497,22 @@ TEST(TFCompileTest, TopK) {
EXPECT_EQ(expected_indices[1], fn.result1(1)); EXPECT_EQ(expected_indices[1], fn.result1(1));
} }
TEST(TFCompileTest, VariableReadonly) {
Eigen::ThreadPool tp(1);
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
VariableReadonlyComp fn;
float x = 23;
fn.set_var_x_data(&x);
fn.set_thread_pool(&device);
fn.Run();
EXPECT_EQ(fn.result0(), 65);
EXPECT_EQ(fn.var_x(), 23);
}
// TODO(bixia): the following tests failed with MLIR bridge.
#if !defined(ENABLE_MLIR_BRIDGE_TEST)
TEST(TFCompileTest, Variable) { TEST(TFCompileTest, Variable) {
Eigen::ThreadPool tp(1); Eigen::ThreadPool tp(1);
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
@ -564,6 +585,7 @@ TEST(TFCompileTest, VariableSequentialUpdatesNoAlloc) {
fn.Run(); fn.Run();
EXPECT_NEAR(x, 0.594322f, 1e-6); EXPECT_NEAR(x, 0.594322f, 1e-6);
} }
#endif
TEST(TFCompileTest, AssertEqAndReturnDiff) { TEST(TFCompileTest, AssertEqAndReturnDiff) {
// Assert is converted into a no-op in XLA, so there is no failure even if the // Assert is converted into a no-op in XLA, so there is no failure even if the
@ -665,6 +687,11 @@ TEST(TFCompileTest, HloProfiling) {
/*clock_rate_ghz=*/1.0); /*clock_rate_ghz=*/1.0);
VLOG(1) << "Original HLO profile string:\n" << hlo_profile_as_string; VLOG(1) << "Original HLO profile string:\n" << hlo_profile_as_string;
// Replace Arg_n with argn when the MLIR bridge is used.
#if defined(ENABLE_MLIR_BRIDGE_TEST)
RE2::GlobalReplace(&hlo_profile_as_string, "(Arg_)([0-9].)", "arg\\2");
#endif
// Strip away identifier details from the profile string to avoid this test // Strip away identifier details from the profile string to avoid this test
// being a change detector for xla internals. Identifiers such as '%dot.0.7' // being a change detector for xla internals. Identifiers such as '%dot.0.7'
// just become '%dot'. // just become '%dot'.
@ -690,7 +717,6 @@ TEST(TFCompileTest, HloProfiling) {
IsSupersetOf({header, total_cycles_profile_line, dot_profile_line, IsSupersetOf({header, total_cycles_profile_line, dot_profile_line,
add_profile_line, tuple_profile_line})); add_profile_line, tuple_profile_line}));
} }
#endif
} // namespace } // namespace
} // namespace tfcompile } // namespace tfcompile

View File

@ -21,7 +21,6 @@ limitations under the License.
#include "absl/strings/match.h" #include "absl/strings/match.h"
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "llvm-c/Target.h"
#include "tensorflow/compiler/aot/codegen.h" #include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/compile.h" #include "tensorflow/compiler/aot/compile.h"
#include "tensorflow/compiler/aot/flags.h" #include "tensorflow/compiler/aot/flags.h"
@ -56,88 +55,6 @@ const char kUsageHeader[] =
"--cpp_class=\"mynamespace::MyComputation\"\n" "--cpp_class=\"mynamespace::MyComputation\"\n"
"\n"; "\n";
Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
if (absl::EndsWith(fname, ".pbtxt")) {
return ReadTextProto(Env::Default(), fname, proto);
} else {
return ReadBinaryProto(Env::Default(), fname, proto);
}
}
Status Main(const MainFlags& flags) {
// Initialize all LLVM targets so we can cross compile.
LLVMInitializeAArch64Target();
LLVMInitializeAArch64TargetInfo();
LLVMInitializeAArch64TargetMC();
LLVMInitializeAArch64AsmPrinter();
LLVMInitializeARMTarget();
LLVMInitializeARMTargetInfo();
LLVMInitializeARMTargetMC();
LLVMInitializeARMAsmPrinter();
LLVMInitializePowerPCTarget();
LLVMInitializePowerPCTargetInfo();
LLVMInitializePowerPCTargetMC();
LLVMInitializePowerPCAsmPrinter();
LLVMInitializeX86Target();
LLVMInitializeX86TargetInfo();
LLVMInitializeX86TargetMC();
LLVMInitializeX86AsmPrinter();
// Process config.
tf2xla::Config config;
if (flags.config.empty()) {
return errors::InvalidArgument("Must specify --config");
}
TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
TF_RETURN_IF_ERROR(ValidateConfig(config));
if (flags.dump_fetch_nodes) {
std::set<string> nodes;
for (const tf2xla::Fetch& fetch : config.fetch()) {
nodes.insert(fetch.id().node_name());
}
std::cout << absl::StrJoin(nodes, ",");
return Status::OK();
}
// Read and initialize the graph.
if (flags.graph.empty()) {
return errors::InvalidArgument("Must specify --graph");
}
GraphDef graph_def;
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
CompileResult compile_result;
TF_RETURN_IF_ERROR(CompileGraph(graph_def, config, flags, &compile_result));
// Write output files.
Env* env = Env::Default();
const std::vector<char>& obj = compile_result.aot->object_file_data();
TF_RETURN_IF_ERROR(
WriteStringToFile(env, flags.out_function_object,
absl::string_view(obj.data(), obj.size())));
CodegenOpts codegen_opts;
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
codegen_opts.gen_program_shape = flags.gen_program_shape;
codegen_opts.target_triple = flags.target_triple;
if (flags.cpp_class.empty()) {
return errors::InvalidArgument("Must specify --cpp_class");
}
codegen_opts.gen_hlo_profile_printer_data =
xla::GetDebugOptionsFromFlags().xla_hlo_profile();
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
&codegen_opts.namespaces));
MetadataResult metadata_result;
TF_RETURN_IF_ERROR(
GenerateMetadata(codegen_opts, compile_result, &metadata_result));
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object,
metadata_result.object_file_data));
string header;
TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result,
metadata_result, &header));
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
return Status::OK();
}
} // end namespace tfcompile } // end namespace tfcompile
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -4,12 +4,7 @@ load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilati
load("//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", "tf_proto_library") load("//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
package( package(
default_visibility = [ default_visibility = [":internal"],
":internal",
# BEGIN-GOOGLE-INTERNAL
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
# END-GOOGLE-INTERNAL
],
licenses = ["notice"], # Apache 2.0 licenses = ["notice"], # Apache 2.0
) )
@ -82,19 +77,6 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "xla_mlir_gpu_jit",
visibility = ["//visibility:public"],
deps = if_cuda_or_rocm([
":jit_compilation_passes",
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/xla/service:mlir_gpu_plugin",
]),
alwayslink = 1,
)
cc_library( cc_library(
name = "xla_cpu_device", name = "xla_cpu_device",
srcs = ["xla_cpu_device.cc"], srcs = ["xla_cpu_device.cc"],
@ -120,6 +102,7 @@ cc_library(
srcs = ["xla_gpu_device.cc"], srcs = ["xla_gpu_device.cc"],
visibility = [":friends"], visibility = [":friends"],
deps = [ deps = [
":flags",
":jit_compilation_passes", ":jit_compilation_passes",
":xla_device", ":xla_device",
":xla_kernel_creator", # buildcleaner: keep ":xla_kernel_creator", # buildcleaner: keep
@ -128,6 +111,7 @@ cc_library(
"//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep "//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:gpu_init",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",

View File

@ -1584,7 +1584,6 @@ DeadnessAnalysis::~DeadnessAnalysis() {}
absl::flat_hash_map<TensorId, string, TensorId::Hasher> absl::flat_hash_map<TensorId, string, TensorId::Hasher>
DeadnessAnalysisImpl::PredicateMapAsString() const { DeadnessAnalysisImpl::PredicateMapAsString() const {
absl::flat_hash_map<TensorId, string, TensorId::Hasher> result; absl::flat_hash_map<TensorId, string, TensorId::Hasher> result;
std::vector<TensorId> tensor_ids;
for (const auto& kv_pair : predicate_map_) { for (const auto& kv_pair : predicate_map_) {
CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second); CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second);
} }

View File

@ -374,39 +374,6 @@ xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
return new_def; return new_def;
} }
TF_ATTRIBUTE_NOINLINE Status
ValidateOutsideCompilationCallNode(Node* call_node) {
// DT_INT64 as input/output for outside compilation is not supported yet:
// b/120809951.
for (const Edge* e : call_node->in_edges()) {
if (e->IsControlEdge()) {
continue;
}
DataType dtype = e->src()->output_type(e->src_output());
if (dtype == DT_INT64) {
return errors::Unimplemented(
"int64 input for outside compilation is not supported yet: "
"b/120809951. Please cast output of node ",
e->src()->DebugString(),
" to int32 before feeding it into outside compilation.");
}
}
for (const Edge* e : call_node->out_edges()) {
if (e->IsControlEdge()) {
continue;
}
DataType dtype = e->dst()->input_type(e->dst_input());
if (dtype == DT_INT64) {
return errors::Unimplemented(
"int64 output for outside compilation is not supported yet: "
"b/120809951. Please cast input of node ",
e->dst()->DebugString(),
" to int32 before returning it from outside compilation.");
}
}
return Status::OK();
}
// Replace outside compilation function call node with XlaHostCompute node. // Replace outside compilation function call node with XlaHostCompute node.
TF_ATTRIBUTE_NOINLINE xla::StatusOr<Node*> ReplaceOutsideCompilationCallNode( TF_ATTRIBUTE_NOINLINE xla::StatusOr<Node*> ReplaceOutsideCompilationCallNode(
Graph* g, Node* call_node, const std::map<string, int>& host_compute_core, Graph* g, Node* call_node, const std::map<string, int>& host_compute_core,
@ -2384,7 +2351,6 @@ Status ExtractOutsideCompilationForFunction(
} }
std::map<string, Node*> host_compute_nodes; std::map<string, Node*> host_compute_nodes;
for (Node* n : outside_compilation_nodes) { for (Node* n : outside_compilation_nodes) {
TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n));
auto host_compute_node_or = ReplaceOutsideCompilationCallNode( auto host_compute_node_or = ReplaceOutsideCompilationCallNode(
graph_out.get(), n, host_compute_core, *cluster_deps); graph_out.get(), n, host_compute_core, *cluster_deps);
TF_RETURN_IF_ERROR(host_compute_node_or.status()); TF_RETURN_IF_ERROR(host_compute_node_or.status());

View File

@ -155,6 +155,7 @@ void AllocateAndParseFlags() {
device_flags = new XlaDeviceFlags; device_flags = new XlaDeviceFlags;
device_flags->tf_xla_compile_on_demand = false; device_flags->tf_xla_compile_on_demand = false;
device_flags->tf_xla_enable_xla_devices = true;
ops_flags = new XlaOpsCommonFlags; ops_flags = new XlaOpsCommonFlags;
ops_flags->tf_xla_always_defer_compilation = false; ops_flags->tf_xla_always_defer_compilation = false;
@ -187,6 +188,12 @@ void AllocateAndParseFlags() {
"Switch a device into 'on-demand' mode, where instead of " "Switch a device into 'on-demand' mode, where instead of "
"autoclustering ops are compiled one by one just-in-time."), "autoclustering ops are compiled one by one just-in-time."),
Flag("tf_xla_enable_xla_devices",
&device_flags->tf_xla_enable_xla_devices,
"Generate XLA_* devices, where placing a computation on such a "
"device"
"forces compilation by XLA. Deprecated."),
Flag("tf_xla_always_defer_compilation", Flag("tf_xla_always_defer_compilation",
&ops_flags->tf_xla_always_defer_compilation, ""), &ops_flags->tf_xla_always_defer_compilation, ""),

View File

@ -87,6 +87,9 @@ struct XlaDeviceFlags {
// Enabling this mode by a legacy flag is a temporary mechanism. When this // Enabling this mode by a legacy flag is a temporary mechanism. When this
// feature is battle-tested, we will switch this to be a session option. // feature is battle-tested, we will switch this to be a session option.
bool tf_xla_compile_on_demand; bool tf_xla_compile_on_demand;
// Enables "XLA" devices if this flag is set.
bool tf_xla_enable_xla_devices;
}; };
// Flags common to the _Xla* ops and their kernels. // Flags common to the _Xla* ops and their kernels.

View File

@ -1776,9 +1776,9 @@ absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable() {
"Lgamma", "Digamma", "Lgamma", "Digamma",
// Binary // Binary
"Add", "AddV2", "Sub", "Mul", "Div", "Atan2", "Complex", "DivNoNan", "Add", "AddV2", "Sub", "Mul", "Div", "Atan2", "Complex", "DivNoNan",
"MulNoNan", "FloorDiv", "Xlogy", "Xdivy", "FloorMod", "BitwiseAnd", "MulNoNan", "FloorDiv", "Xlogy", "Xlog1py", "Xdivy", "FloorMod",
"BitwiseOr", "BitwiseXor", "LeftShift", "RightShift", "LogicalAnd", "BitwiseAnd", "BitwiseOr", "BitwiseXor", "LeftShift", "RightShift",
"LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv", "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
"ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "TruncateDiv", "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "TruncateDiv",
"TruncateMod", "Equal", "NotEqual", "Greater", "GreaterEqual", "TruncateMod", "Equal", "NotEqual", "Greater", "GreaterEqual",
"Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", "SoftsignGrad", "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", "SoftsignGrad",
@ -1872,6 +1872,8 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"Einsum", "Einsum",
"EmptyTensorList", "EmptyTensorList",
"ExtractImagePatches", "ExtractImagePatches",
"Igamma",
"Igammac",
"FFT", "FFT",
"FFT2D", "FFT2D",
"FFT3D", "FFT3D",

View File

@ -36,8 +36,13 @@ class XlaCpuDeviceFactory : public DeviceFactory {
}; };
Status XlaCpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) { Status XlaCpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0")); XlaDeviceFlags* flags = GetXlaDeviceFlags();
if (!flags->tf_xla_enable_xla_devices) {
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
return Status::OK();
}
devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0"));
return Status::OK(); return Status::OK();
} }
@ -45,6 +50,10 @@ Status XlaCpuDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix, const SessionOptions& session_options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) { std::vector<std::unique_ptr<Device>>* devices) {
XlaDeviceFlags* flags = GetXlaDeviceFlags(); XlaDeviceFlags* flags = GetXlaDeviceFlags();
if (!flags->tf_xla_enable_xla_devices) {
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
return Status::OK();
}
bool compile_on_demand = flags->tf_xla_compile_on_demand; bool compile_on_demand = flags->tf_xla_compile_on_demand;
XlaOpRegistry::DeviceRegistration registration; XlaOpRegistry::DeviceRegistration registration;

View File

@ -140,7 +140,6 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
// The device tensor should always be fresh. // The device tensor should always be fresh.
TF_RET_CHECK(!xla_tensor->has_shaped_buffer()); TF_RET_CHECK(!xla_tensor->has_shaped_buffer());
xla_tensor->set_host_tensor(*cpu_tensor);
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_, xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
stream_->parent()->device_ordinal())); stream_->parent()->device_ordinal()));

View File

@ -14,17 +14,20 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
// Registers the XLA_GPU device, which is an XlaDevice instantiation that runs // Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
// operators using XLA via the XLA "CUDA" (GPU) backend. // operators using XLA via the XLA "CUDA" or "ROCM" (GPU) backend.
#include <set> #include <set>
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/numbers.h" #include "absl/strings/numbers.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
namespace tensorflow { namespace tensorflow {
@ -61,7 +64,14 @@ class XlaGpuDeviceFactory : public DeviceFactory {
}; };
Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) { Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA"); XlaDeviceFlags* flags = GetXlaDeviceFlags();
if (!flags->tf_xla_enable_xla_devices) {
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
return Status::OK();
}
auto platform =
se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
if (!platform.ok()) { if (!platform.ok()) {
// Treat failures as non-fatal; there might not be a GPU in the machine. // Treat failures as non-fatal; there might not be a GPU in the machine.
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status(); VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
@ -84,6 +94,12 @@ Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
Status XlaGpuDeviceFactory::CreateDevices( Status XlaGpuDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix, const SessionOptions& session_options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) { std::vector<std::unique_ptr<Device>>* devices) {
XlaDeviceFlags* flags = GetXlaDeviceFlags();
if (!flags->tf_xla_enable_xla_devices) {
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
return Status::OK();
}
XlaOpRegistry::DeviceRegistration registration; XlaOpRegistry::DeviceRegistration registration;
registration.compilation_device_name = DEVICE_GPU_XLA_JIT; registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
registration.autoclustering_policy = registration.autoclustering_policy =
@ -103,7 +119,8 @@ Status XlaGpuDeviceFactory::CreateDevices(
RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT); RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT);
(void)registrations; (void)registrations;
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA"); auto platform =
se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
if (!platform.ok()) { if (!platform.ok()) {
// Treat failures as non-fatal; there might not be a GPU in the machine. // Treat failures as non-fatal; there might not be a GPU in the machine.
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status(); VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();

View File

@ -222,8 +222,9 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
OpKernelConstruction construction( OpKernelConstruction construction(
DeviceType(dev->device_type()), dev, DeviceType(dev->device_type()), dev,
dev->GetAllocator(AllocatorAttributes()), &node_def, dev->GetAllocator(AllocatorAttributes()), &node_def,
&fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types, &fbody->fdef.signature(), flr, dev->resource_manager(), fbody->arg_types,
fbody->ret_types, output_memory_types, flr->graph_def_version(), &s); input_memory_types, fbody->ret_types, output_memory_types,
flr->graph_def_version(), &s);
*kernel = absl::make_unique<XlaLocalLaunchBase>( *kernel = absl::make_unique<XlaLocalLaunchBase>(
&construction, constant_arg_indices, resource_arg_indices, function); &construction, constant_arg_indices, resource_arg_indices, function);

View File

@ -44,8 +44,11 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/platform:logging", "//tensorflow/core/platform:logging",
"@llvm-project//llvm:support", "@llvm-project//llvm:support",
"@llvm-project//mlir:AffineDialectRegistration",
"@llvm-project//mlir:LoopDialectRegistration",
"@llvm-project//mlir:MlirOptLib", "@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOpsDialectRegistration",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
"@llvm-project//mlir/test:TestTransforms", "@llvm-project//mlir/test:TestTransforms",
], ],
@ -80,9 +83,10 @@ cc_library(
"//tensorflow/compiler/mlir/xla:xla_legalize_tf", "//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard", "//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
"//tensorflow/compiler/mlir/xla:xla_lower", "//tensorflow/compiler/mlir/xla:xla_lower",
"@llvm-project//mlir:AffineDialectRegistration", "//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts",
"//tensorflow/compiler/mlir/xla:xla_test_passes",
"@llvm-project//mlir:AffineOps",
"@llvm-project//mlir:QuantOps", "@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:QuantOpsDialectRegistration",
], ],
) )

View File

@ -47,6 +47,14 @@ gentbl(
"-gen-op-doc", "-gen-op-doc",
"g3doc/tfl_ops.md", "g3doc/tfl_ops.md",
), ),
(
"-gen-op-interface-decls",
"ir/tfl_ops_interface.h.inc",
),
(
"-gen-op-interface-defs",
"ir/tfl_ops_interface.cc.inc",
),
], ],
tblgen = "@llvm-project//mlir:mlir-tblgen", tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ir/tfl_ops.td", td_file = "ir/tfl_ops.td",
@ -177,11 +185,12 @@ cc_library(
"ir/tfl_ops.cc", "ir/tfl_ops.cc",
"ir/tfl_ops.cc.inc", "ir/tfl_ops.cc.inc",
"ir/tfl_ops.h.inc", "ir/tfl_ops.h.inc",
"ir/tfl_ops_interface.cc.inc",
"ir/tfl_ops_interface.h.inc",
"utils/attribute_utils.cc", "utils/attribute_utils.cc",
], ],
hdrs = [ hdrs = [
"ir/tfl_ops.h", "ir/tfl_ops.h",
"ir/tfl_traits.h",
"transforms/passes.h", "transforms/passes.h",
"utils/attribute_utils.h", "utils/attribute_utils.h",
"//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h", "//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",
@ -330,6 +339,7 @@ cc_library(
cc_library( cc_library(
name = "tensorflow_lite_quantize", name = "tensorflow_lite_quantize",
srcs = [ srcs = [
"transforms/default_quant_params.cc",
"transforms/generated_post_quantize.inc", "transforms/generated_post_quantize.inc",
"transforms/generated_quantize.inc", "transforms/generated_quantize.inc",
"transforms/load_quantization_recipe.cc", "transforms/load_quantization_recipe.cc",
@ -506,6 +516,7 @@ cc_library(
"//tensorflow/lite:schema_fbs_version", "//tensorflow/lite:schema_fbs_version",
"//tensorflow/lite:string_util", "//tensorflow/lite:string_util",
"//tensorflow/lite/delegates/flex:whitelisted_flex_ops_lib", "//tensorflow/lite/delegates/flex:whitelisted_flex_ops_lib",
"//tensorflow/lite/kernels/internal:kernel_utils",
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/tools/versioning:op_version", "//tensorflow/lite/tools/versioning:op_version",
"@com_google_absl//absl/base", "@com_google_absl//absl/base",
@ -671,12 +682,16 @@ cc_library(
], ],
) )
exports_files( cc_library(
["transforms/passes.h"], name = "empty_passes",
hdrs = ["transforms/passes.h"],
visibility = [ visibility = [
"//configs/devtools/hawkeye/tflite:__subpackages__", "//configs/devtools/hawkeye/tflite:__subpackages__",
"//learning/brain/models/app_benchmarks:__subpackages__", "//learning/brain/models/app_benchmarks:__subpackages__",
"//tensorflow/compiler/mlir/lite:friends", "//tensorflow/compiler/mlir/lite:friends",
"//tensorflow/lite/experimental/mlir:__subpackages__", "//tensorflow/lite/experimental/mlir:__subpackages__",
], ],
deps = [
"@llvm-project//llvm:support",
],
) )

View File

@ -31,10 +31,11 @@ struct PassConfig {
: emit_builtin_tflite_ops(true), : emit_builtin_tflite_ops(true),
lower_tensor_list_ops(false), lower_tensor_list_ops(false),
trim_functions_whitelist({}), trim_functions_whitelist({}),
quant_specs(specs), quant_specs(std::move(specs)),
skip_control_dialect(false), skip_control_dialect(false),
form_clusters(false), form_clusters(false),
inline_functions(false) {} inline_functions(false),
unfold_batch_matmul(true) {}
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be // If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
// added, which produces TF Lite ops. // added, which produces TF Lite ops.
@ -57,6 +58,9 @@ struct PassConfig {
// Inline function calls within the main function in the MLIR module, prior // Inline function calls within the main function in the MLIR module, prior
// to legalization to TFLite. // to legalization to TFLite.
bool inline_functions; bool inline_functions;
// if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set
// of tfl.fully_connected ops.
bool unfold_batch_matmul;
}; };
} // namespace TFL } // namespace TFL

View File

@ -389,7 +389,6 @@ StatusOr<mlir::ElementsAttr> ConvertIntBuffer(
mlir::RankedTensorType shaped_type, mlir::Type elem_type, mlir::RankedTensorType shaped_type, mlir::Type elem_type,
const std::vector<uint8_t>& buffer) { const std::vector<uint8_t>& buffer) {
unsigned bit_width; unsigned bit_width;
mlir::RankedTensorType buffer_type;
if (auto itype = elem_type.dyn_cast<mlir::IntegerType>()) { if (auto itype = elem_type.dyn_cast<mlir::IntegerType>()) {
bit_width = itype.getWidth(); bit_width = itype.getWidth();
} else if (auto qtype = elem_type.dyn_cast<QuantizedType>()) { } else if (auto qtype = elem_type.dyn_cast<QuantizedType>()) {
@ -920,15 +919,13 @@ StatusOr<FuncOp> ConvertSubgraph(
// represents TFLite, this entry point must be called "main" // represents TFLite, this entry point must be called "main"
// TODO(b/131175224,b/132239787) Support multiple entry points // TODO(b/131175224,b/132239787) Support multiple entry points
std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) { std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
if (subgraph.name.empty()) { if (index == 0) {
if (index == 0) { return "main";
return "main";
} else {
return llvm::formatv("fn_{0}", index).str();
}
} else {
return subgraph.name;
} }
if (subgraph.name.empty()) {
return llvm::formatv("fn_{0}", index).str();
}
return subgraph.name;
} }
} // namespace } // namespace

View File

@ -259,9 +259,9 @@ Status mlir::CustomOptionsToAttributes(
attributes->emplace_back(builder.getNamedAttr( attributes->emplace_back(builder.getNamedAttr(
"stride_w", builder.getI32IntegerAttr(pool_params->stride_width))); "stride_w", builder.getI32IntegerAttr(pool_params->stride_width)));
attributes->emplace_back(builder.getNamedAttr( attributes->emplace_back(builder.getNamedAttr(
"filter_w", builder.getI32IntegerAttr(pool_params->filter_height))); "filter_h", builder.getI32IntegerAttr(pool_params->filter_height)));
attributes->emplace_back(builder.getNamedAttr( attributes->emplace_back(builder.getNamedAttr(
"filter_h", builder.getI32IntegerAttr(pool_params->filter_width))); "filter_w", builder.getI32IntegerAttr(pool_params->filter_width)));
return Status::OK(); return Status::OK();
} else if (op_name == "tfl.convolution_2d_transpose_bias") { } else if (op_name == "tfl.convolution_2d_transpose_bias") {

View File

@ -71,6 +71,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h" #include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/string_util.h" #include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/tools/versioning/op_version.h" #include "tensorflow/lite/tools/versioning/op_version.h"
@ -218,6 +219,13 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
auto qtype = type.cast<mlir::quant::UniformQuantizedPerAxisType>(); auto qtype = type.cast<mlir::quant::UniformQuantizedPerAxisType>();
return GetTFLiteType(qtype.getStorageType(), qtype.isSigned()); return GetTFLiteType(qtype.getStorageType(), qtype.isSigned());
} }
case mlir::TF::TensorFlowTypes::RESOURCE: {
// Treat tf.resource values as integer values in flatbuffer.
// TODO(b/146131919): Maybe need to have a detailed design for supporting
// other resource types beyonds hash table resources and resource
// variables.
return tflite::TensorType_INT32;
}
default: default:
// TFLite export fills FLOAT32 for unknown data types. Returning an error // TFLite export fills FLOAT32 for unknown data types. Returning an error
// for now for safety and this could be revisited when required. // for now for safety and this could be revisited when required.
@ -317,6 +325,48 @@ static std::unique_ptr<::tensorflow::NodeDef> getTensorFlowNodeDef(
return std::move(status_or_node_def.ValueOrDie()); return std::move(status_or_node_def.ValueOrDie());
} }
// Converts a mlir padding StringRef to TfLitePadding.
// Returns llvm::None if conversion fails.
static Optional<TfLitePadding> GetTflitePadding(Operation* inst,
llvm::StringRef padding) {
const tflite::Padding padding_attr =
std::move(llvm::StringSwitch<tflite::Padding>(padding)
.Case("SAME", tflite::Padding_SAME)
.Case("VALID", tflite::Padding_VALID));
if (padding_attr == tflite::Padding_SAME) {
return kTfLitePaddingSame;
}
if (padding_attr == tflite::Padding_VALID) {
return kTfLitePaddingValid;
}
return inst->emitOpError() << "Invalid padding attribute: " << padding,
llvm::None;
}
// Extracts TfLitePoolParams from a TFL custom op.
// Template parameter, TFLOp, should be a TFL custom op containing attributes
// generated from TfLitePoolParams.
// Returns llvm::None if conversion fails.
template <typename TFLOp>
static Optional<TfLitePoolParams> GetTflitePoolParams(Operation* inst,
TFLOp op) {
TfLitePoolParams pool_params;
pool_params.stride_height = op.stride_h().getSExtValue();
pool_params.stride_width = op.stride_w().getSExtValue();
pool_params.filter_height = op.filter_h().getSExtValue();
pool_params.filter_width = op.filter_w().getSExtValue();
const auto padding = GetTflitePadding(inst, op.padding());
if (padding) {
pool_params.padding = *padding;
pool_params.activation = kTfLiteActNone;
pool_params.computed.padding = TfLitePaddingValues{0, 0, 0, 0};
return pool_params;
}
return llvm::None;
}
namespace { namespace {
// Translates an MLIR module in TFLite dialect to TFLite FlatBuffer. // Translates an MLIR module in TFLite dialect to TFLite FlatBuffer.
@ -375,9 +425,31 @@ class Translator {
mlir::TF::WhileOp op, const std::vector<int32_t>& operands, mlir::TF::WhileOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results); const std::vector<int32_t>& results);
// Builds custom operators.
// Templated on a) data type of custom_option to be stored into flatbuffer,
// and b) TFL custom op type.
template <typename CustomOptionType, typename TFLOp>
BufferOffset<tflite::Operator> BuildCustomOperator(
const CustomOptionType& custom_option, const std::string& opcode_name,
TFLOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
BufferOffset<tflite::Operator> BuildNumericVerifyOperator( BufferOffset<tflite::Operator> BuildNumericVerifyOperator(
mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands, mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results); const std::vector<int32_t>& results);
Optional<BufferOffset<tflite::Operator>>
BuildConvolution2DTransposeBiasOperator(
Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op,
const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
Optional<BufferOffset<tflite::Operator>> BuildMaxPoolingWithArgMax2DOperator(
Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op,
const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
Optional<BufferOffset<tflite::Operator>> BuildMaxUnpooling2DOperator(
Operation* inst, mlir::TFL::MaxUnpooling2DOp op,
const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
Optional<CustomOptionsOffset> CreateFlexOpCustomOptions( Optional<CustomOptionsOffset> CreateFlexOpCustomOptions(
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
@ -615,19 +687,72 @@ BufferOffset<tflite::Operator> Translator::BuildWhileOperator(
builtin_options); builtin_options);
} }
template <typename CustomOptionType, typename TFLOp>
BufferOffset<tflite::Operator> Translator::BuildCustomOperator(
const CustomOptionType& custom_option, const std::string& opcode_name,
TFLOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results) {
std::vector<uint8_t> custom_option_vector(sizeof(CustomOptionType));
memcpy(custom_option_vector.data(), &custom_option, sizeof(CustomOptionType));
auto opcode_index =
GetOpcodeIndex(opcode_name, tflite::BuiltinOperator_CUSTOM);
return tflite::CreateOperator(
builder_, opcode_index, builder_.CreateVector(operands),
builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
/*builtin_options=*/0,
builder_.CreateVector<uint8_t>(custom_option_vector),
tflite::CustomOptionsFormat_FLEXBUFFERS);
}
BufferOffset<tflite::Operator> Translator::BuildNumericVerifyOperator( BufferOffset<tflite::Operator> Translator::BuildNumericVerifyOperator(
mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands, mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results) { const std::vector<int32_t>& results) {
float tolerance = op.tolerance().convertToFloat(); float tolerance = op.tolerance().convertToFloat();
std::vector<uint8_t> custom_options(sizeof(float)); return BuildCustomOperator(tolerance, "NumericVerify", op, operands, results);
memcpy(custom_options.data(), &tolerance, sizeof(float)); }
auto opcode_index =
GetOpcodeIndex("NumericVerify", tflite::BuiltinOperator_CUSTOM); Optional<BufferOffset<tflite::Operator>>
return tflite::CreateOperator( Translator::BuildConvolution2DTransposeBiasOperator(
builder_, opcode_index, builder_.CreateVector(operands), Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op,
builder_.CreateVector(results), tflite::BuiltinOptions_NONE, const std::vector<int32_t>& operands, const std::vector<int32_t>& results) {
/*builtin_options=*/0, builder_.CreateVector<uint8_t>(custom_options), TfLiteTransposeConvParams conv_params;
tflite::CustomOptionsFormat_FLEXBUFFERS); conv_params.stride_height = op.stride_h().getSExtValue();
conv_params.stride_width = op.stride_w().getSExtValue();
const auto padding = GetTflitePadding(inst, op.padding());
if (padding) {
conv_params.padding = *padding;
return BuildCustomOperator(conv_params, "Convolution2DTransposeBias", op,
operands, results);
}
return llvm::None;
}
Optional<BufferOffset<tflite::Operator>>
Translator::BuildMaxPoolingWithArgMax2DOperator(
Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op,
const std::vector<int32_t>& operands, const std::vector<int32_t>& results) {
const auto pool_params = GetTflitePoolParams(inst, op);
if (pool_params) {
return BuildCustomOperator(*pool_params, "MaxPoolingWithArgmax2D", op,
operands, results);
}
return llvm::None;
}
Optional<BufferOffset<tflite::Operator>>
Translator::BuildMaxUnpooling2DOperator(Operation* inst,
mlir::TFL::MaxUnpooling2DOp op,
const std::vector<int32_t>& operands,
const std::vector<int32_t>& results) {
const auto pool_params = GetTflitePoolParams(inst, op);
if (pool_params) {
return BuildCustomOperator(*pool_params, "MaxUnpooling2D", op, operands,
results);
}
return llvm::None;
} }
Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions( Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
@ -769,6 +894,20 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
if (auto verify_op = dyn_cast<mlir::TFL::NumericVerifyOp>(inst)) { if (auto verify_op = dyn_cast<mlir::TFL::NumericVerifyOp>(inst)) {
return BuildNumericVerifyOperator(verify_op, operands, results); return BuildNumericVerifyOperator(verify_op, operands, results);
} }
if (auto conv_transpose_bias_op =
dyn_cast<mlir::TFL::Convolution2DTransposeBiasOp>(inst)) {
return BuildConvolution2DTransposeBiasOperator(
inst, conv_transpose_bias_op, operands, results);
}
if (auto max_pooling_with_arg_max_op =
dyn_cast<mlir::TFL::MaxPoolingWithArgMax2DOp>(inst)) {
return BuildMaxPoolingWithArgMax2DOperator(
inst, max_pooling_with_arg_max_op, operands, results);
}
if (auto max_unpooling_op = dyn_cast<mlir::TFL::MaxUnpooling2DOp>(inst)) {
return BuildMaxUnpooling2DOperator(inst, max_unpooling_op, operands,
results);
}
inst->emitOpError("is not a supported TFLite op"); inst->emitOpError("is not a supported TFLite op");
return llvm::None; return llvm::None;
} }
@ -904,11 +1043,6 @@ void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) {
bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) { bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) {
std::vector<int> operand_indices; std::vector<int> operand_indices;
// TODO(b/138254427): When the bug is addressed, we'll be able to inspect
// for the presence of a specific OpTrait using mlir::Operation, without
// having to cast it to specific ops like below.
// Until then, when a new RNN/LSTM op is added to TFLite and has stateful
// tensors as operands, they will need to be added here as well.
if (!mlir::TFL::IsStatefulOp(op, &operand_indices)) return false; if (!mlir::TFL::IsStatefulOp(op, &operand_indices)) return false;
return absl::c_find(operand_indices, operand_index) != operand_indices.end(); return absl::c_find(operand_indices, operand_index) != operand_indices.end();
} }

View File

@ -1728,6 +1728,7 @@ static LogicalResult Verify(TransposeOp op) {
// TableGen'd op method definitions // TableGen'd op method definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.cc.inc"
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"

View File

@ -27,7 +27,6 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // TF:llvm-project #include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project #include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project #include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h"
@ -44,6 +43,7 @@ class TensorFlowLiteDialect : public Dialect {
Location loc) override; Location loc) override;
}; };
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.h.inc"
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc"

View File

@ -249,14 +249,39 @@ def TFL_ComparisonBinaryBuilder : OpBuilder<
}]>; }]>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TFL native op trait for stateful operands and channel indices. // TFL op interface for stateful operands.
class StatefulOperands<list<int> operands> def TFL_StatefulOp : OpInterface<"StatefulOpInterface"> {
: ParamNativeOpTrait<"TFL::StatefulOperands", StrJoinInt<operands>.result>; let description = [{
Interface for ops that are stateful and need to identify stateful operands.
Stateful operands correspond to TF's variables semantics. An op that has 1
or more stateful operands is a stateful op.
}];
class ChannelDimIndex<int index> let methods = [
: ParamNativeOpTrait<"TFL::ChannelDimIndex", !cast<string>(index)>; InterfaceMethod<
[{Returns the indices of stateful operands.}],
"std::vector<int>", "GetStatefulOperands", (ins)
>,
];
}
//===----------------------------------------------------------------------===//
// TFL op interface for output channel index.
def TFL_ChannelDimIndexInterface : OpInterface<"ChannelDimIndexInterface"> {
let description = [{
Interface for defining the index of out channel index.
}];
let methods = [
InterfaceMethod<
[{Returns the dimension index of the output channels.}],
"int", "GetChannelDimIndex", (ins)
>,
];
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TFL op base class. // TFL op base class.
@ -285,7 +310,7 @@ class TFL_Op<string mnemonic, list<OpTrait> traits = []> :
class TFL_ConvOp<string mnemonic, string opSummary, int index> : class TFL_ConvOp<string mnemonic, string opSummary, int index> :
TFL_Op<mnemonic, [NoSideEffect, AccumulatorUniformScale<2, 0, 1>, TFL_Op<mnemonic, [NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
ChannelDimIndex<index>, AffineOpCoefficient<index, 1>]> { TFL_ChannelDimIndexInterface, AffineOpCoefficient<index, 1>]> {
let summary = opSummary # " operator"; let summary = opSummary # " operator";
let description = [{ let description = [{
@ -486,8 +511,7 @@ def TFL_ArgMaxOp : TFL_Op<"arg_max", [NoSideEffect]> {
}]; }];
let arguments = ( let arguments = (
// TODO: Add support for uint8. ins TensorOf<[F32, I32, I8, TFL_Uint8, QI8, QUI8]>:$input,
ins TensorOf<[F32, I32, I8]>:$input,
TFL_I32OrI64Tensor:$dim TFL_I32OrI64Tensor:$dim
); );
@ -515,8 +539,7 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> {
}]; }];
let arguments = ( let arguments = (
// TODO(pkanwar): Add support for uint8. ins TensorOf<[F32, I32, I8, TFL_Uint8, QI8, QUI8]>:$input,
ins TensorOf<[F32, I32, I8]>:$input,
TFL_I32OrI64Tensor:$dim TFL_I32OrI64Tensor:$dim
); );
@ -617,7 +640,12 @@ def TFL_ExternalConstOp : Op<TFL_Dialect, "external_const", [NoSideEffect]> {
let results = (outs AnyTensor:$output); let results = (outs AnyTensor:$output);
} }
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0>; def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0> {
let extraClassDeclaration = [{
// StatefulOpInterface:
int GetChannelDimIndex() { return 0; }
}];
}
def TFL_CosOp: TFL_Op<"cos", [ def TFL_CosOp: TFL_Op<"cos", [
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> { NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
@ -637,6 +665,11 @@ def TFL_CosOp: TFL_Op<"cos", [
def TFL_DepthwiseConv2DOp : def TFL_DepthwiseConv2DOp :
TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> { TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> {
let arguments = !con(TFL_Conv2DOp.arguments, (ins I32Attr:$depth_multiplier)); let arguments = !con(TFL_Conv2DOp.arguments, (ins I32Attr:$depth_multiplier));
let extraClassDeclaration = [{
// StatefulOpInterface:
int GetChannelDimIndex() { return 3; }
}];
} }
def TFL_FCWO_Default : StrEnumAttrCase<"DEFAULT">; def TFL_FCWO_Default : StrEnumAttrCase<"DEFAULT">;
@ -650,7 +683,8 @@ def TFL_FullyConnectedOptionsWeightFormatAttr :
// TODO(jpienaar): Update post discussion on semantics of FC OP. // TODO(jpienaar): Update post discussion on semantics of FC OP.
def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
NoSideEffect, AccumulatorUniformScale<2, 0, 1>, ChannelDimIndex<0>, NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
TFL_ChannelDimIndexInterface,
AffineOpCoefficient<-1, 1>]> { AffineOpCoefficient<-1, 1>]> {
let summary = "Fully connected op"; let summary = "Fully connected op";
@ -672,6 +706,11 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
let verifier = [{ return Verify(*this); }]; let verifier = [{ return Verify(*this); }];
let hasOptions = 1; let hasOptions = 1;
let extraClassDeclaration = [{
// ChannelDimIndexInterface:
int GetChannelDimIndex() { return 0; }
}];
} }
def TFL_GatherOp : TFL_Op<"gather", [ def TFL_GatherOp : TFL_Op<"gather", [
@ -1208,7 +1247,8 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [Broadcastable, NoSideEffect]> {
let builders = [TFL_BroadcastableBinaryBuilder]; let builders = [TFL_BroadcastableBinaryBuilder];
} }
def TFL_GreaterOp : TFL_Op<"greater", [NoSideEffect, NoQuantizableResult]> { def TFL_GreaterOp : TFL_Op<"greater", [
Broadcastable, NoSideEffect, NoQuantizableResult]> {
let summary = "Greater operator"; let summary = "Greater operator";
let description = [{ let description = [{
@ -1221,6 +1261,8 @@ def TFL_GreaterOp : TFL_Op<"greater", [NoSideEffect, NoQuantizableResult]> {
let results = (outs AnyTensor:$output); let results = (outs AnyTensor:$output);
let builders = [TFL_ComparisonBinaryBuilder];
let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }]; let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }];
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }]; let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
@ -1287,7 +1329,8 @@ def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [NoSideEffect, SameOperandsAndResultTy
let hasOptions = 0b1; let hasOptions = 0b1;
} }
def TFL_LessOp : TFL_Op<"less", [NoSideEffect, NoQuantizableResult]> { def TFL_LessOp : TFL_Op<"less", [
Broadcastable, NoSideEffect, NoQuantizableResult]> {
let summary = "Less operator"; let summary = "Less operator";
let description = [{ let description = [{
@ -2123,7 +2166,7 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2",
Args: Args:
tensor: A Tensor. Must be one of the following types: tensor: A Tensor. Must be one of the following types:
int16, int32, int64, float32 Up to 8-D. uint8, int16, int32, int64, float32, bool Up to 8-D.
axis: A Tensor. Must be one of the following types: int32, int64. axis: A Tensor. Must be one of the following types: int32, int64.
with only 1 element which is the axis index. with only 1 element which is the axis index.
@ -2132,12 +2175,12 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2",
let arguments = ( let arguments = (
ins ins
TensorOf<[F32, I16, I32, I64]>:$input, TensorOf<[F32, I16, I32, I64, TFL_Uint8, I1]>:$input,
TensorOf<[I32, I64]>:$axis TensorOf<[I32, I64]>:$axis
); );
let results = (outs let results = (outs
TensorOf<[F32, I16, I32, I64, I8]>:$output TensorOf<[F32, I16, I32, I64, TFL_Uint8, I1]>:$output
); );
} }
@ -2341,9 +2384,9 @@ def TFL_TanhOp: TFL_Op<"tanh", [
let results = (outs TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$y); let results = (outs TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$y);
} }
def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, SameOperandsAndResultsScale,
PredOpTrait<"resultant element type needs to match first operand type", PredOpTrait<"resultant element type needs to match first operand type",
TCresVTEtIsSameAsOp<0,0>>]> { TFL_TCresVTEtIsSameAsOp<0,0>>]> {
let summary = "Tile operator."; let summary = "Tile operator.";
let description = [{ let description = [{
Constructs a tensor by tiling a given tensor. Constructs a tensor by tiling a given tensor.
@ -2356,10 +2399,11 @@ def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
}]; }];
let arguments = (ins let arguments = (ins
TensorOf<[F32, I1, I32, I64, TFL_Uint8]>:$input, TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8]>:$input,
TFL_I32OrI64Tensor:$multiples); TFL_I32OrI64Tensor:$multiples);
let results = (outs TensorOf<[F32, I1, I32, I64, TFL_Uint8]>:$output); let results = (outs
TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8]>:$output);
let hasOptions = 0; let hasOptions = 0;
} }
@ -2369,7 +2413,7 @@ def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
// TODO(jpienaar): Check that k is less or equal the internal dimension // TODO(jpienaar): Check that k is less or equal the internal dimension
def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>, def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>,
PredOpTrait<"result and input element type match", PredOpTrait<"result and input element type match",
TCresVTEtIsSameAsOp<0,0>>]> { TCresVTEtIsSameAsOp<0,0>>, SameOperandsAndResultsScale]> {
let summary = "TopK operator"; let summary = "TopK operator";
let description = [{ let description = [{
@ -2379,11 +2423,11 @@ def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>,
}]; }];
let arguments = (ins let arguments = (ins
TensorOf<[F32, I8, I32, I64, TFL_Uint8]>:$input, TensorOf<[F32, I8, I32, I64, TFL_Uint8, QI8, QUI8]>:$input,
I32Tensor:$k); I32Tensor:$k);
let results = (outs let results = (outs
AnyTensor:$values, TensorOf<[F32, I8, I32, I64, TFL_Uint8, QI8, QUI8]>:$values,
I32Tensor:$indices); I32Tensor:$indices);
let builders = [OpBuilder<"Builder *builder, OperationState &result, " let builders = [OpBuilder<"Builder *builder, OperationState &result, "
@ -2907,6 +2951,20 @@ def TFL_QuantizeOp: TFL_Op<"quantize", [
let results = (outs AnyTensor:$output); let results = (outs AnyTensor:$output);
} }
def TFL_DensifyOp: TFL_Op<"densify", [NoSideEffect,
SameOperandsAndResultType,
NoQuantizableResult]> {
let summary = "Densify operator";
let description = [{
Converts sparse tensor to dense format.
}];
let arguments = (ins AnyTensor:$input);
let results = (outs AnyTensor:$output);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// LSTM Ops // LSTM Ops
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -2996,7 +3054,7 @@ def TFL_LSTMOp :
LstmOptionalPeepholeWeightConstraint, LstmOptionalPeepholeWeightConstraint,
LstmProjectionWeightBiasConstraint, LstmProjectionWeightBiasConstraint,
LstmResultConstraint, LstmResultConstraint,
StatefulOperands<[18, 19]>]> { TFL_StatefulOp]> {
let summary = "The full lstm operator"; let summary = "The full lstm operator";
let description = [{ let description = [{
@ -3080,6 +3138,11 @@ Ba et al. “Layer Normalization”
let hasOptions = 1; let hasOptions = 1;
let verifier = [{ return Verify(*this); }]; let verifier = [{ return Verify(*this); }];
let extraClassDeclaration = [{
// StatefulOpInterface:
std::vector<int> GetStatefulOperands() { return {18, 19}; }
}];
} }
// UnidirectionalSequenceLstm op. // UnidirectionalSequenceLstm op.
@ -3091,7 +3154,7 @@ def TFL_UnidirectionalSequenceLSTMOp :
LstmOptionalPeepholeWeightConstraint, LstmOptionalPeepholeWeightConstraint,
LstmProjectionWeightBiasConstraint, LstmProjectionWeightBiasConstraint,
LstmResultConstraint, LstmResultConstraint,
StatefulOperands<[18, 19]>]> { TFL_StatefulOp]> {
let summary = "Unidirectional sequence lstm operator"; let summary = "Unidirectional sequence lstm operator";
let description = [{ let description = [{
@ -3160,6 +3223,11 @@ def TFL_UnidirectionalSequenceLSTMOp :
let hasOptions = 1; let hasOptions = 1;
let verifier = [{ return Verify(*this); }]; let verifier = [{ return Verify(*this); }];
let extraClassDeclaration = [{
// StatefulOpInterface:
std::vector<int> GetStatefulOperands() { return {18, 19}; }
}];
} }
def RnnResultConstraint : PredOpTrait< def RnnResultConstraint : PredOpTrait<
@ -3169,7 +3237,7 @@ def RnnResultConstraint : PredOpTrait<
// UnidirectionalSequenceRNN op. // UnidirectionalSequenceRNN op.
def TFL_UnidirectionalSequenceRNNOp : def TFL_UnidirectionalSequenceRNNOp :
TFL_Op<"unidirectional_sequence_rnn", TFL_Op<"unidirectional_sequence_rnn",
[RnnResultConstraint, StatefulOperands<[4]>]> { [RnnResultConstraint, TFL_StatefulOp]> {
let summary = "Unidirectional sequence rnn operator"; let summary = "Unidirectional sequence rnn operator";
@ -3213,6 +3281,11 @@ def TFL_UnidirectionalSequenceRNNOp :
let customOption = "SequenceRNNOptions"; let customOption = "SequenceRNNOptions";
let verifier = [{ return Verify(*this); }]; let verifier = [{ return Verify(*this); }];
let extraClassDeclaration = [{
// StatefulOpInterface:
std::vector<int> GetStatefulOperands() { return {4}; }
}];
} }
def TFL_WhereOp : TFL_Op<"where", [NoSideEffect]> { def TFL_WhereOp : TFL_Op<"where", [NoSideEffect]> {
@ -3264,7 +3337,7 @@ def SVDFResultConstraint: PredOpTrait<
// SVDF op. // SVDF op.
def TFL_SVDFOp : def TFL_SVDFOp :
TFL_Op<"svdf", TFL_Op<"svdf",
[SVDFResultConstraint, StatefulOperands<[4]>]> { [SVDFResultConstraint, TFL_StatefulOp]> {
let summary = "Single value decomposition filter operator"; let summary = "Single value decomposition filter operator";
@ -3300,6 +3373,25 @@ def TFL_SVDFOp :
let hasOptions = 1; let hasOptions = 1;
let verifier = [{ return Verify(*this); }]; let verifier = [{ return Verify(*this); }];
let extraClassDeclaration = [{
// StatefulOpInterface:
std::vector<int> GetStatefulOperands() { return {4}; }
}];
}
def TFL_SegmentSumOp: TFL_Op<"segment_sum", [NoSideEffect]> {
let summary = "SegmentSum operator";
let description = [{
Computes the sum along segments of a tensor.
}];
let arguments = (ins
TensorOf<[F32, I32]>:$data,
I32Tensor:$segment_ids
);
let results = (outs TensorOf<[F32, I32]>:$output);
} }
#endif // TFL_OPS #endif // TFL_OPS

View File

@ -1,67 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the op traits used in the MLIR TensorFlow Lite dialect.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LLVM.h" // TF:llvm-project
namespace mlir {
namespace OpTrait {
namespace TFL {
// The trait to specify that the specified operands of the TFL op are stateful.
// This is used as a trait like this:
//
// class LSTMOp
// : public Op<LSTMOp, OpTrait::TFL::StatefulOperands<18, 19>::Impl> {
//
template <int... Operands>
class StatefulOperands {
public:
template <typename ConcreteType>
class Impl
: public TraitBase<ConcreteType, StatefulOperands<Operands...>::Impl> {
public:
static std::vector<int> GetStatefulOperands() {
return std::vector<int>({Operands...});
}
};
};
// The trait to specify the channel dimension index of the input (first operand)
// of an affine TFL op (Conv2D, DepthwiseConv2D, FullyConnected).
//
// class Conv2DOp
// : public Op<Conv2DOp, OpTrait::TFL::ChannelDimIndex<0>::Impl> {
//
template <int Index>
class ChannelDimIndex {
public:
template <typename ConcreteType>
class Impl : public TraitBase<ConcreteType, ChannelDimIndex<Index>::Impl> {
public:
static int GetChannelDimIndex() { return Index; }
};
};
} // namespace TFL
} // namespace OpTrait
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_

View File

@ -32,6 +32,6 @@ cc_library(
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
"@llvm-project//mlir:ViewOpGraph", "@llvm-project//mlir:Transforms",
], ],
) )

View File

@ -107,9 +107,6 @@ void WarningUnusedFlags(const toco::ModelFlags& model_flags,
if (toco_flags.output_format()) { if (toco_flags.output_format()) {
LOG(WARNING) << "Ignored output_format."; LOG(WARNING) << "Ignored output_format.";
} }
if (toco_flags.default_ranges_min() || toco_flags.default_ranges_max()) {
LOG(WARNING) << "Ignored default_ranges_stats.";
}
if (toco_flags.drop_control_dependency()) { if (toco_flags.drop_control_dependency()) {
LOG(WARNING) << "Ignored drop_control_dependency."; LOG(WARNING) << "Ignored drop_control_dependency.";
} }
@ -242,6 +239,13 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
tensorflow::ParseOutputArrayInfo(output_arrays, &specs.outputs)); tensorflow::ParseOutputArrayInfo(output_arrays, &specs.outputs));
// Other flags. // Other flags.
if (toco_flags.has_default_ranges_min()) {
quant_specs.default_ranges.first = toco_flags.default_ranges_min();
}
if (toco_flags.has_default_ranges_max()) {
quant_specs.default_ranges.second = toco_flags.default_ranges_max();
}
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops(); bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops(); bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
bool emit_custom_ops = toco_flags.allow_custom_ops(); bool emit_custom_ops = toco_flags.allow_custom_ops();

View File

@ -206,10 +206,17 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
std::unique_ptr<OpPassBase<FuncOp>> std::unique_ptr<OpPassBase<FuncOp>>
CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) { CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) {
auto get_name_func = [](Operation *op) { auto get_name_func = [](Operation *op) {
if (auto name = op->getAttrOfType<StringAttr>("name")) Location loc = op->getLoc();
return name.getValue(); if (auto name = loc.dyn_cast<NameLoc>()) {
else return name.getName().strref();
return llvm::StringRef(""); } else if (auto fused_name = loc.dyn_cast<FusedLoc>()) {
for (auto sub_loc : fused_name.getLocations()) {
if (auto named_sub_loc = sub_loc.dyn_cast<NameLoc>()) {
return named_sub_loc.getName().strref();
}
}
}
return llvm::StringRef("");
}; };
return CreateImportQuantStatsPass(get_name_func, stats_str); return CreateImportQuantStatsPass(get_name_func, stats_str);

View File

@ -23,7 +23,6 @@ cc_library(
], ],
hdrs = [ hdrs = [
"quantize_model.h", "quantize_model.h",
"//tensorflow/compiler/mlir/lite:transforms/passes.h",
], ],
deps = [ deps = [
"//tensorflow/compiler/mlir/lite:common", "//tensorflow/compiler/mlir/lite:common",

View File

@ -73,19 +73,19 @@ TfLiteStatus QuantizeModel(
// Apply quantization passes // Apply quantization passes
PassManager pm(module->getContext()); PassManager pm(module->getContext());
TFL::QuantizationSpecs pass_config; TFL::QuantizationSpecs quant_specs;
pass_config.inference_type = tensorflow::DT_QINT8; quant_specs.inference_type = tensorflow::DT_QINT8;
pass_config.post_training_quantization = true; quant_specs.post_training_quantization = true;
bool emit_adaptor = false; bool emit_adaptor = false;
auto input_tf_type = tflite::TflTypeToTfType(input_type); auto input_tf_type = tflite::TflTypeToTfType(input_type);
if (input_tf_type == tensorflow::DT_FLOAT) { if (input_tf_type == tensorflow::DT_FLOAT) {
emit_adaptor = true; emit_adaptor = true;
} else if (input_tf_type == tensorflow::DT_UINT8) { } else if (input_tf_type == tensorflow::DT_UINT8) {
pass_config.inference_type = tensorflow::DT_QUINT8; quant_specs.inference_type = tensorflow::DT_QUINT8;
} }
pm.addPass(TFL::CreatePrepareQuantizePass(pass_config)); pm.addPass(TFL::CreatePrepareQuantizePass(quant_specs));
pm.addPass(TFL::CreateQuantizePass()); pm.addPass(TFL::CreateQuantizePass());
pm.addPass(TFL::CreatePostQuantizePass(emit_adaptor)); pm.addPass(TFL::CreatePostQuantizePass(emit_adaptor));

View File

@ -23,6 +23,7 @@ limitations under the License.
#include <vector> #include <vector>
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
@ -64,6 +65,10 @@ struct QuantizationSpecs {
// quantization aware training or calibration, for the remaining tensors. // quantization aware training or calibration, for the remaining tensors.
std::vector<std::pair<double, double>> input_ranges; std::vector<std::pair<double, double>> input_ranges;
// The default ranges can be used when a tensor doesn't have quantization
// parameters and couldn't be quantized. Used only for latency tests.
std::pair<llvm::Optional<double>, llvm::Optional<double>> default_ranges;
// A serialized "QuantizationInfo" object to specify value ranges for some of // A serialized "QuantizationInfo" object to specify value ranges for some of
// the tensors with known names. // the tensors with known names.
std::string serialized_quant_stats = ""; std::string serialized_quant_stats = "";

View File

@ -35,7 +35,6 @@ limitations under the License.
#include "mlir/IR/Value.h" // TF:llvm-project #include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project #include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"

View File

@ -3,7 +3,8 @@
// CHECK-LABEL: import_stats_skip // CHECK-LABEL: import_stats_skip
func @import_stats_skip(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) { func @import_stats_skip(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "skip"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>) %0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
loc(fused["skip1", "skip2.cc":10:8, callsite("op" at "skip3.cc":10:8)])
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32> return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
// CHECK-NEXT: "tfl.split" // CHECK-NEXT: "tfl.split"
@ -12,7 +13,8 @@ func @import_stats_skip(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf3
// CHECK-LABEL: import_stats_name // CHECK-LABEL: import_stats_name
func @import_stats_name(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) { func @import_stats_name(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>) %0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
loc(fused["skip1.cc":10:8, "op", callsite("skip2" at "skip3.cc":10:8)])
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32> return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split" // CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"
@ -23,7 +25,8 @@ func @import_stats_name(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf3
// CHECK-LABEL: import_stats_name_port // CHECK-LABEL: import_stats_name_port
func @import_stats_name_port(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) { func @import_stats_name_port(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op_0"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>) %0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
loc(fused["skip1.cc":10:8, "op_0", callsite("skip2" at "skip3.cc":10:8)])
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32> return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split" // CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"
@ -34,6 +37,7 @@ func @import_stats_name_port(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor
// CHECK-LABEL: import_stats_name_regex // CHECK-LABEL: import_stats_name_regex
func @import_stats_name_regex(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) { func @import_stats_name_regex(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op_regex"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>) %0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op_regex"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
loc(fused["skip1.cc":10:8, "op_regex", callsite("skip2" at "skip3.cc":10:8)])
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32> return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split" // CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"

View File

@ -0,0 +1,89 @@
// RUN: tf-opt %s --tfl-default-quant --tfl-quantize | FileCheck %s
// CHECK-LABEL: hardcode_all
func @hardcode_all(%arg0: tensor<2x2xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x2xf32> {
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function="NONE"}: (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform<u8:f32, 0.0078431372549019607:128>>}
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>}
// Quantized tfl.add
// CHECK: %[[add:.*]] = "tfl.add"(%[[q1]], %[[q0]]) {fused_activation_function = "NONE"} : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[add]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>)
// CHECK: return %[[dq]]
}
// CHECK-LABEL: hardcode_input
func @hardcode_input(%arg0: tensor<2x2xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x2xf32> {
%0 = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>}: (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>
%1 = "tfl.dequantize"(%0) : (tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>) -> tensor<2x2xf32>
%4 = "tfl.add"(%1, %arg1) {fused_activation_function="NONE"}: (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
return %4 : tensor<2x2xf32>
// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform<u8:f32, 0.0078431372549019607:128>>}
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e+00:128>>}
// CHECK: %[[add:.*]] = "tfl.add"(%[[q1]], %[[q0]]) {fused_activation_function = "NONE"} : (tensor<2x2x!quant.uniform<u8:f32, 1.000000e+00:128>>
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[add]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>)
// CHECK: return %[[dq]]
}
// CHECK-LABEL: hardcode_input_deq
func @hardcode_input_deq(%arg0: tensor<2x2x!quant.uniform<u8:f32, 1.0>>, %arg1: tensor<2x1xf32>) -> tensor<2x2xf32> {
%1 = "tfl.dequantize"(%arg0) : (tensor<2x2x!quant.uniform<u8:f32, 1.0>>) -> tensor<2x2xf32>
%4 = "tfl.add"(%1, %arg1) {fused_activation_function="NONE"}: (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
return %4 : tensor<2x2xf32>
// CHECK: %[[q:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform<u8:f32, 0.0078431372549019607:128>>}
// CHECK: %[[add:.*]] = "tfl.add"(%arg0, %[[q]]) {fused_activation_function = "NONE"} : (tensor<2x2x!quant.uniform<u8:f32, 1.000000e+00>>
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[add]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>)
// CHECK: return %[[dq]]
}
// CHECK-LABEL: hardcode_output
func @hardcode_output(%arg0: tensor<2x2xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x2xf32> {
%0 = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>}: (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>
%1 = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform<u8:f32, 1.0:128>>}: (tensor<2x1xf32>) -> tensor<2x1x!quant.uniform<u8:f32, 1.0:128>>
%2 = "tfl.dequantize"(%0) : (tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>) -> tensor<2x2xf32>
%3 = "tfl.dequantize"(%1) : (tensor<2x1x!quant.uniform<u8:f32, 1.0:128>>) -> tensor<2x1xf32>
%4 = "tfl.add"(%2, %3) {fused_activation_function="NONE"}: (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
return %4 : tensor<2x2xf32>
// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e+00:128>>}
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform<u8:f32, 1.000000e+00:128>>}
// CHECK: %[[add:.*]] = "tfl.add"(%[[q0]], %[[q1]]) {fused_activation_function = "NONE"} : (tensor<2x2x!quant.uniform<u8:f32, 1.000000e+00:128>>
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[add]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>)
// CHECK: return %[[dq]]
}
// CHECK-LABEL: test_conv_2d_add
func @test_conv_2d_add(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32, 1.0>>, %arg1: tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 1.0>>, %arg2: tensor<32x!quant.uniform<i32:f32, 1.0>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>> {
%0 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 1.0>>) -> tensor<1x224x224x3xf32>
%1 = "tfl.dequantize"(%arg1) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 1.0>>) -> tensor<32x3x3x3xf32>
%2 = "tfl.dequantize"(%arg2) : (tensor<32x!quant.uniform<i32:f32, 1.0>>) -> tensor<32xf32>
%3 = "tfl.conv_2d"(%0, %1, %2) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
%4 = "tfl.pseudo_qconst"() {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>, value = dense<1> : tensor<1x112x112x32xi8>} : () -> tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>
%5 = "tfl.dequantize"(%4) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>) -> tensor<1x112x112x32xf32>
%6 = "tfl.add"(%3, %5) {fused_activation_function="NONE"}: (tensor<1x112x112x32xf32>, tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32>
%7 = "tfl.quantize"(%6) {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>
return %7 : tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>
// CHECK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %arg1, %arg2)
// CHECK-SAME: -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"()
// CHECK: %[[add:.*]] = "tfl.add"(%[[conv]], %[[cst]])
// CHECK-SAME: -> tensor<1x112x112x32x!quant.uniform<u8:f32, 1.000000e+00>>
// CHECK: return %[[add]]
}
// CHECK-LABEL: test_conv_2d_activation_and_bias
func @test_conv_2d_activation_and_bias(%arg0: tensor<1x224x224x3xf32>, %arg1: tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 1.0>>, %arg2: tensor<32xf32>) -> tensor<1x112x112x32xf32> {
%0 = "tfl.dequantize"(%arg1) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 1.0>>) -> tensor<32x3x3x3xf32>
%1 = "tfl.conv_2d"(%arg0, %0, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
return %1 : tensor<1x112x112x32xf32>
// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg2) {qtype = tensor<32x!quant.uniform<i32:f32, 0.0078431372549019607>>}
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x224x224x3x!quant.uniform<u8:f32, 0.0078431372549019607:128>>}
// CHECK: %[[conv:.*]] = "tfl.conv_2d"(%[[q1]], %arg1, %[[q0]])
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[conv]]) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.0078431372549019607:128>>)
// CHECK: return %[[dq]]
}

View File

@ -0,0 +1,76 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s
func @main(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>) -> tensor<1x64x84x32xf32> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: CUSTOM,
// CHECK-NEXT: custom_code: "Convolution2DTransposeBias"
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 32, 4, 4, 128 ],
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "arg0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 32, 42, 128 ],
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "arg1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "arg2",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 64, 84, 32 ],
// CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "tfl.convolution_2d_transpose_bias",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0, 1, 2 ],
// CHECK-NEXT: outputs: [ 3 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0, 1, 2 ],
// CHECK-NEXT: outputs: [ 3 ],
// CHECK-NEXT: custom_options: [ 1, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: } ]
// CHECK-NEXT:}
// MLIR-LABEL: func @main(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>)
// MLIR-SAME: -> tensor<1x64x84x32xf32>
// MLIR: %0 = "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %arg2)
// MLIR-SAME: {padding = "SAME", stride_h = 1 : i32, stride_w = 2 : i32}
// MLIR-SAME: (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32>
// MLIR-NEXT: return %0 : tensor<1x64x84x32xf32>
%0 = "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 1 : i32, stride_w = 2 : i32} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32>
return %0 : tensor<1x64x84x32xf32>
}

View File

@ -0,0 +1,39 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -emit-builtin-tflite-ops=false -o - | flatbuffer_to_string - | FileCheck %s
// CHECK: {
// CHECK: version: 3,
// CHECK: operator_codes: [ {
// CHECK: builtin_code: CUSTOM,
// CHECK: custom_code: "HashTableV2"
// CHECK: } ],
// CHECK: subgraphs: [ {
// CHECK: tensors: [ {
// CHECK: shape: [ ],
// CHECK: type: INT32,
// CHECK: buffer: 1,
// CHECK: name: "tf.HashTableV2",
// CHECK: quantization: {
// CHECK-EMPTY
// CHECK: }
// CHECK: } ],
// CHECK: inputs: [ ],
// CHECK: outputs: [ 0 ],
// CHECK: operators: [ {
// CHECK: inputs: [ ],
// CHECK: outputs: [ 0 ],
// CHECK: custom_options:
// CHECK: name: "main"
// CHECK: } ],
// CHECK: description: "MLIR Converted.",
// CHECK: buffers: [ {
// CHECK-EMPTY
// CHECK: }, {
// CHECK-EMPTY
// CHECK: } ]
// CHECK: }
func @main() -> tensor<*x!tf.resource> {
%0 = "tf.HashTableV2"() {container = "" , shared_name= "table", use_node_name_sharing = false, key_dtype = i32, value_dtype = i32 } : () -> tensor<*x!tf.resource>
return %0 : tensor<*x!tf.resource>
}

View File

@ -0,0 +1,65 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s
func @main(%arg0: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: CUSTOM,
// CHECK-NEXT: custom_code: "MaxPoolingWithArgmax2D"
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 1, 64, 64, 32 ],
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "arg0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 32, 32, 32 ],
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "tfl.max_pooling_with_argmax_2d",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 32, 32, 32 ],
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "tfl.max_pooling_with_argmax_2d:1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0 ],
// CHECK-NEXT: outputs: [ 1, 2 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0 ],
// CHECK-NEXT: outputs: [ 1, 2 ],
// CHECK-NEXT: custom_options: [ 1, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: } ]
// CHECK-NEXT:}
// MLIR-LABEL: func @main(%arg0: tensor<1x64x64x32xf32>)
// MLIR-SAME: -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
// MLIR: %value, %indices = "tfl.max_pooling_with_argmax_2d"(%arg0)
// MLIR-SAME: {filter_h = 4 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 1 : i32}
// MLIR-SAME: (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
// MLIR-NEXT: return %value, %indices : tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>
%0, %1 = "tfl.max_pooling_with_argmax_2d"(%arg0) {filter_h = 4 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 1 : i32} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
return %0, %1 : tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>
}

View File

@ -0,0 +1,65 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s
func @main(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf32>) -> tensor<1x8x8x128xf32> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: CUSTOM,
// CHECK-NEXT: custom_code: "MaxUnpooling2D"
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 1, 8, 8, 128 ],
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "arg0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 8, 8, 128 ],
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "arg1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 8, 8, 128 ],
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "tfl.max_unpooling_2d",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2 ],
// CHECK-NEXT: custom_options: [ 1, 0, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: } ]
// CHECK-NEXT:}
// MLIR-LABEL: func @main(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf32>)
// MLIR-SAME: -> tensor<1x8x8x128xf32>
// MLIR: %0 = "tfl.max_unpooling_2d"(%arg0, %arg1)
// MLIR-SAME: {filter_h = 1 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 4 : i32, stride_w = 2 : i32}
// MLIR-SAME: (tensor<1x8x8x128xf32>, tensor<1x8x8x128xf32>) -> tensor<1x8x8x128xf32>
// MLIR-NEXT: return %0 : tensor<1x8x8x128xf32>
%0 = "tfl.max_unpooling_2d"(%arg0, %arg1) {filter_h = 1 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 4 : i32, stride_w = 2 : i32} : (tensor<1x8x8x128xf32>, tensor<1x8x8x128xf32>) -> (tensor<1x8x8x128xf32>)
return %0 : tensor<1x8x8x128xf32>
}

View File

@ -1977,3 +1977,12 @@ func @testTransposeConvBadOutputShape(%arg1: tensor<32x4x4x128xf32>, %arg2: tens
%0 = "tfl.transpose_conv"(%cst, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>) -> tensor<1x64x84x31xf32> %0 = "tfl.transpose_conv"(%cst, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>) -> tensor<1x64x84x31xf32>
return %0 : tensor<1x64x84x31xf32> return %0 : tensor<1x64x84x31xf32>
} }
// -----
// CHECK-LABEL: testDensify
func @testDensify(%arg0: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.densify"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.densify"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}

View File

@ -1,4 +1,7 @@
// Run optimize pass only and check the results.
// RUN: tf-opt %s -tfl-optimize | FileCheck %s // RUN: tf-opt %s -tfl-optimize | FileCheck %s
// Run optimize pass and then canonicalize pass, and make sure some folding is applied.
// RUN: tf-opt %s -tfl-optimize -canonicalize | FileCheck --check-prefix=FOLD %s
// CHECK-LABEL: fusedConv2dRelu // CHECK-LABEL: fusedConv2dRelu
func @fusedConv2dRelu(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<256x30x30x16xf32> { func @fusedConv2dRelu(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<256x30x30x16xf32> {
@ -75,10 +78,10 @@ func @fuseSubIntoFollowingConv2d(%arg0: tensor<256x32x32x3xf32>) -> tensor<256x3
} }
// CHECK-LABEL: @fuseAddIntoDepthwiseConv2d // CHECK-LABEL: @fuseAddIntoDepthwiseConv2d
func @fuseAddIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> { func @fuseAddIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> {
%cst = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32> %cst = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
%cst_0 = constant dense<1.5> : tensor<16xf32> %cst_0 = constant dense<1.5> : tensor<16xf32>
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> %0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> %1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
return %1 : tensor<256x30x30x16xf32> return %1 : tensor<256x30x30x16xf32>
@ -87,10 +90,10 @@ func @fuseAddIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<1
} }
// CHECK-LABEL: fuseSubIntoDepthwiseConv2d // CHECK-LABEL: fuseSubIntoDepthwiseConv2d
func @fuseSubIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> { func @fuseSubIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> {
%cst = constant dense<0.5> : tensor<16xf32> %cst = constant dense<0.5> : tensor<16xf32>
%cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32> %cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> %0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
%1 = "tfl.sub"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> %1 = "tfl.sub"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
return %1 : tensor<256x30x30x16xf32> return %1 : tensor<256x30x30x16xf32>
@ -128,10 +131,10 @@ func @fuseAddWithRelu6IntoConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<1
} }
// CHECK-LABEL: @fuseAddWithRelu6IntoDepthwiseConv2d // CHECK-LABEL: @fuseAddWithRelu6IntoDepthwiseConv2d
func @fuseAddWithRelu6IntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> { func @fuseAddWithRelu6IntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> {
%cst = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32> %cst = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
%cst_0 = constant dense<1.5> : tensor<16xf32> %cst_0 = constant dense<1.5> : tensor<16xf32>
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> %0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "RELU6"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> %1 = "tfl.add"(%0, %cst) {fused_activation_function = "RELU6"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
return %1 : tensor<256x30x30x16xf32> return %1 : tensor<256x30x30x16xf32>
@ -302,6 +305,58 @@ func @FuseFullyConnectedAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf
// CHECK: return %[[fc]] // CHECK: return %[[fc]]
} }
// CHECK-LABEL: @FuseFullyConnectedReshapeAddConst
// FOLD-LABEL: @FuseFullyConnectedReshapeAddConst
func @FuseFullyConnectedReshapeAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
%cst = constant dense<3.0> : tensor<40x40xf32>
%cst2 = constant dense<2.0> : tensor<40xf32>
%shape1 = constant dense<[1, 40, 40]> : tensor<3xi32>
%shape2 = constant dense<[40, 40]> : tensor<2xi32>
%0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, tensor<40x40xf32>) -> (tensor<40x40xf32>)
%1 = "tfl.reshape"(%0, %shape1) : (tensor<40x40xf32>, tensor<3xi32>) -> tensor<1x40x40xf32>
%2 = "tfl.add"(%1, %cst2) {fused_activation_function = "NONE"} : (tensor<1x40x40xf32>, tensor<40xf32>) -> tensor<1x40x40xf32>
%3 = "tfl.reshape"(%2, %shape2) : (tensor<1x40x40xf32>, tensor<2xi32>) -> tensor<40x40xf32>
return %3 : tensor<40x40xf32>
// CHECK: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40x40xf32>
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%[[fc]]
// CHECK: %[[rs2:.*]] = "tfl.reshape"(%[[rs1]]
// CHECK: return %[[rs2]]
// FOLD: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40x40xf32>
// FOLD: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
// FOLD: return %[[fc]]
}
// CHECK-LABEL: @NotReorderReshapeAddIfNotBroadcastable
func @NotReorderReshapeAddIfNotBroadcastable(%arg0: tensor<40x10x4xf32>) -> tensor<40x40xf32> {
%cst = constant dense<2.0> : tensor<40xf32>
%shape = constant dense<[40, 40]> : tensor<2xi32>
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<40x10x4xf32>, tensor<2xi32>) -> tensor<40x40xf32>
%2 = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40xf32>) -> tensor<40x40xf32>
return %2 : tensor<40x40xf32>
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0
// CHECK: %[[rs2:.*]] = "tfl.add"(%[[rs1]]
// CHECK: return %[[rs2]]
}
// CHECK-LABEL: @NotReorderReshapeAddIfNotTailingDim
func @NotReorderReshapeAddIfNotTailingDim(%arg0: tensor<40x40x1xf32>) -> tensor<40x40xf32> {
%cst = constant dense<2.0> : tensor<1x40xf32>
%shape = constant dense<[40, 40]> : tensor<2xi32>
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<40x40x1xf32>, tensor<2xi32>) -> tensor<40x40xf32>
%2 = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<1x40xf32>) -> tensor<40x40xf32>
return %2 : tensor<40x40xf32>
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0
// CHECK: %[[rs2:.*]] = "tfl.add"(%[[rs1]]
// CHECK: return %[[rs2]]
}
// CHECK-LABEL: @FuseFullyConnectedRelu // CHECK-LABEL: @FuseFullyConnectedRelu
func @FuseFullyConnectedRelu(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> { func @FuseFullyConnectedRelu(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> {
%0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32> %0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32>

View File

@ -1,5 +1,6 @@
// RUN: tf-opt -tfl-prepare-composite-funcs-tf %s | FileCheck %s --dump-input-on-failure // RUN: tf-opt -tfl-prepare-composite-funcs-tf %s -split-input-file | FileCheck %s --dump-input-on-failure
module{
func @embedding(%arg0: tensor<*xf32>, %arg1: tensor<*xi32>) -> tensor<*xf32> attributes {tf._implements = "embedding_matmul", tf._reference = "mlir"} { func @embedding(%arg0: tensor<*xf32>, %arg1: tensor<*xi32>) -> tensor<*xf32> attributes {tf._implements = "embedding_matmul", tf._reference = "mlir"} {
%0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> %0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.ExpandDims"(%arg1, %0) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32> %1 = "tf.ExpandDims"(%arg1, %0) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
@ -148,3 +149,39 @@ func @layernormalizedlstmcellsimple(%arg0: tensor<1x?xf32>, %arg1: tensor<3x4xf3
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x?xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<3x1xf32>, tensor<3xf32>, tensor<1x3xf32>, tensor<1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x3xf32> // CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x?xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<3x1xf32>, tensor<3xf32>, tensor<1x3xf32>, tensor<1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x3xf32>
// CHECK: [[VAL_104:%.*]] = tensor_cast [[VAL_105:%.*]] : tensor<1x3xf32> to tensor<1x?xf32> // CHECK: [[VAL_104:%.*]] = tensor_cast [[VAL_105:%.*]] : tensor<1x3xf32> to tensor<1x?xf32>
// CHECK: return [[VAL_104]] : tensor<1x?xf32> // CHECK: return [[VAL_104]] : tensor<1x?xf32>
}
// -----
module {
func @inference_standard_lstm_7410(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.signature.is_stateful} {
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
%1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
%3 = "tf.Add"(%2, %arg1) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
%4 = "tf.Add"(%2, %arg2) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x?x10xf32>
%5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
%6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
}
// CHECK: func @inference_standard_lstm_7410([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<?x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.signature.is_stateful} {
// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32>
// CHECK: [[VAL_10:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: [[VAL_11:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
// CHECK: [[VAL_13:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: [[VAL_14:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
// CHECK: [[VAL_16:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: [[VAL_17:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
// CHECK: [[VAL_19:%.*]] = constant unit
// CHECK: [[VAL_20:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) ( {
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<?x8x10xf32>
// CHECK: return [[VAL_21:%.*]] : tensor<?x8x10xf32>
}

View File

@ -414,6 +414,14 @@ func @CheckNumerics(%arg0: tensor<3xf32>) -> tensor<3xf32> {
// CHECK: return %arg0 : tensor<3xf32> // CHECK: return %arg0 : tensor<3xf32>
} }
func @placeholder_with_default(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "tf.PlaceholderWithDefault"(%arg0): (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// Should be converted to Identity and then from Identity to value
// CHECK-LABEL: placeholder_with_default
// CHECK: return %arg0 : tensor<3xf32>
}
// CHECK-LABEL: @NoPadStridedSliceNonNewAxisMask // CHECK-LABEL: @NoPadStridedSliceNonNewAxisMask
func @NoPadStridedSliceNonNewAxisMask(%arg0: tensor<1x2x3x1xf32>) -> tensor<1x2x3x1xf32> { func @NoPadStridedSliceNonNewAxisMask(%arg0: tensor<1x2x3x1xf32>) -> tensor<1x2x3x1xf32> {
%cst = constant dense<0> : tensor<4xi32> %cst = constant dense<0> : tensor<4xi32>
@ -426,8 +434,8 @@ func @NoPadStridedSliceNonNewAxisMask(%arg0: tensor<1x2x3x1xf32>) -> tensor<1x2x
// CHECK: %0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 15 : i64, ellipsis_mask = 0 : i64, end_mask = 15 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1x2x3x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32> // CHECK: %0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 15 : i64, ellipsis_mask = 0 : i64, end_mask = 15 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1x2x3x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
} }
// CHECK-LABEL: @PadStridedSliceNewAxisMask // CHECK-LABEL: @PadStridedSliceNewAxisMask1
func @PadStridedSliceNewAxisMask(%arg0: tensor<2x3xf32>) -> tensor<1x2x3x1xf32> { func @PadStridedSliceNewAxisMask1(%arg0: tensor<2x3xf32>) -> tensor<1x2x3x1xf32> {
%cst = constant dense<0> : tensor<4xi32> %cst = constant dense<0> : tensor<4xi32>
%cst_0 = constant dense<1> : tensor<4xi32> %cst_0 = constant dense<1> : tensor<4xi32>
%0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 9 : i64, shrink_axis_mask = 0 : i64} : (tensor<2x3xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32> %0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 9 : i64, shrink_axis_mask = 0 : i64} : (tensor<2x3xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
@ -439,3 +447,12 @@ func @PadStridedSliceNewAxisMask(%arg0: tensor<2x3xf32>) -> tensor<1x2x3x1xf32>
// CHECK: %0 = "tf.Reshape"(%arg0, %[[cst_1]]) : (tensor<2x3xf32>, tensor<4xi32>) -> tensor<1x2x3x1xf32> // CHECK: %0 = "tf.Reshape"(%arg0, %[[cst_1]]) : (tensor<2x3xf32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
// CHECK: %1 = "tf.StridedSlice"(%0, %cst, %cst, %cst_0) {begin_mask = 15 : i64, ellipsis_mask = 0 : i64, end_mask = 15 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1x2x3x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32> // CHECK: %1 = "tf.StridedSlice"(%0, %cst, %cst, %cst_0) {begin_mask = 15 : i64, ellipsis_mask = 0 : i64, end_mask = 15 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1x2x3x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
} }
// CHECK-LABEL: @PadStridedSliceNewAxisMask2
func @PadStridedSliceNewAxisMask2(%arg0: tensor<4x64x64x1xf32>) -> tensor<1x4x64x64xf32> {
%cst = constant dense<0> : tensor<3xi32>
%cst_0 = constant dense<1> : tensor<3xi32>
%0 = "tf.Squeeze"(%arg0) {T = f32, _output_shapes = ["tfshape$dim { size: 4 } dim { size: 64 } dim { size: 64 }"], device = "", squeeze_dims = []} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
%1 = "tf.StridedSlice"(%0, %cst, %cst, %cst_0) {Index = i32, T = f32, _output_shapes = ["tfshape$dim { size: 1 } dim { size: 4 } dim { size: 64 } dim { size: 64 }"], begin_mask = 6 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 1 : i64, shrink_axis_mask = 0 : i64} : (tensor<4x64x64xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x4x64x64xf32>
return %1 : tensor<1x4x64x64xf32>
}

View File

@ -43,6 +43,16 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
quant_specs.inference_type != quant_specs.inference_input_type; quant_specs.inference_type != quant_specs.inference_input_type;
pass_manager->addPass( pass_manager->addPass(
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops)); mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
if (quant_specs.default_ranges.first.hasValue() ||
quant_specs.default_ranges.second.hasValue()) {
pass_manager->addPass(mlir::TFL::CreateDefaultQuantParamsPass(
quant_specs.default_ranges.first.getValueOr(0.0),
quant_specs.default_ranges.second.getValueOr(0.0)));
pass_manager->addPass(mlir::TFL::CreateQuantizePass());
pass_manager->addPass(
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
}
} }
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
@ -115,7 +125,8 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
if (pass_config.emit_builtin_tflite_ops) { if (pass_config.emit_builtin_tflite_ops) {
// Prepare for TFLite dialect, rerun canonicalization, and then legalize to // Prepare for TFLite dialect, rerun canonicalization, and then legalize to
// the TFLite dialect. // the TFLite dialect.
pass_manager->addPass(mlir::TFL::CreatePrepareTFPass()); pass_manager->addPass(
mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul));
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass()); pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pass_manager->addPass(mlir::TFL::CreateLegalizeTFPass()); pass_manager->addPass(mlir::TFL::CreateLegalizeTFPass());
pass_manager->addPass(mlir::TFL::CreateOptimizePass()); pass_manager->addPass(mlir::TFL::CreateOptimizePass());

View File

@ -86,15 +86,15 @@ StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
if (use_splatted_constant) { if (use_splatted_constant) {
return tensorflow::GraphdefToSplattedMlirTranslateFunction( return tensorflow::GraphdefToSplattedMlirTranslateFunction(
file->getBuffer(), debug_info_file, input_arrays, input_dtypes, file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
input_shapes, output_arrays, prune_unused_nodes, input_shapes, output_arrays, /*control_output_arrays=*/"",
/*convert_legacy_fed_inputs=*/true, prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
/*graph_as_function=*/false, /*upgrade_legacy=*/true, context); /*graph_as_function=*/false, /*upgrade_legacy=*/true, context);
} }
return tensorflow::GraphdefToMlirTranslateFunction( return tensorflow::GraphdefToMlirTranslateFunction(
file->getBuffer(), debug_info_file, input_arrays, input_dtypes, file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
input_shapes, output_arrays, prune_unused_nodes, input_shapes, output_arrays, /*control_output_arrays=*/"",
/*convert_legacy_fed_inputs=*/true, /*graph_as_function=*/false, prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
/*upgrade_legacy=*/true, context); /*graph_as_function=*/false, /*upgrade_legacy=*/true, context);
} }
Status ConvertTFExecutorToTFLOrFlatbuffer( Status ConvertTFExecutorToTFLOrFlatbuffer(

View File

@ -0,0 +1,234 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/LLVM.h"
#include "absl/memory/memory.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
//===----------------------------------------------------------------------===//
// The Pass to add default quantization parameters for the activations which
// don't have quantization information. These default parameters are usually
// not from real measurement, so this pass is only for test purpose.
namespace mlir {
namespace TFL {
// Includs an auto-generated function, which can retrieve the quantization
// specification for an TFL operation. The signature of the function is
// std::unique_pointer<OpQuantSpec> TFL::GetOpQuantSpec(Operation *)
#include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc"
namespace {
class DefaultQuantParamsPass : public FunctionPass<DefaultQuantParamsPass> {
public:
explicit DefaultQuantParamsPass(double default_min, double default_max)
: default_min_(default_min), default_max_(default_max) {}
void runOnFunction() override;
private:
// Whether the value is used as a bias input of another op. Here we assume
// bias is used immediately by the user. This assumption is always correct
// after constant folding.
bool UsedAsBias(Value value) {
for (auto &use : value.getUses()) {
auto biases = TFL::GetOpQuantSpec(use.getOwner())->biases_params;
if (biases.find(use.getOperandNumber()) != biases.end()) return true;
}
return false;
}
// Uses `quant_params` to quantize `value` and inserting a pair of
// tfl.quantize and tfl.dequantize ops for this `value`.
void QuantizeValue(OpBuilder builder, Value value,
TFL::QuantParams quant_params);
// If the value hasn't been quantized, the functions adds it to `values`.
void AddToWorkListIfUnquantized(Value value, std::vector<Value> *values);
// Converts the default min/max to the default quantization parameters.
TFL::QuantParams GetDefaultQuantParams(Builder builder);
// Gets the quantization parameters for the bias of an operation by using the
// quantization parameters from the non-biases operands.
TFL::QuantParams GetQuantParamsForBias(Operation *op, int bias,
const std::vector<int> &non_biases,
TFL::AccumulatorScaleFunc func);
double default_min_;
double default_max_;
TFL::QuantParams default_quant_params_;
};
} // namespace
void DefaultQuantParamsPass::runOnFunction() {
FuncOp func = getFunction();
OpBuilder builder(func);
std::vector<Value> activation_values;
std::vector<Value> bias_values;
// First of all, collect all the values (block arguments and op results) which
// are required to be quantized.
for (auto arg : func.getBody().begin()->getArguments()) {
if (UsedAsBias(arg)) {
AddToWorkListIfUnquantized(arg, &bias_values);
} else {
AddToWorkListIfUnquantized(arg, &activation_values);
}
}
func.walk([&](Operation *op) {
if (op->isKnownTerminator() ||
op->hasTrait<OpTrait::quant::NoQuantizableResult>())
return;
for (auto res : op->getResults()) {
if (UsedAsBias(res)) {
AddToWorkListIfUnquantized(res, &bias_values);
} else {
AddToWorkListIfUnquantized(res, &activation_values);
}
}
});
// Apply the default quantization parameters for these activation values.
TFL::QuantParams default_params = GetDefaultQuantParams(builder);
for (Value value : activation_values) {
QuantizeValue(builder, value, default_params);
}
// Since all the non-biases operands have quantization parameters now, we
// should be able to propagate them to the bias operand.
for (Value bias : bias_values) {
Operation *op = *bias.user_begin();
auto spec = TFL::GetOpQuantSpec(op);
for (auto &it : spec->biases_params) {
TFL::QuantParams bias_params = GetQuantParamsForBias(
op, it.first, it.second.first, it.second.second);
if (!bias_params) continue;
QuantizeValue(builder, bias, bias_params);
}
}
}
void DefaultQuantParamsPass::AddToWorkListIfUnquantized(
Value value, std::vector<Value> *values) {
// If the result isn't with float type, this result is an integer tensor and
// doesn't require quantization.
auto tensor_type = value.getType().dyn_cast<TensorType>();
if (!tensor_type) {
// There are none type values.
return;
}
if (!tensor_type.getElementType().isF32()) return;
// If the result is consumed by a quantize op, it has been quantized.
if (value.hasOneUse() &&
llvm::isa<TFL::QuantizeOp>(*value.getUsers().begin()))
return;
// Add this result to the list to apply the default value.
values->push_back(value);
}
void DefaultQuantParamsPass::QuantizeValue(OpBuilder builder, Value value,
TFL::QuantParams quant_params) {
Type expressed_type = value.getType();
Type new_type = quant_params.castFromExpressedType(expressed_type);
// This value isn't an expressed type (float), skip.
if (!new_type) return;
Block &block = value.getParentRegion()->front();
Operation *op = value.getDefiningOp();
if (op) {
builder.setInsertionPoint(&block, ++Block::iterator(op));
} else {
builder.setInsertionPointToStart(&block);
}
TypeAttr type_attr = TypeAttr::get(new_type);
auto quantize = builder.create<TFL::QuantizeOp>(value.getLoc(), new_type,
value, type_attr);
auto dequantize = builder.create<TFL::DequantizeOp>(
value.getLoc(), expressed_type, quantize.output());
value.replaceAllUsesWith(dequantize);
// `quantize` is using `dequantize` now, so we should set its operand to
// `value`.
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
}
TFL::QuantParams DefaultQuantParamsPass::GetQuantParamsForBias(
Operation *op, int bias, const std::vector<int> &non_biases,
TFL::AccumulatorScaleFunc func) {
std::vector<quant::QuantizedType> non_bias_types;
non_bias_types.reserve(non_biases.size());
for (int non_bias : non_biases) {
Operation *non_bias_define = op->getOperand(non_bias).getDefiningOp();
if (auto dequant = llvm::dyn_cast<TFL::DequantizeOp>(non_bias_define)) {
auto non_bias_type = dequant.input().getType().cast<TensorType>();
auto non_bias_ele_type =
non_bias_type.getElementType().cast<quant::QuantizedType>();
non_bias_types.push_back(non_bias_ele_type);
} else {
// The non-bias hasn't been quantized, let's skip this bias.
break;
}
}
// The non-bias hasn't been quantized, let's skip this bias.
if (non_bias_types.size() != non_biases.size()) return {};
return func(non_bias_types);
}
TFL::QuantParams DefaultQuantParamsPass::GetDefaultQuantParams(
Builder builder) {
if (!default_quant_params_) {
default_quant_params_ = quant::fakeQuantAttrsToType(
builder.getUnknownLoc(),
/*numBits=*/8, default_min_, default_max_, /*narrowRange=*/false,
builder.getF32Type());
}
return default_quant_params_;
}
// Creates an instance of the default quant parameters pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateDefaultQuantParamsPass(
double default_min, double default_max) {
return absl::make_unique<DefaultQuantParamsPass>(default_min, default_max);
}
// Registers this pass with default values, only for test
static PassRegistration<DefaultQuantParamsPass> pass(
"tfl-default-quant",
"Apply quantization with default quantization parameter", [] {
return CreateDefaultQuantParamsPass(/*default_min=*/-1.0,
/*default_max=*/1.0);
});
} // namespace TFL
} // namespace mlir

View File

@ -150,6 +150,7 @@ def : Pat<(TF_RoundOp $arg), (TFL_RoundOp $arg)>;
def : Pat<(TF_RsqrtOp $arg), (TFL_RsqrtOp $arg)>; def : Pat<(TF_RsqrtOp $arg), (TFL_RsqrtOp $arg)>;
def : Pat<(TF_SqrtOp $arg), (TFL_SqrtOp $arg)>; def : Pat<(TF_SqrtOp $arg), (TFL_SqrtOp $arg)>;
def : Pat<(TF_SquareOp $arg), (TFL_SquareOp $arg)>; def : Pat<(TF_SquareOp $arg), (TFL_SquareOp $arg)>;
def : Pat<(TF_SegmentSumOp $data, I32Tensor:$segment_ids), (TFL_SegmentSumOp $data, $segment_ids)>;
def : Pat<(TF_SelectOp $cond, $x, $y), (TFL_SelectOp $cond, $x, $y)>; def : Pat<(TF_SelectOp $cond, $x, $y), (TFL_SelectOp $cond, $x, $y)>;
def : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), (TFL_SelectOp $cond, $x, $y), [(HasSameStaticShapes $src_op)]>; def : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), (TFL_SelectOp $cond, $x, $y), [(HasSameStaticShapes $src_op)]>;
def : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), (TFL_SelectV2Op $cond, $x, $y), [(HasNotSameStaticShapes $src_op)]>; def : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), (TFL_SelectV2Op $cond, $x, $y), [(HasNotSameStaticShapes $src_op)]>;

Some files were not shown because too many files have changed in this diff Show More