Merge branch 'master' into google_upstream_rocblas_complex
This commit is contained in:
commit
b64dde60e8
2
.bazelrc
2
.bazelrc
@ -279,7 +279,6 @@ build:windows --host_linkopt=/OPT:REF
|
||||
build:windows --linkopt=/OPT:ICF
|
||||
build:windows --host_linkopt=/OPT:ICF
|
||||
build:windows --experimental_strict_action_env=true
|
||||
build:windows --incompatible_windows_native_test_wrapper
|
||||
|
||||
# Verbose failure logs when something goes wrong
|
||||
build:windows --verbose_failures
|
||||
@ -344,6 +343,7 @@ build:rbe_linux --config=avx_linux
|
||||
build:rbe_linux --config=short_logs
|
||||
# TODO(gunan): Check why we need this specified in rbe, but not in other builds.
|
||||
build:rbe_linux --linkopt=-lrt
|
||||
build:rbe_linux --linkopt=-lm
|
||||
|
||||
build:rbe_cpu_linux --config=rbe_linux
|
||||
build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain"
|
||||
|
@ -1 +1 @@
|
||||
1.1.0
|
||||
1.2.1
|
||||
|
178
RELEASE.md
178
RELEASE.md
File diff suppressed because one or more lines are too long
38
WORKSPACE
38
WORKSPACE
@ -1,11 +1,13 @@
|
||||
workspace(name = "org_tensorflow")
|
||||
|
||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file")
|
||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||
load("//third_party:repo.bzl", "tf_http_archive")
|
||||
|
||||
http_archive(
|
||||
tf_http_archive(
|
||||
name = "io_bazel_rules_closure",
|
||||
sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9",
|
||||
strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149",
|
||||
patch_file = "@org_tensorflow//third_party:rules_closure.patch",
|
||||
urls = [
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz",
|
||||
"https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13
|
||||
@ -48,38 +50,6 @@ load("//third_party/toolchains/preconfig/generate:workspace.bzl",
|
||||
|
||||
remote_config_workspace()
|
||||
|
||||
# Apple and Swift rules.
|
||||
http_archive(
|
||||
name = "build_bazel_rules_apple",
|
||||
sha256 = "a045a436b642c70fb0c10ca84ff0fd2dcbd59cc89100d597a61e8374afafb366",
|
||||
urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.18.0/rules_apple.0.18.0.tar.gz"],
|
||||
) # https://github.com/bazelbuild/rules_apple/releases
|
||||
http_archive(
|
||||
name = "build_bazel_rules_swift",
|
||||
sha256 = "18cd4df4e410b0439a4935f9ca035bd979993d42372ba79e7f2d4fafe9596ef0",
|
||||
urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.12.1/rules_swift.0.12.1.tar.gz"],
|
||||
) # https://github.com/bazelbuild/rules_swift/releases
|
||||
http_archive(
|
||||
name = "build_bazel_apple_support",
|
||||
sha256 = "122ebf7fe7d1c8e938af6aeaee0efe788a3a2449ece5a8d6a428cb18d6f88033",
|
||||
urls = ["https://github.com/bazelbuild/apple_support/releases/download/0.7.1/apple_support.0.7.1.tar.gz"],
|
||||
) # https://github.com/bazelbuild/apple_support/releases
|
||||
http_archive(
|
||||
name = "bazel_skylib",
|
||||
sha256 = "1dde365491125a3db70731e25658dfdd3bc5dbdfd11b840b3e987ecf043c7ca0",
|
||||
urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.9.0/bazel-skylib.0.9.0.tar.gz"],
|
||||
) # https://github.com/bazelbuild/bazel-skylib/releases
|
||||
http_archive(
|
||||
name = "com_github_apple_swift_swift_protobuf",
|
||||
type = "zip",
|
||||
strip_prefix = "swift-protobuf-1.6.0/",
|
||||
urls = ["https://github.com/apple/swift-protobuf/archive/1.6.0.zip"],
|
||||
) # https://github.com/apple/swift-protobuf/releases
|
||||
http_file(
|
||||
name = "xctestrunner",
|
||||
executable = 1,
|
||||
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.9/ios_test_runner.par"],
|
||||
) # https://github.com/google/xctestrunner/releases
|
||||
# Use `swift_rules_dependencies` to fetch the toolchains. With the
|
||||
# `git_repository` rules above, the following call will skip redefining them.
|
||||
load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies")
|
||||
|
@ -49,8 +49,8 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
|
||||
_TF_WORKSPACE_ROOT = ''
|
||||
_TF_BAZELRC = ''
|
||||
_TF_CURRENT_BAZEL_VERSION = None
|
||||
_TF_MIN_BAZEL_VERSION = '1.0.0'
|
||||
_TF_MAX_BAZEL_VERSION = '1.1.0'
|
||||
_TF_MIN_BAZEL_VERSION = '1.2.1'
|
||||
_TF_MAX_BAZEL_VERSION = '1.2.1'
|
||||
|
||||
NCCL_LIB_PATHS = [
|
||||
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
|
||||
|
@ -2,6 +2,7 @@
|
||||
# TensorFlow is a computational framework, primarily for use in machine
|
||||
# learning applications.
|
||||
|
||||
load("@bazel_skylib//lib:selects.bzl", "selects")
|
||||
load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary")
|
||||
load(
|
||||
"//tensorflow/core/platform:build_config.bzl",
|
||||
@ -478,6 +479,7 @@ bzl_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core/platform:build_config_root_bzl",
|
||||
"//tensorflow/core/platform:rules_cc_bzl",
|
||||
"//tensorflow/core/platform/default:cuda_build_defs_bzl",
|
||||
"//third_party/mkl:build_defs_bzl",
|
||||
"//third_party/mkl_dnn:build_defs_bzl",
|
||||
|
@ -23,10 +23,6 @@ from __future__ import print_function
|
||||
# pylint: disable=g-bad-import-order
|
||||
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
||||
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
|
||||
del LazyLoader
|
||||
|
||||
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
|
||||
from tensorflow.python.platform import app # pylint: disable=g-import-not-at-top
|
||||
app.flags = flags
|
||||
|
@ -302,6 +302,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"//tensorflow/core/platform",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
@ -458,7 +458,7 @@ static void TF_Run_Helper(
|
||||
EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
|
||||
continue;
|
||||
}
|
||||
c_outputs[i] = TF_TensorFromTensor(src, status);
|
||||
c_outputs[i] = TF_TensorFromTensor(src, &status->status);
|
||||
if (!status->status.ok()) return;
|
||||
}
|
||||
}
|
||||
@ -1493,7 +1493,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
|
||||
Tensor t;
|
||||
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
|
||||
if (!status->status.ok()) return;
|
||||
*value = TF_TensorFromTensor(t, status);
|
||||
*value = TF_TensorFromTensor(t, &status->status);
|
||||
}
|
||||
|
||||
void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
|
||||
@ -1504,7 +1504,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
|
||||
if (!status->status.ok()) return;
|
||||
const auto len = std::min(max_values, static_cast<int>(ts.size()));
|
||||
for (int i = 0; i < len; ++i) {
|
||||
values[i] = TF_TensorFromTensor(ts[i], status);
|
||||
values[i] = TF_TensorFromTensor(ts[i], &status->status);
|
||||
}
|
||||
}
|
||||
|
||||
@ -2398,7 +2398,7 @@ unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
|
||||
graph->graph.versions().producer(), &evaluated, &result_tensor);
|
||||
if (evaluated) {
|
||||
DCHECK(status->status.ok());
|
||||
*result = TF_TensorFromTensor(result_tensor, status);
|
||||
*result = TF_TensorFromTensor(result_tensor, &status->status);
|
||||
if (!status->status.ok()) evaluated = false;
|
||||
}
|
||||
return evaluated;
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
@ -549,7 +550,7 @@ TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(TFE_Op* op,
|
||||
TF_Status* status) {
|
||||
TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification;
|
||||
|
||||
n->thread.reset(op->operation.EagerContext()->TFEnv()->StartThread(
|
||||
n->thread.reset(op->operation.EagerContext().TFEnv()->StartThread(
|
||||
tensorflow::ThreadOptions(), "ExecuteOpThread",
|
||||
[op, retvals, num_retvals, n]() {
|
||||
TFE_Execute(op, retvals, num_retvals, n->status.get());
|
||||
@ -634,7 +635,7 @@ TF_Tensor* TF_CheckpointReaderGetTensor(TF_CheckpointReader* reader,
|
||||
std::unique_ptr<tensorflow::Tensor> tensor;
|
||||
reader->GetTensor(name, &tensor, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
return tensorflow::TF_TensorFromTensor(*tensor, status);
|
||||
return tensorflow::TF_TensorFromTensor(*tensor, &status->status);
|
||||
}
|
||||
|
||||
void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader,
|
||||
@ -767,8 +768,9 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
|
||||
} while (0);
|
||||
|
||||
// New server created for new server_def. Unused if updating server_def.
|
||||
tensorflow::EagerContext* context = ctx->context;
|
||||
tensorflow::GrpcServer* grpc_server =
|
||||
dynamic_cast<tensorflow::GrpcServer*>(ctx->context->GetServer());
|
||||
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||
if (grpc_server == nullptr) {
|
||||
std::unique_ptr<tensorflow::ServerInterface> new_server;
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
|
||||
@ -779,12 +781,12 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
|
||||
}
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
|
||||
|
||||
LOG_AND_RETURN_IF_ERROR(ctx->context->StoreCollectiveOpsServer(
|
||||
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
|
||||
std::move(new_server), grpc_server->worker_env()->device_mgr,
|
||||
grpc_server->worker_env()->collective_executor_mgr));
|
||||
} else {
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
|
||||
LOG_AND_RETURN_IF_ERROR(ctx->context->StoreCollectiveOpsServer(
|
||||
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
|
||||
/*new_server=*/nullptr, grpc_server->worker_env()->device_mgr,
|
||||
grpc_server->worker_env()->collective_executor_mgr));
|
||||
}
|
||||
|
@ -1260,11 +1260,10 @@ TEST_F(CApiFunctionTest, GraphToFunctionDefWithPlaceholderAttr) {
|
||||
NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node3", "v2",
|
||||
&node3);
|
||||
|
||||
TF_Output inputs[] = {};
|
||||
TF_Output outputs[] = {{node1, 0}, {node2, 0}, {node3, 0}};
|
||||
func_ = TF_GraphToFunction(
|
||||
func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
|
||||
/*opers=*/nullptr, 0, inputs, 3, outputs,
|
||||
/*opers=*/nullptr, 0, nullptr, 3, outputs,
|
||||
/*output_names=*/nullptr,
|
||||
/*opts=*/nullptr, /*description=*/nullptr, s.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
|
||||
@ -1300,10 +1299,9 @@ TEST_F(CApiFunctionTest, GraphToFunctionDefWithArgAttr) {
|
||||
&node);
|
||||
|
||||
TF_Output inputs[] = {{node, 0}};
|
||||
TF_Output outputs[] = {};
|
||||
func_ = TF_GraphToFunction(
|
||||
func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
|
||||
/*opers=*/nullptr, 1, inputs, 0, outputs,
|
||||
/*opers=*/nullptr, 1, inputs, 0, nullptr,
|
||||
/*output_names=*/nullptr,
|
||||
/*opts=*/nullptr, /*description=*/nullptr, s.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
|
||||
@ -1603,11 +1601,10 @@ void DefineStatefulFunction(const char* name, TF_Function** func) {
|
||||
TF_Operation* random =
|
||||
RandomUniform(shape, TF_FLOAT, func_graph.get(), s.get());
|
||||
|
||||
TF_Output inputs[] = {};
|
||||
TF_Output outputs[] = {{random, 0}};
|
||||
*func = TF_GraphToFunction(func_graph.get(), name,
|
||||
/*append_hash_to_fn_name=*/false, -1,
|
||||
/*opers=*/nullptr, 0, inputs, 1, outputs,
|
||||
/*opers=*/nullptr, 0, nullptr, 1, outputs,
|
||||
/*output_names=*/nullptr,
|
||||
/*opts=*/nullptr, "", s.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
|
||||
|
@ -188,7 +188,7 @@ namespace tensorflow {
|
||||
|
||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
||||
|
||||
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
|
||||
TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
|
||||
|
||||
Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
|
||||
TF_Buffer* out);
|
||||
|
@ -51,7 +51,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/equal_graph_def.h"
|
||||
|
||||
namespace tensorflow {
|
||||
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
|
||||
TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
|
||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
||||
|
||||
namespace {
|
||||
@ -227,7 +227,7 @@ TEST(CAPI, LibraryLoadFunctions) {
|
||||
|
||||
void TestEncodeDecode(int line, const std::vector<string>& data) {
|
||||
const tensorflow::int64 n = data.size();
|
||||
TF_Status* status = TF_NewStatus();
|
||||
Status status;
|
||||
for (const std::vector<tensorflow::int64>& dims :
|
||||
std::vector<std::vector<tensorflow::int64>>{
|
||||
{n}, {1, n}, {n, 1}, {n / 2, 2}}) {
|
||||
@ -236,8 +236,8 @@ void TestEncodeDecode(int line, const std::vector<string>& data) {
|
||||
for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) {
|
||||
src.flat<tstring>()(i) = data[i];
|
||||
}
|
||||
TF_Tensor* dst = TF_TensorFromTensor(src, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_Tensor* dst = TF_TensorFromTensor(src, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.error_message();
|
||||
|
||||
// Convert back to a C++ Tensor and ensure we get expected output.
|
||||
Tensor output;
|
||||
@ -249,7 +249,6 @@ void TestEncodeDecode(int line, const std::vector<string>& data) {
|
||||
|
||||
TF_DeleteTensor(dst);
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(CAPI, TensorEncodeDecodeStrings) {
|
||||
@ -1394,8 +1393,9 @@ TEST(CAPI, SavedModel) {
|
||||
TF_Operation* input_op =
|
||||
TF_GraphOperationByName(graph, input_op_name.c_str());
|
||||
ASSERT_TRUE(input_op != nullptr);
|
||||
csession.SetInputs({{input_op, TF_TensorFromTensor(input, s)}});
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
Status status;
|
||||
csession.SetInputs({{input_op, TF_TensorFromTensor(input, &status)}});
|
||||
ASSERT_TRUE(status.ok()) << status.error_message();
|
||||
|
||||
const tensorflow::string output_op_name(
|
||||
tensorflow::ParseTensorName(output_name).first);
|
||||
@ -2522,12 +2522,11 @@ TEST(CAPI, TestTensorIsNotAligned) {
|
||||
|
||||
// Take an unaligned slice.
|
||||
Tensor y = x.Slice(1, 13);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Tensor* a = TF_TensorFromTensor(y, status);
|
||||
Status status;
|
||||
TF_Tensor* a = TF_TensorFromTensor(y, &status);
|
||||
if (EIGEN_MAX_ALIGN_BYTES > 0) {
|
||||
EXPECT_FALSE(TF_TensorIsAligned(a));
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
TF_DeleteTensor(a);
|
||||
}
|
||||
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
#include <memory.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <sys/time.h>
|
||||
#include <time.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
@ -58,12 +58,8 @@ int main(int argc, char** argv) {
|
||||
}
|
||||
|
||||
char file_name[100];
|
||||
struct timeval t;
|
||||
if (gettimeofday(&t, NULL)) {
|
||||
perror("gettimeofday failed");
|
||||
return 1;
|
||||
}
|
||||
snprintf(file_name, sizeof(file_name), "test-%d-%ld.txt", getpid(), t.tv_sec);
|
||||
time_t t = time(NULL);
|
||||
snprintf(file_name, sizeof(file_name), "test-%d-%ld.txt", getpid(), t);
|
||||
|
||||
size_t length = 2 + strlen(path) + strlen(file_name);
|
||||
char* full_path = malloc(length);
|
||||
|
@ -26,8 +26,8 @@ tf_cuda_library(
|
||||
"c_api.cc",
|
||||
"c_api_debug.cc",
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.cc",
|
||||
"c_api_internal.h",
|
||||
"tensor_handle_interface.h",
|
||||
],
|
||||
hdrs = ["c_api.h"],
|
||||
copts = tf_copts() + tfe_xla_copts(),
|
||||
@ -93,6 +93,7 @@ filegroup(
|
||||
srcs = [
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
"tensor_handle_interface.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
@ -102,7 +103,10 @@ filegroup(
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api_internal",
|
||||
srcs = ["c_api_experimental.h"],
|
||||
srcs = [
|
||||
"c_api_experimental.h",
|
||||
"tensor_handle_interface.h",
|
||||
],
|
||||
hdrs = ["c_api_internal.h"],
|
||||
visibility = [
|
||||
"//learning/deepmind/courier:__subpackages__",
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
@ -81,6 +82,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
@ -93,10 +95,8 @@ using tensorflow::string;
|
||||
namespace {
|
||||
|
||||
const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) {
|
||||
if (op->inference_ctx) {
|
||||
return op->inference_ctx->op_def;
|
||||
}
|
||||
const tensorflow::OpDef* op_def;
|
||||
const tensorflow::OpDef* op_def = op->operation.OpDef();
|
||||
if (op_def) return op_def;
|
||||
status->status =
|
||||
tensorflow::OpDefForOp(op->operation.Name().c_str(), &op_def);
|
||||
return op_def;
|
||||
@ -409,6 +409,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
|
||||
// New server created for new server_def. Unused if updating server_def.
|
||||
std::unique_ptr<tensorflow::ServerInterface> new_server;
|
||||
tensorflow::EagerContext* context = ctx->context;
|
||||
tensorflow::GrpcServer* grpc_server;
|
||||
if (reset_context) {
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
|
||||
@ -416,26 +417,25 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
|
||||
} else {
|
||||
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(
|
||||
ctx->context->GetServer(), worker_name, &curr_remote_workers));
|
||||
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
|
||||
&curr_remote_workers));
|
||||
// No need to check the cast here, since `ListRemoteWorkers` already checks
|
||||
// if the server is a GRPC server or not.
|
||||
grpc_server =
|
||||
dynamic_cast<tensorflow::GrpcServer*>(ctx->context->GetServer());
|
||||
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
|
||||
}
|
||||
|
||||
tensorflow::uint64 context_id = ctx->context->GetContextId();
|
||||
tensorflow::uint64 context_view_id = ctx->context->GetContextViewId();
|
||||
tensorflow::uint64 context_id = context->GetContextId();
|
||||
tensorflow::uint64 context_view_id = context->GetContextViewId();
|
||||
if (reset_context) {
|
||||
context_id = tensorflow::EagerContext::NewContextId();
|
||||
context_view_id = 0;
|
||||
// Make master eager context accessible by local eager service, which might
|
||||
// receive send tensor requests from remote workers.
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->AddMasterEagerContextToEagerService(
|
||||
context_id, ctx->context));
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->AddMasterEagerContextToEagerService(context_id, context));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
|
||||
@ -464,11 +464,11 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
&new_remote_device_mgr));
|
||||
remote_device_mgr = new_remote_device_mgr.get();
|
||||
} else {
|
||||
ctx->context->ClearCachesAndDefaultExecutor();
|
||||
context->ClearCachesAndDefaultExecutor();
|
||||
// TODO(b/143914772): Potential memory leak if rendezvous has pending
|
||||
// tensors for removed / replaced workers.
|
||||
|
||||
remote_device_mgr = ctx->context->GetOwnedRemoteDeviceMgr();
|
||||
remote_device_mgr = context->GetOwnedRemoteDeviceMgr();
|
||||
if (remote_device_mgr == nullptr) {
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::errors::InvalidArgument(
|
||||
"Updating context with an invalid set of remote devices."));
|
||||
@ -479,8 +479,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
&added_workers, &removed_workers,
|
||||
&existing_workers);
|
||||
LOG_AND_RETURN_IF_ERROR(GetReplacedFromExistingWorkers(
|
||||
&existing_workers, context_id, ctx->context->GetContextViewId(),
|
||||
server_def, remote_eager_workers.get(), &replaced_workers));
|
||||
&existing_workers, context_id, context->GetContextViewId(), server_def,
|
||||
remote_eager_workers.get(), &replaced_workers));
|
||||
if (VLOG_IS_ON(1)) {
|
||||
VLOG(1) << "Updating cluster with following changes";
|
||||
for (const string& w : added_workers) VLOG(1) << " Added worker " << w;
|
||||
@ -516,7 +516,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
grpc_server->worker_env()->device_mgr->ListDeviceAttributes(
|
||||
&local_device_attributes);
|
||||
|
||||
// This request make sure that we can create Rendevzous properly between
|
||||
// This request make sure that we can create Rendezvous properly between
|
||||
// Local and Remote context.
|
||||
tensorflow::eager::CreateContextRequest base_request;
|
||||
for (const auto& da : cluster_device_attributes) {
|
||||
@ -534,9 +534,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
if (reset_context) {
|
||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
||||
remote_workers, context_id, context_view_id, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(),
|
||||
ctx->context->Executor().Async(),
|
||||
ctx->context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||
context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||
} else {
|
||||
// The master's context_view_id will be incremented by one
|
||||
// the UpdateRemoteMaster call later. We want all new workers and
|
||||
@ -545,9 +544,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
// context_view_id + 1.
|
||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
||||
added_workers, context_id, context_view_id + 1, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(),
|
||||
ctx->context->Executor().Async(),
|
||||
ctx->context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||
context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||
if (!existing_workers.empty()) {
|
||||
if (VLOG_IS_ON(1)) {
|
||||
for (const string& w : existing_workers) {
|
||||
@ -578,12 +576,12 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
|
||||
|
||||
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
|
||||
tensorflow::eager::CreateClusterFLR(context_id, ctx->context,
|
||||
tensorflow::eager::CreateClusterFLR(context_id, context,
|
||||
worker_session.get());
|
||||
auto remote_mgr = absl::make_unique<tensorflow::eager::RemoteMgr>(
|
||||
/*is_master=*/true, ctx->context);
|
||||
/*is_master=*/true, context);
|
||||
|
||||
LOG_AND_RETURN_IF_ERROR(ctx->context->InitializeRemoteMaster(
|
||||
LOG_AND_RETURN_IF_ERROR(context->InitializeRemoteMaster(
|
||||
std::move(new_server), grpc_server->worker_env(), worker_session,
|
||||
std::move(remote_eager_workers), std::move(new_remote_device_mgr),
|
||||
remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr,
|
||||
@ -601,9 +599,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
|
||||
session_name, &worker_session));
|
||||
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
|
||||
tensorflow::eager::CreateClusterFLR(context_id, ctx->context,
|
||||
tensorflow::eager::CreateClusterFLR(context_id, context,
|
||||
worker_session.get());
|
||||
LOG_AND_RETURN_IF_ERROR(ctx->context->UpdateRemoteMaster(
|
||||
LOG_AND_RETURN_IF_ERROR(context->UpdateRemoteMaster(
|
||||
grpc_server->worker_env(), std::move(remote_eager_workers),
|
||||
added_workers, removed_workers, context_id, r, device_mgr,
|
||||
keep_alive_secs, cluster_flr));
|
||||
@ -614,77 +612,6 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
}
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
|
||||
tensorflow::Status OpInferSingleInputAttrs(TFE_Op* op,
|
||||
TFE_TensorHandle* input) {
|
||||
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
|
||||
const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++);
|
||||
if (!input_def.number_attr().empty() || !input_def.type_list_attr().empty()) {
|
||||
// Some clients that are still setting their input attributes manually are
|
||||
// adding input list to their op by calling `TFE_OpAddInput` for each of
|
||||
// its elements instead of calling `TFE_OpAddInputList`. When this happens,
|
||||
// we cannot detect the end of such list, thus lose track of the input
|
||||
// arguments in the op definition. To guarantee backward compatibility with
|
||||
// those clients, disable automatic inference in this case.
|
||||
op->inference_ctx.reset(nullptr);
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
const std::string& type_attr = input_def.type_attr();
|
||||
if (!type_attr.empty() && ictx->attrs.find(type_attr) == ictx->attrs.end()) {
|
||||
op->operation.MutableAttrs()->Set(type_attr, input->handle->dtype);
|
||||
ictx->attrs.insert(type_attr);
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
void OpInferSingleTypeInputListAttrs(TFE_Op* op,
|
||||
const tensorflow::OpDef::ArgDef& input_def,
|
||||
TFE_TensorHandle** inputs,
|
||||
int num_inputs) {
|
||||
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
|
||||
if (ictx->attrs.find(input_def.number_attr()) == ictx->attrs.end()) {
|
||||
op->operation.MutableAttrs()->Set(input_def.number_attr(), num_inputs);
|
||||
ictx->attrs.insert(input_def.number_attr());
|
||||
}
|
||||
if (ictx->attrs.find(input_def.type_attr()) == ictx->attrs.end()) {
|
||||
op->operation.MutableAttrs()->Set(input_def.type_attr(),
|
||||
inputs[0]->handle->dtype);
|
||||
ictx->attrs.insert(input_def.type_attr());
|
||||
}
|
||||
}
|
||||
|
||||
void OpInferMixedTypeInputListAttrs(TFE_Op* op,
|
||||
const tensorflow::OpDef::ArgDef& input_def,
|
||||
TFE_TensorHandle** inputs, int num_inputs) {
|
||||
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
|
||||
if (ictx->attrs.find(input_def.type_list_attr()) == ictx->attrs.end()) {
|
||||
std::unique_ptr<tensorflow::DataType[]> dtypes(
|
||||
new tensorflow::DataType[num_inputs]);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
dtypes[i] = inputs[i]->handle->dtype;
|
||||
}
|
||||
op->operation.MutableAttrs()->Set(
|
||||
input_def.type_list_attr(),
|
||||
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(dtypes.get(),
|
||||
num_inputs));
|
||||
ictx->attrs.insert(input_def.type_list_attr());
|
||||
}
|
||||
}
|
||||
|
||||
tensorflow::Status OpInferInputListAttrs(TFE_Op* op, TFE_TensorHandle** inputs,
|
||||
int num_inputs) {
|
||||
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
|
||||
const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++);
|
||||
if (!input_def.type_list_attr().empty()) {
|
||||
OpInferMixedTypeInputListAttrs(op, input_def, inputs, num_inputs);
|
||||
} else if (!input_def.type_attr().empty() &&
|
||||
!input_def.number_attr().empty()) {
|
||||
OpInferSingleTypeInputListAttrs(op, input_def, inputs, num_inputs);
|
||||
} else {
|
||||
return tensorflow::errors::InvalidArgument("Invalid input list definition");
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
@ -720,12 +647,14 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
tensorflow::Rendezvous* r =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||
|
||||
return new TFE_Context(opts->session_options.options,
|
||||
opts->device_placement_policy, opts->mirroring_policy,
|
||||
opts->async, opts->lazy_remote_inputs_copy,
|
||||
device_mgr.release(),
|
||||
/*device_mgr_owned*/ true, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator());
|
||||
return new TFE_Context{new tensorflow::EagerContext(
|
||||
opts->session_options.options,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
opts->device_placement_policy),
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
|
||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
|
||||
/*device_mgr_owned*/ true, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator())};
|
||||
}
|
||||
|
||||
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
|
||||
@ -736,22 +665,28 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
|
||||
tensorflow::Rendezvous* r =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr);
|
||||
|
||||
return new TFE_Context(opts->session_options.options,
|
||||
opts->device_placement_policy, opts->mirroring_policy,
|
||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
|
||||
/*device_mgr_owned*/ false, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator());
|
||||
return new TFE_Context{new tensorflow::EagerContext(
|
||||
opts->session_options.options,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
opts->device_placement_policy),
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
|
||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
|
||||
/*device_mgr_owned*/ false, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator())};
|
||||
}
|
||||
|
||||
void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; }
|
||||
void TFE_DeleteContext(TFE_Context* ctx) {
|
||||
// context->RefCountIsOne() should be true here.
|
||||
// TODO(iga): Remove EagerContext refcounting.
|
||||
ctx->context->Unref();
|
||||
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
|
||||
TF_DeviceList* list = new TF_DeviceList;
|
||||
ctx->context->local_device_mgr()->ListDeviceAttributes(&list->response);
|
||||
if (ctx->context->remote_device_mgr()) {
|
||||
ctx->context->remote_device_mgr()->ListDeviceAttributes(&list->response);
|
||||
}
|
||||
return list;
|
||||
TF_DeviceList* l = new TF_DeviceList;
|
||||
ctx->context->ListDevices(&l->response);
|
||||
return l;
|
||||
}
|
||||
|
||||
void TFE_ContextClearCaches(TFE_Context* ctx) {
|
||||
@ -812,8 +747,9 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
"TFE_ContextSetServerDef not supported on mobile");
|
||||
return false;
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
tensorflow::EagerContext* context = ctx->context;
|
||||
tensorflow::GrpcServer* grpc_server =
|
||||
static_cast<tensorflow::GrpcServer*>(ctx->context->GetServer());
|
||||
static_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||
|
||||
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
|
||||
status->status = grpc_server->master_env()->worker_cache->GetEagerClientCache(
|
||||
@ -832,7 +768,7 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
|
||||
// Send a rpc request to the worker to check aliveness.
|
||||
tensorflow::eager::KeepAliveRequest request;
|
||||
request.set_context_id(ctx->context->GetContextId());
|
||||
request.set_context_id(context->GetContextId());
|
||||
tensorflow::eager::KeepAliveResponse response;
|
||||
|
||||
tensorflow::Status keep_alive_status;
|
||||
@ -887,108 +823,180 @@ void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
|
||||
if (h == nullptr) return;
|
||||
tensorflow::profiler::TraceMe activity(
|
||||
"TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
|
||||
VLOG(1) << "Deleting tensor handle " << h << " with internal handle "
|
||||
<< h->handle;
|
||||
if (h->handle) {
|
||||
h->handle->Unref();
|
||||
}
|
||||
delete h;
|
||||
}
|
||||
|
||||
tensorflow::TensorHandleInterface::~TensorHandleInterface() {
|
||||
VLOG(1) << "Deleting tensor handle " << this << " with internal handle "
|
||||
<< handle_;
|
||||
if (handle_) {
|
||||
handle_->Unref();
|
||||
}
|
||||
}
|
||||
|
||||
bool tensorflow::TensorHandleInterface::IsValid(Status* status) const {
|
||||
if (handle_ == nullptr) {
|
||||
*status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
|
||||
return static_cast<TF_DataType>(h->handle->dtype);
|
||||
return h->handle->DataType();
|
||||
}
|
||||
|
||||
TF_DataType tensorflow::TensorHandleInterface::DataType() const {
|
||||
return static_cast<TF_DataType>(handle_->dtype);
|
||||
}
|
||||
|
||||
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return -1;
|
||||
}
|
||||
|
||||
return h->handle->NumDims(&status->status);
|
||||
}
|
||||
|
||||
int tensorflow::TensorHandleInterface::NumDims(Status* status) const {
|
||||
if (!IsValid(status)) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
int result;
|
||||
status->status = h->handle->NumDims(&result);
|
||||
*status = handle_->NumDims(&result);
|
||||
return result;
|
||||
}
|
||||
|
||||
int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return -1;
|
||||
}
|
||||
|
||||
return h->handle->NumElements(&status->status);
|
||||
}
|
||||
|
||||
int64_t tensorflow::TensorHandleInterface::NumElements(Status* status) const {
|
||||
if (!IsValid(status)) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
tensorflow::int64 result;
|
||||
status->status = h->handle->NumElements(&result);
|
||||
*status = handle_->NumElements(&result);
|
||||
return result;
|
||||
}
|
||||
|
||||
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
|
||||
TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return -1;
|
||||
}
|
||||
|
||||
return h->handle->Dim(dim_index, &status->status);
|
||||
}
|
||||
|
||||
int64_t tensorflow::TensorHandleInterface::Dim(int dim_index,
|
||||
Status* status) const {
|
||||
if (!IsValid(status)) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
tensorflow::int64 result;
|
||||
status->status = h->handle->Dim(dim_index, &result);
|
||||
*status = handle_->Dim(dim_index, &result);
|
||||
return result;
|
||||
}
|
||||
|
||||
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::Device* d = h->handle->op_device();
|
||||
return h->handle->DeviceName(&status->status);
|
||||
}
|
||||
|
||||
const char* tensorflow::TensorHandleInterface::DeviceName(
|
||||
Status* status) const {
|
||||
if (!IsValid(status)) {
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::Device* d = handle_->op_device();
|
||||
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
: d->name().c_str();
|
||||
}
|
||||
|
||||
const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
|
||||
TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::Device* d = h->handle->device();
|
||||
return h->handle->BackingDeviceName(&status->status);
|
||||
}
|
||||
|
||||
const char* tensorflow::TensorHandleInterface::BackingDeviceName(
|
||||
Status* status) const {
|
||||
if (!IsValid(status)) {
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::Device* d = handle_->device();
|
||||
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
: d->name().c_str();
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
|
||||
TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr || !h->handle->IsValid(&status->status)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
h->handle->Ref();
|
||||
return new TFE_TensorHandle{
|
||||
std::unique_ptr<AbstractTensorHandleInterface>(h->handle->Copy())};
|
||||
}
|
||||
|
||||
return new TFE_TensorHandle(h->handle);
|
||||
AbstractTensorHandleInterface* tensorflow::TensorHandleInterface::Copy() {
|
||||
handle_->Ref();
|
||||
return new TensorHandleInterface(handle_);
|
||||
}
|
||||
|
||||
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::TensorHandle* handle = h->handle;
|
||||
|
||||
return h->handle->Resolve(&status->status);
|
||||
}
|
||||
|
||||
TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
|
||||
if (!IsValid(status)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// TODO(agarwal): move this implementation inside TFE_TensorHandle.
|
||||
if (handle->IsRemote()) {
|
||||
if (handle_->IsRemote()) {
|
||||
const tensorflow::Tensor* t = nullptr;
|
||||
tensorflow::TensorHandle* h_cpu = nullptr;
|
||||
status->status = EagerCopyToDevice(
|
||||
handle, handle->Context(), &handle->Context()->Executor(),
|
||||
handle->Context()->HostCPU(), false, &h_cpu);
|
||||
if (!status->status.ok()) {
|
||||
*status = EagerCopyToDevice(handle_, handle_->Context(),
|
||||
&handle_->Context()->Executor(),
|
||||
handle_->Context()->HostCPU(), false, &h_cpu);
|
||||
if (!status->ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
status->status = h_cpu->Tensor(&t);
|
||||
if (!status->status.ok()) {
|
||||
*status = h_cpu->Tensor(&t);
|
||||
if (!status->ok()) {
|
||||
h_cpu->Unref();
|
||||
return nullptr;
|
||||
}
|
||||
@ -997,28 +1005,30 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
return retval;
|
||||
} else {
|
||||
tensorflow::Tensor tensor;
|
||||
if (IsCPU(handle->device())) {
|
||||
if (IsCPU(handle_->device())) {
|
||||
const tensorflow::Tensor* src = nullptr;
|
||||
status->status = handle->Tensor(&src);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
*status = handle_->Tensor(&src);
|
||||
if (!status->ok()) return nullptr;
|
||||
tensor = *src;
|
||||
} else {
|
||||
tensorflow::EagerContext* ctx = handle->Context();
|
||||
tensorflow::EagerContext* ctx = handle_->Context();
|
||||
CHECK_NE(ctx, nullptr);
|
||||
status->status = h->handle->CopyToDevice(ctx, ctx->HostCPU(), &tensor);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
*status = handle_->CopyToDevice(*ctx, ctx->HostCPU(), &tensor);
|
||||
if (!status->ok()) return nullptr;
|
||||
}
|
||||
return tensorflow::TF_TensorFromTensor(tensor, status);
|
||||
}
|
||||
}
|
||||
|
||||
void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr || !h->handle->IsValid(&status->status)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::TensorHandle* handle = h->handle;
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||
->Handle();
|
||||
|
||||
if (handle->IsRemote()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
@ -1047,7 +1057,8 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
void (*deallocator)(void* data, size_t len, void* arg),
|
||||
void* deallocator_arg, TF_Status* status) {
|
||||
tensorflow::Device* device;
|
||||
status->status = ctx->context->FindDeviceFromName(device_name, &device);
|
||||
tensorflow::EagerContext* context = ctx->context;
|
||||
status->status = context->FindDeviceFromName(device_name, &device);
|
||||
if (!status->status.ok()) {
|
||||
deallocator(data, len, deallocator_arg);
|
||||
return nullptr;
|
||||
@ -1075,11 +1086,12 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
buf->Unref();
|
||||
tensorflow::TensorHandle* ret_handle;
|
||||
status->status = tensorflow::TensorHandle::CreateLocalHandle(
|
||||
t, device, ctx->context, &ret_handle);
|
||||
t, device, context, &ret_handle);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TFE_TensorHandle(ret_handle);
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(ret_handle)};
|
||||
}
|
||||
|
||||
// This function will block till the operation that produces `h` has
|
||||
@ -1087,12 +1099,14 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
// bytes of the memory pointed to by the device pointer returned above.
|
||||
size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
||||
TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr || !h->handle->IsValid(&status->status)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return 0;
|
||||
}
|
||||
tensorflow::TensorHandle* handle = h->handle;
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||
->Handle();
|
||||
|
||||
if (handle->IsRemote()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
@ -1110,8 +1124,14 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
||||
|
||||
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status) {
|
||||
return NewOrResetOp(ctx, op_or_function_name, nullptr, status,
|
||||
/* op_to_reset= */ nullptr);
|
||||
std::unique_ptr<TFE_Op> new_op(
|
||||
new TFE_Op{tensorflow::EagerOperation(ctx->context)});
|
||||
status->status =
|
||||
new_op->operation.Reset(op_or_function_name, nullptr, false, nullptr);
|
||||
if (!status->status.ok()) {
|
||||
new_op.reset();
|
||||
}
|
||||
return new_op.release();
|
||||
}
|
||||
|
||||
void TFE_DeleteOp(TFE_Op* op) { delete op; }
|
||||
@ -1122,7 +1142,7 @@ void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
|
||||
|
||||
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
|
||||
tensorflow::Device* device = (op->operation.Device() == nullptr)
|
||||
? op->operation.EagerContext()->HostCPU()
|
||||
? op->operation.EagerContext().HostCPU()
|
||||
: op->operation.Device();
|
||||
return device->name().c_str();
|
||||
}
|
||||
@ -1136,20 +1156,23 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
|
||||
}
|
||||
|
||||
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
|
||||
op->operation.AddInput(input->handle);
|
||||
if (op->inference_ctx) {
|
||||
status->status = OpInferSingleInputAttrs(op, input);
|
||||
}
|
||||
tensorflow::TensorHandle* h =
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
input->handle.get())
|
||||
->Handle();
|
||||
op->operation.AddInput(h);
|
||||
status->status = op->operation.MaybeInferSingleInputAttrs(h);
|
||||
}
|
||||
|
||||
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
|
||||
TF_Status* status) {
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
op->operation.AddInput(inputs[i]->handle);
|
||||
}
|
||||
if (op->inference_ctx) {
|
||||
status->status = OpInferInputListAttrs(op, inputs, num_inputs);
|
||||
op->operation.AddInput(
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
inputs[i]->handle.get())
|
||||
->Handle());
|
||||
}
|
||||
status->status = op->operation.InferInputListAttrs(num_inputs);
|
||||
}
|
||||
|
||||
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
||||
@ -1382,15 +1405,16 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
|
||||
|
||||
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||
TF_Status* status) {
|
||||
VLOG(1) << "Calling TFE_Execute() on op " << op;
|
||||
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
|
||||
VLOG(1) << "Calling TFE_Execute() on op " << op;
|
||||
status->status = tensorflow::EagerExecute(&op->operation,
|
||||
handle_retvals.data(), num_retvals);
|
||||
if (!status->status.ok()) {
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals[i] = new TFE_TensorHandle(handle_retvals[i]);
|
||||
retvals[i] = new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(handle_retvals[i])};
|
||||
}
|
||||
}
|
||||
|
||||
@ -1400,15 +1424,18 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
TF_Status* status) {
|
||||
tensorflow::TensorHandle* handle = nullptr;
|
||||
tensorflow::Device* device;
|
||||
status->status = ctx->context->FindDeviceFromName(device_name, &device);
|
||||
tensorflow::EagerContext* context = ctx->context;
|
||||
status->status = context->FindDeviceFromName(device_name, &device);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context,
|
||||
&ctx->context->Executor(),
|
||||
device, false, &handle);
|
||||
status->status = tensorflow::EagerCopyToDevice(
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||
->Handle(),
|
||||
context, &context->Executor(), device, false, &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle(handle);
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
@ -1456,11 +1483,12 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
|
||||
|
||||
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
|
||||
TF_Status* status) {
|
||||
status->status = ctx->context->Executor().WaitForAllPendingNodes();
|
||||
tensorflow::EagerContext* context = ctx->context;
|
||||
status->status = context->Executor().WaitForAllPendingNodes();
|
||||
if (!status->status.ok()) return;
|
||||
tensorflow::mutex_lock ml(*ctx->context->MetadataMu());
|
||||
status->status = MessageToBuffer(*ctx->context->RunMetadataProto(), buf);
|
||||
ctx->context->ClearRunMetadata();
|
||||
tensorflow::mutex_lock ml(*context->MetadataMu());
|
||||
status->status = MessageToBuffer(*context->RunMetadataProto(), buf);
|
||||
context->ClearRunMetadata();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -206,14 +206,14 @@ typedef struct TFE_TensorDebugInfo TFE_TensorDebugInfo;
|
||||
// error and nullptr is returned. This function can block till the operation
|
||||
// that produces `handle` has completed.
|
||||
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
TFE_TensorHandle* handle, TF_Status* status);
|
||||
TFE_TensorHandle* h, TF_Status* status);
|
||||
|
||||
// Deletes `debug_info`.
|
||||
TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo(
|
||||
TFE_TensorDebugInfo* debug_info);
|
||||
|
||||
// Returns the number of dimensions used to represent the tensor on its device.
|
||||
// The number of dimensions used to reprensent the tensor on device can be
|
||||
// The number of dimensions used to represent the tensor on device can be
|
||||
// different from the number returned by TFE_TensorHandleNumDims.
|
||||
// The return value was current at the time of TFE_TensorDebugInfo creation.
|
||||
TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims(
|
||||
|
@ -28,19 +28,22 @@ using tensorflow::string;
|
||||
|
||||
namespace {
|
||||
|
||||
std::vector<int64> TensorShapeAsVector(TFE_TensorHandle* handle,
|
||||
TF_Status* status) {
|
||||
std::vector<int64> TensorShapeAsVector(const tensorflow::TensorHandle& handle,
|
||||
tensorflow::Status* status) {
|
||||
std::vector<int64> shape;
|
||||
int rank = TFE_TensorHandleNumDims(handle, status);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
int rank = -1;
|
||||
*status = handle.NumDims(&rank);
|
||||
if (!status->ok()) {
|
||||
return shape;
|
||||
}
|
||||
shape.reserve(rank);
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
shape.push_back(TFE_TensorHandleDim(handle, i, status));
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
tensorflow::int64 dim;
|
||||
*status = handle.Dim(i, &dim);
|
||||
if (!status->ok()) {
|
||||
return shape;
|
||||
}
|
||||
shape.push_back(dim);
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
@ -50,15 +53,20 @@ std::vector<int64> TensorShapeAsVector(TFE_TensorHandle* handle,
|
||||
extern "C" {
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
TFE_TensorHandle* handle, TF_Status* status) {
|
||||
TFE_TensorHandle* h, TF_Status* status) {
|
||||
return h->handle->TensorDebugInfo(&status->status);
|
||||
}
|
||||
|
||||
TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
|
||||
Status* status) {
|
||||
const tensorflow::Tensor* tensor;
|
||||
status->status = handle->handle->Tensor(&tensor);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
*status = handle_->Tensor(&tensor);
|
||||
if (!status->ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
tensorflow::Device* device = handle->handle->device();
|
||||
tensorflow::Device* device = handle_->device();
|
||||
|
||||
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
|
||||
tensorflow::XlaDevice* xla_device =
|
||||
@ -67,15 +75,15 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
tensorflow::XlaDevice::PaddedShapeFn shape_fn =
|
||||
xla_device->metadata().padded_shape_fn();
|
||||
xla::Shape padded_shape;
|
||||
status->status = shape_fn(*tensor, &padded_shape);
|
||||
if (!status->status.ok()) {
|
||||
*status = shape_fn(*tensor, &padded_shape);
|
||||
if (!status->ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (VLOG_IS_ON(3)) {
|
||||
std::vector<int64> shape_to_log = TensorShapeAsVector(handle, status);
|
||||
if (!status->status.ok()) {
|
||||
std::vector<int64> shape_to_log = TensorShapeAsVector(*handle_, status);
|
||||
if (!status->ok()) {
|
||||
// Ignore the status here as we are simply logging.
|
||||
status->status = tensorflow::Status::OK();
|
||||
*status = tensorflow::Status::OK();
|
||||
} else {
|
||||
VLOG(3) << "Fully padded shape of ["
|
||||
<< absl::StrJoin(shape_to_log, ", ") << "] is "
|
||||
@ -88,7 +96,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
// Currently, the only case of XlaTensor containing a tuple shape is to
|
||||
// represent 64 bit ints, doubles, and complex numbers (we don't support
|
||||
// 64bit complex numbers).
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
*status = tensorflow::errors::InvalidArgument(
|
||||
"XlaTensors should only contain tuples of size 2. Shape: ",
|
||||
padded_shape.DebugString());
|
||||
return nullptr;
|
||||
@ -100,13 +108,13 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
const xla::Shape& shape1 =
|
||||
xla::ShapeUtil::GetTupleElementShape(padded_shape, 1);
|
||||
if (shape0.IsTuple() || shape1.IsTuple()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
*status = tensorflow::errors::InvalidArgument(
|
||||
"XlaTensors should not contain nested tuples. Shape: ",
|
||||
padded_shape.DebugString());
|
||||
return nullptr;
|
||||
}
|
||||
if (!xla::ShapeUtil::Equal(shape0, shape1)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
*status = tensorflow::errors::InvalidArgument(
|
||||
"Subshapes of XlaTensors should be the same. Shape: ",
|
||||
padded_shape.DebugString());
|
||||
return nullptr;
|
||||
@ -131,15 +139,15 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
dev_dims.push_back(padded_shape.dimensions(dim_index));
|
||||
}
|
||||
}
|
||||
status->status = tensorflow::Status::OK();
|
||||
*status = tensorflow::Status::OK();
|
||||
return new TFE_TensorDebugInfo(dev_dims);
|
||||
}
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
|
||||
// If the tensor is not an XLA tensor, the device shape is
|
||||
// the same as regular tensor shape.
|
||||
std::vector<int64> dev_dims = TensorShapeAsVector(handle, status);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
std::vector<int64> dev_dims = TensorShapeAsVector(*handle_, status);
|
||||
if (!status->ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TFE_TensorDebugInfo(dev_dims);
|
||||
|
@ -22,18 +22,18 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/profiler/rpc/client/capture_profile.h"
|
||||
#include "tensorflow/core/profiler/rpc/profiler_server.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name,
|
||||
const char* raw_device_name, TF_Status* status,
|
||||
TFE_Op* op_to_reset) {
|
||||
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
||||
const char* raw_device_name, TF_Status* status) {
|
||||
if (op_to_reset) {
|
||||
NewOrResetOp(ctx, op_or_function_name, raw_device_name, status,
|
||||
op_to_reset);
|
||||
status->status = op_to_reset->operation.Reset(
|
||||
op_or_function_name, raw_device_name, false, nullptr);
|
||||
} else {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"op_to_reset should not be nullptr");
|
||||
@ -41,7 +41,9 @@ void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name,
|
||||
}
|
||||
|
||||
void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
|
||||
op->operation.ConsumeInput(h->handle);
|
||||
op->operation.ConsumeInput(
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||
->Handle());
|
||||
}
|
||||
|
||||
TFE_Profiler* TFE_NewProfiler() { return new TFE_Profiler(); }
|
||||
|
@ -29,10 +29,10 @@ extern "C" {
|
||||
// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster
|
||||
// than seperately calling it because if the existing op has the same
|
||||
// `raw_device_name`, it skips parsing and just leave as it is.
|
||||
TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Context* ctx,
|
||||
TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset,
|
||||
const char* op_or_function_name,
|
||||
const char* raw_device_name,
|
||||
TF_Status* status, TFE_Op* op_to_reset);
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
|
||||
TF_Status* status);
|
||||
|
@ -1,66 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/host_info.h"
|
||||
|
||||
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
const char* raw_device_name, TF_Status* status,
|
||||
TFE_Op* op_to_reset) {
|
||||
const char* name = op_or_function_name; // Shorthand
|
||||
const tensorflow::AttrTypeMap* types;
|
||||
bool is_function = false;
|
||||
status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (op_to_reset && op_to_reset->ctx != ctx) {
|
||||
status->status = tensorflow::errors::Internal(
|
||||
"Cannot reset a TFE_Op from another TFE_Context");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
|
||||
if (!is_function) {
|
||||
const tensorflow::OpDef* op_def;
|
||||
status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
inference_ctx.reset(new TFE_OpInferenceContext(op_def));
|
||||
} else if (!ctx->context->FindFunctionByName(name)) {
|
||||
status->status = tensorflow::errors::NotFound(
|
||||
"'", name,
|
||||
"' is neither a type of a primitive operation nor a name "
|
||||
"of a function registered in binary running on ",
|
||||
tensorflow::port::Hostname(),
|
||||
". Make sure the operation or function is "
|
||||
"registered in the binary running in this process.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (op_to_reset) {
|
||||
status->status = op_to_reset->Reset(
|
||||
name, is_function, types, raw_device_name, std::move(inference_ctx));
|
||||
return op_to_reset;
|
||||
}
|
||||
|
||||
TFE_Op* new_op =
|
||||
new TFE_Op(ctx, name, is_function, types, std::move(inference_ctx));
|
||||
status->status = new_op->operation.SetDeviceName(raw_device_name);
|
||||
return new_op;
|
||||
}
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
@ -62,36 +63,10 @@ struct TFE_ContextOptions {
|
||||
};
|
||||
|
||||
struct TFE_Context {
|
||||
TFE_Context(const tensorflow::SessionOptions& opts,
|
||||
TFE_ContextDevicePlacementPolicy default_device_placement_policy,
|
||||
TFE_ContextMirroringPolicy default_mirroring_policy, bool async,
|
||||
const bool lazy_remote_inputs_copy,
|
||||
const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned,
|
||||
tensorflow::Rendezvous* rendezvous,
|
||||
const tensorflow::CustomKernelCreator* custom_kernel_creator)
|
||||
: context(new tensorflow::EagerContext(
|
||||
opts,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
default_device_placement_policy),
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(
|
||||
default_mirroring_policy),
|
||||
async, lazy_remote_inputs_copy, device_mgr, device_mgr_owned,
|
||||
rendezvous, custom_kernel_creator)) {}
|
||||
|
||||
~TFE_Context() {
|
||||
// TODO(iga): Add a separate API method to shutdown TFE_Context so that we
|
||||
// don't send RPCs and block in destructor.
|
||||
context->WaitForAndCloseRemoteContexts();
|
||||
// context->RefCountIsOne() should be true here.
|
||||
// TODO(iga): Remove EagerContext refcounting.
|
||||
context->Unref();
|
||||
}
|
||||
|
||||
tensorflow::EagerContext* context;
|
||||
};
|
||||
|
||||
struct TFE_TensorHandle {
|
||||
explicit TFE_TensorHandle(tensorflow::TensorHandle* h) : handle(h) {}
|
||||
static TFE_TensorHandle* CreateLocalHandle(const class tensorflow::Tensor& t,
|
||||
TF_Status* s) {
|
||||
tensorflow::TensorHandle* handle;
|
||||
@ -99,10 +74,11 @@ struct TFE_TensorHandle {
|
||||
if (!s->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TFE_TensorHandle(handle);
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
|
||||
}
|
||||
|
||||
tensorflow::TensorHandle* handle;
|
||||
std::unique_ptr<AbstractTensorHandleInterface> handle;
|
||||
};
|
||||
|
||||
struct TFE_TensorDebugInfo {
|
||||
@ -113,46 +89,10 @@ struct TFE_TensorDebugInfo {
|
||||
std::vector<tensorflow::int64> dev_dims;
|
||||
};
|
||||
|
||||
struct TFE_OpInferenceContext {
|
||||
explicit TFE_OpInferenceContext(const tensorflow::OpDef* op_def)
|
||||
: op_def(op_def) {}
|
||||
|
||||
const tensorflow::OpDef* op_def; // op definition from protobuf
|
||||
int input_arg_idx = 0; // arg definition index for the next input to be added
|
||||
tensorflow::gtl::FlatSet<std::string> attrs; // attributes inferred so far
|
||||
};
|
||||
|
||||
struct TFE_Op {
|
||||
TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
|
||||
const tensorflow::AttrTypeMap* t,
|
||||
std::unique_ptr<TFE_OpInferenceContext> inference_ctx)
|
||||
: ctx(ctx),
|
||||
operation(ctx->context, op, is_function, t),
|
||||
inference_ctx(std::move(inference_ctx)) {}
|
||||
|
||||
void Clear() {
|
||||
operation.Clear();
|
||||
inference_ctx.reset();
|
||||
}
|
||||
|
||||
tensorflow::Status Reset(const char* op, bool is_function,
|
||||
const tensorflow::AttrTypeMap* t,
|
||||
const char* raw_device_name,
|
||||
std::unique_ptr<TFE_OpInferenceContext> infer_ctx) {
|
||||
inference_ctx = std::move(infer_ctx);
|
||||
return operation.Reset(ctx->context, op, is_function, t, raw_device_name,
|
||||
nullptr);
|
||||
}
|
||||
|
||||
TFE_Context* ctx;
|
||||
tensorflow::EagerOperation operation;
|
||||
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
|
||||
};
|
||||
|
||||
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
const char* raw_device_name, TF_Status* status,
|
||||
TFE_Op* op_to_reset = nullptr);
|
||||
|
||||
struct TFE_Profiler {
|
||||
explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); }
|
||||
|
||||
|
@ -1362,10 +1362,11 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
|
||||
TFE_TensorHandle* inputs[] = {input1, input2};
|
||||
TFE_OpAddInput(concatOp, dim, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
CHECK(concatOp->inference_ctx);
|
||||
CHECK(concatOp->operation.OpDef());
|
||||
TFE_OpAddInput(concatOp, inputs[0], status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_FALSE(concatOp->inference_ctx) << "Inference context is still present";
|
||||
EXPECT_FALSE(concatOp->operation.OpDef())
|
||||
<< "Inference context is still present";
|
||||
TFE_OpAddInput(concatOp, inputs[1], status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
|
@ -284,7 +284,7 @@ class ForwardAccumulator {
|
||||
// Temporarily push or pop transient state for this accumulator.
|
||||
//
|
||||
// Allows an accumulator which is currently processing an operation to
|
||||
// temporarily reset its state. Without pushing and poping, accumulators
|
||||
// temporarily reset its state. Without pushing and popping, accumulators
|
||||
// ignore operations executed as a direct result of their own jvp
|
||||
// computations.
|
||||
void PushState() { call_state_.emplace(nullptr, false); }
|
||||
|
90
tensorflow/c/eager/tensor_handle_interface.h
Normal file
90
tensorflow/c/eager/tensor_handle_interface.h
Normal file
@ -0,0 +1,90 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
||||
#define TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
|
||||
// Abstract interface to a TensorHandle.
|
||||
//
|
||||
// A TensorHandle is management class around a Tensor which may track additional
|
||||
// metadata and synchronization.
|
||||
//
|
||||
// This allows us to hide concrete implementations of TensorHandle from header
|
||||
// files. The interface lists the common functionality that must be provided by
|
||||
// any concrete implementation. However, in cases where the true concrete class
|
||||
// is needed a static_cast can be applied.
|
||||
class AbstractTensorHandleInterface {
|
||||
public:
|
||||
virtual ~AbstractTensorHandleInterface() {}
|
||||
|
||||
// Check if the handle is in a valid initialized state.
|
||||
virtual bool IsValid(tensorflow::Status* status) const = 0;
|
||||
// Returns tensor dtype.
|
||||
virtual TF_DataType DataType() const = 0;
|
||||
// Returns number of dimensions.
|
||||
virtual int NumDims(tensorflow::Status* status) const = 0;
|
||||
// Returns number of elements across all dimensions.
|
||||
virtual int64_t NumElements(tensorflow::Status* status) const = 0;
|
||||
// Returns size of specified dimension
|
||||
virtual int64_t Dim(int dim_index, tensorflow::Status* status) const = 0;
|
||||
|
||||
// Returns the device which created the handle.
|
||||
virtual const char* DeviceName(tensorflow::Status* status) const = 0;
|
||||
// Returns the device where the tensor was placed.
|
||||
virtual const char* BackingDeviceName(tensorflow::Status* status) const = 0;
|
||||
// Returns a tensor for the handle. If tensor is remote, it will be copied.
|
||||
virtual TF_Tensor* Resolve(tensorflow::Status* status) = 0;
|
||||
// Returns debug information about the tensor.
|
||||
virtual TFE_TensorDebugInfo* TensorDebugInfo(tensorflow::Status* status) = 0;
|
||||
|
||||
// Return a copy of the handle.
|
||||
virtual AbstractTensorHandleInterface* Copy() = 0;
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TensorHandleInterface : public AbstractTensorHandleInterface {
|
||||
public:
|
||||
explicit TensorHandleInterface(TensorHandle* h) : handle_(h) {}
|
||||
~TensorHandleInterface() override;
|
||||
|
||||
bool IsValid(Status* status) const override;
|
||||
TF_DataType DataType() const override;
|
||||
int NumDims(Status* status) const override;
|
||||
int64_t NumElements(Status* status) const override;
|
||||
int64_t Dim(int dim_index, Status* status) const override;
|
||||
|
||||
const char* DeviceName(Status* status) const override;
|
||||
const char* BackingDeviceName(Status* status) const override;
|
||||
TF_Tensor* Resolve(Status* status) override;
|
||||
TFE_TensorDebugInfo* TensorDebugInfo(Status* status) override;
|
||||
|
||||
AbstractTensorHandleInterface* Copy() override;
|
||||
|
||||
// TODO(gjn): This is not a very generic interface, but is needed for specific
|
||||
// use cases.
|
||||
TensorHandle* Handle() { return handle_; }
|
||||
|
||||
private:
|
||||
TensorHandle* handle_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
@ -18,37 +18,23 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
# Core TensorFlow depends on this, this will be included in main library
|
||||
cc_library(
|
||||
name = "filesystem_interface_impl",
|
||||
srcs = ["filesystem_interface.cc"],
|
||||
hdrs = ["filesystem_interface.h"],
|
||||
deps = [
|
||||
":modular_filesystem",
|
||||
"//tensorflow/c:tf_file_statistics",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_status_internal",
|
||||
"//tensorflow/core:ptr_util",
|
||||
"//tensorflow/core/platform:env",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/core/platform:strcat",
|
||||
"//tensorflow/core/platform:stringpiece",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# Core TensorFlow depends on this, will be included in main library
|
||||
cc_library(
|
||||
name = "modular_filesystem",
|
||||
srcs = ["modular_filesystem.cc"],
|
||||
srcs = [
|
||||
"modular_filesystem.cc",
|
||||
"modular_filesystem_registration.cc",
|
||||
"modular_filesystem_registration.h",
|
||||
],
|
||||
hdrs = ["modular_filesystem.h"],
|
||||
deps = [
|
||||
":filesystem_interface",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/c:tf_status_internal",
|
||||
"//tensorflow/core:ptr_util",
|
||||
"//tensorflow/core/platform:env",
|
||||
"//tensorflow/core/platform:strcat",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:status",
|
||||
],
|
||||
)
|
||||
|
||||
@ -63,16 +49,12 @@ tf_cc_test(
|
||||
"notap", # b/139060984, requires implementing modular support for Google filesystem
|
||||
],
|
||||
deps = [
|
||||
":filesystem_interface_impl",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_status_internal",
|
||||
":modular_filesystem",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core/lib/io:path",
|
||||
"//tensorflow/core/platform:env",
|
||||
"//tensorflow/core/platform:error",
|
||||
"//tensorflow/core/platform:stacktrace_handler",
|
||||
"//tensorflow/core/platform:str_util",
|
||||
"//tensorflow/core/platform:strcat",
|
||||
"//tensorflow/core/platform:test",
|
||||
],
|
||||
)
|
||||
|
@ -1,366 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
|
||||
#include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
|
||||
/// This translation unit is linked in core TensorFlow and provides the
|
||||
/// functionality needed for plugin registration to check ABI/API compatibility,
|
||||
/// to ensure required methods are present, to ensure plugins are not allowed to
|
||||
/// change functionality after being loaded and to register the filesystems
|
||||
/// provided by a plugin. Consult the header file for more information about
|
||||
/// how this is achieved.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Checks if the plugin and core ABI numbers match, filling in `status`.
|
||||
//
|
||||
// If the numbers don't match, plugin cannot be loaded.
|
||||
static bool CheckABIHelper(int pluginABI, int coreABI, StringPiece where,
|
||||
TF_Status* status) {
|
||||
if (pluginABI != coreABI) {
|
||||
TF_SetStatus(
|
||||
status, TF_FAILED_PRECONDITION,
|
||||
strings::StrCat("Plugin ABI (", pluginABI, ") for ", where,
|
||||
" operations doesn't match expected core ABI (",
|
||||
coreABI, "). Plugin cannot be loaded.")
|
||||
.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Checks if the plugin and core ABI numbers match, for all operations.
|
||||
//
|
||||
// If the numbers don't match, plugin cannot be loaded.
|
||||
//
|
||||
// Uses the simpler `CheckABIHelper(int, int, StringPiece, TF_Status*)`
|
||||
static bool CheckABI(
|
||||
int plugin_filesystem_ops_ABI,
|
||||
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
|
||||
int plugin_random_access_file_ops_ABI,
|
||||
const TF_WritableFileOps* plugin_writable_file_ops,
|
||||
int plugin_writable_file_ops_ABI,
|
||||
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
|
||||
int plugin_read_only_memory_region_ops_ABI, TF_Status* status) {
|
||||
if (!CheckABIHelper(plugin_filesystem_ops_ABI, TF_FILESYSTEM_OPS_ABI,
|
||||
"filesystem", status))
|
||||
return false;
|
||||
|
||||
if (plugin_random_access_file_ops != nullptr &&
|
||||
!CheckABIHelper(plugin_random_access_file_ops_ABI,
|
||||
TF_RANDOM_ACCESS_FILE_OPS_ABI, "random access file",
|
||||
status))
|
||||
return false;
|
||||
|
||||
if (plugin_writable_file_ops != nullptr &&
|
||||
!CheckABIHelper(plugin_writable_file_ops_ABI, TF_WRITABLE_FILE_OPS_ABI,
|
||||
"writable file", status))
|
||||
return false;
|
||||
|
||||
if (plugin_read_only_memory_region_ops != nullptr &&
|
||||
!CheckABIHelper(plugin_read_only_memory_region_ops_ABI,
|
||||
TF_READ_ONLY_MEMORY_REGION_OPS_ABI,
|
||||
"read only memory region", status))
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Checks if the plugin and core API numbers match, logging mismatches.
|
||||
static void CheckAPIHelper(int plugin_API, int core_API, StringPiece where) {
|
||||
if (plugin_API != core_API) {
|
||||
VLOG(0) << "Plugin API (" << plugin_API << ") for " << where
|
||||
<< " operations doesn't match expected core API (" << core_API
|
||||
<< "). Plugin will be loaded but functionality might be missing.";
|
||||
}
|
||||
}
|
||||
|
||||
// Checks if the plugin and core API numbers match, for all operations.
|
||||
//
|
||||
// Uses the simpler `CheckAPIHelper(int, int, StringPiece)`.
|
||||
static void CheckAPI(
|
||||
int plugin_filesystem_ops_API,
|
||||
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
|
||||
int plugin_random_access_file_ops_API,
|
||||
const TF_WritableFileOps* plugin_writable_file_ops,
|
||||
int plugin_writable_file_ops_API,
|
||||
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
|
||||
int plugin_read_only_memory_region_ops_API) {
|
||||
CheckAPIHelper(plugin_filesystem_ops_API, TF_FILESYSTEM_OPS_API,
|
||||
"filesystem");
|
||||
|
||||
if (plugin_random_access_file_ops != nullptr)
|
||||
CheckAPIHelper(plugin_random_access_file_ops_API,
|
||||
TF_RANDOM_ACCESS_FILE_OPS_API, "random access file");
|
||||
|
||||
if (plugin_writable_file_ops != nullptr)
|
||||
CheckAPIHelper(plugin_writable_file_ops_API, TF_WRITABLE_FILE_OPS_API,
|
||||
"writable file");
|
||||
|
||||
if (plugin_read_only_memory_region_ops != nullptr)
|
||||
CheckAPIHelper(plugin_read_only_memory_region_ops_API,
|
||||
TF_READ_ONLY_MEMORY_REGION_OPS_API,
|
||||
"read only memory region");
|
||||
}
|
||||
|
||||
// Validates the filesystem operations supplied by the plugin.
|
||||
static bool ValidateHelper(const TF_FilesystemOps* ops, TF_Status* status) {
|
||||
if (ops == nullptr) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"Trying to register filesystem without operations");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ops->init == nullptr) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"Trying to register filesystem without `init` operation");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ops->cleanup == nullptr) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"Trying to register filesystem without `cleanup` operation");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Validates the random access file operations supplied by the plugin.
|
||||
static bool ValidateHelper(const TF_RandomAccessFileOps* ops,
|
||||
TF_Status* status) {
|
||||
if (ops == nullptr) {
|
||||
// We allow filesystems where files can only be written to (from TF code)
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ops->cleanup == nullptr) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"Trying to register filesystem without `cleanup` operation on "
|
||||
"random access files");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Validates the writable file operations supplied by the plugin.
|
||||
static bool ValidateHelper(const TF_WritableFileOps* ops, TF_Status* status) {
|
||||
if (ops == nullptr) {
|
||||
// We allow read-only filesystems
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ops->cleanup == nullptr) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"Trying to register filesystem without `cleanup` operation on "
|
||||
"writable files");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Validates the read only memory region operations given by the plugin.
|
||||
static bool ValidateHelper(const TF_ReadOnlyMemoryRegionOps* ops,
|
||||
TF_Status* status) {
|
||||
if (ops == nullptr) {
|
||||
// read only memory region support is always optional
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ops->cleanup == nullptr) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"Trying to register filesystem without `cleanup` operation on "
|
||||
"read only memory regions");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ops->data == nullptr) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"Trying to register filesystem without `data` operation on "
|
||||
"read only memory regions");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ops->length == nullptr) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"Trying to register filesystem without `length` operation on "
|
||||
"read only memory regions");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Validates the operations supplied by the plugin.
|
||||
//
|
||||
// Uses the 4 simpler `ValidateHelper(const TF_..., TF_Status*)` to validate
|
||||
// each individual function table and then checks that the function table for a
|
||||
// specific file type exists if the plugin offers support for creating that
|
||||
// type of files.
|
||||
static bool Validate(
|
||||
const TF_FilesystemOps* plugin_filesystem_ops,
|
||||
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
|
||||
const TF_WritableFileOps* plugin_writable_file_ops,
|
||||
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
|
||||
TF_Status* status) {
|
||||
if (!ValidateHelper(plugin_filesystem_ops, status)) return false;
|
||||
if (!ValidateHelper(plugin_random_access_file_ops, status)) return false;
|
||||
if (!ValidateHelper(plugin_writable_file_ops, status)) return false;
|
||||
if (!ValidateHelper(plugin_read_only_memory_region_ops, status)) return false;
|
||||
|
||||
if (plugin_filesystem_ops->new_random_access_file != nullptr &&
|
||||
plugin_random_access_file_ops == nullptr) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"Filesystem allows creation of random access files but no "
|
||||
"operations on them have been supplied.");
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((plugin_filesystem_ops->new_writable_file != nullptr ||
|
||||
plugin_filesystem_ops->new_appendable_file != nullptr) &&
|
||||
plugin_writable_file_ops == nullptr) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"Filesystem allows creation of writable files but no "
|
||||
"operations on them have been supplied.");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (plugin_filesystem_ops->new_read_only_memory_region_from_file != nullptr &&
|
||||
plugin_read_only_memory_region_ops == nullptr) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"Filesystem allows creation of readonly memory regions but no "
|
||||
"operations on them have been supplied.");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Copies a function table from plugin memory space to core memory space.
|
||||
//
|
||||
// This has three benefits:
|
||||
// * allows having newer plugins than the current core TensorFlow: the
|
||||
// additional entries in the plugin's table are just discarded;
|
||||
// * allows having older plugins than the current core TensorFlow (though
|
||||
// we are still warning users): the entries that core TensorFlow expects
|
||||
// but plugins didn't provide will be set to `nullptr` values and core
|
||||
// TensorFlow will know to not call these on behalf of users;
|
||||
// * increased security as plugins will not be able to alter function table
|
||||
// after loading up. Thus, malicious plugins can't alter functionality to
|
||||
// probe for gadgets inside core TensorFlow. We can even protect the area
|
||||
// of memory where the copies reside to not allow any more writes to it
|
||||
// after all copies are created.
|
||||
template <typename T>
|
||||
static std::unique_ptr<const T> CopyToCore(const T* plugin_ops,
|
||||
size_t plugin_size) {
|
||||
if (plugin_ops == nullptr) return nullptr;
|
||||
|
||||
size_t copy_size = sizeof(T);
|
||||
if (plugin_size < copy_size) {
|
||||
copy_size = plugin_size;
|
||||
}
|
||||
|
||||
auto core_ops = tensorflow::MakeUnique<T>();
|
||||
memcpy(const_cast<T*>(core_ops.get()), plugin_ops, copy_size);
|
||||
return core_ops;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
void RegisterFilesystemPlugin(
|
||||
int plugin_filesystem_ops_ABI, int plugin_filesystem_ops_API,
|
||||
size_t plugin_filesystem_ops_size, int plugin_random_access_file_ops_ABI,
|
||||
int plugin_random_access_file_ops_API,
|
||||
size_t plugin_random_access_file_ops_size, int plugin_writable_file_ops_ABI,
|
||||
int plugin_writable_file_ops_API, size_t plugin_writable_file_ops_size,
|
||||
int plugin_read_only_memory_region_ops_ABI,
|
||||
int plugin_read_only_memory_region_ops_API,
|
||||
size_t plugin_read_only_memory_region_ops_size, const char* scheme,
|
||||
const TF_FilesystemOps* plugin_filesystem_ops,
|
||||
const TF_RandomAccessFileOps* plugin_random_access_file_ops,
|
||||
const TF_WritableFileOps* plugin_writable_file_ops,
|
||||
const TF_ReadOnlyMemoryRegionOps* plugin_read_only_memory_region_ops,
|
||||
TF_Status* status) {
|
||||
if (scheme == nullptr) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"`scheme` argument must not be `nullptr`.");
|
||||
return;
|
||||
}
|
||||
|
||||
// ABI numbers must match exactly for plugin to be loaded
|
||||
if (!tensorflow::CheckABI(
|
||||
plugin_filesystem_ops_ABI, plugin_random_access_file_ops,
|
||||
plugin_random_access_file_ops_ABI, plugin_writable_file_ops,
|
||||
plugin_writable_file_ops_ABI, plugin_read_only_memory_region_ops,
|
||||
plugin_read_only_memory_region_ops_ABI, status)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// API numbers should match but mismatch doesn't block plugin load
|
||||
tensorflow::CheckAPI(plugin_filesystem_ops_API, plugin_random_access_file_ops,
|
||||
plugin_random_access_file_ops_API,
|
||||
plugin_writable_file_ops, plugin_writable_file_ops_API,
|
||||
plugin_read_only_memory_region_ops,
|
||||
plugin_read_only_memory_region_ops_API);
|
||||
|
||||
// Plugin can only be loaded if all supplied ops are valid
|
||||
if (!tensorflow::Validate(plugin_filesystem_ops,
|
||||
plugin_random_access_file_ops,
|
||||
plugin_writable_file_ops,
|
||||
plugin_read_only_memory_region_ops, status)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Copy all the function tables to core TensorFlow memory space
|
||||
auto core_filesystem_ops = tensorflow::CopyToCore<TF_FilesystemOps>(
|
||||
plugin_filesystem_ops, plugin_filesystem_ops_size);
|
||||
auto core_random_access_file_ops =
|
||||
tensorflow::CopyToCore<TF_RandomAccessFileOps>(
|
||||
plugin_random_access_file_ops, plugin_random_access_file_ops_size);
|
||||
auto core_writable_file_ops = tensorflow::CopyToCore<TF_WritableFileOps>(
|
||||
plugin_writable_file_ops, plugin_writable_file_ops_size);
|
||||
auto core_read_only_memory_region_ops =
|
||||
tensorflow::CopyToCore<TF_ReadOnlyMemoryRegionOps>(
|
||||
plugin_read_only_memory_region_ops,
|
||||
plugin_read_only_memory_region_ops_size);
|
||||
|
||||
// Initialize the opaque filesystem structure
|
||||
auto filesystem = tensorflow::MakeUnique<TF_Filesystem>();
|
||||
core_filesystem_ops->init(filesystem.get(), status);
|
||||
if (!status->status.ok()) {
|
||||
core_filesystem_ops->cleanup(filesystem.get());
|
||||
return;
|
||||
}
|
||||
|
||||
// Register new filesystem
|
||||
status->status = tensorflow::Env::Default()->RegisterFileSystem(
|
||||
scheme, tensorflow::MakeUnique<tensorflow::ModularFileSystem>(
|
||||
std::move(filesystem), std::move(core_filesystem_ops),
|
||||
std::move(core_random_access_file_ops),
|
||||
std::move(core_writable_file_ops),
|
||||
std::move(core_read_only_memory_region_ops)));
|
||||
}
|
@ -529,7 +529,7 @@ typedef struct TF_FilesystemOps {
|
||||
/// If `statuses` is not null, plugins must fill each element with detailed
|
||||
/// status for each file, as if calling `path_exists` on each one. Core
|
||||
/// TensorFlow initializes the `statuses` array and plugins must use
|
||||
/// `TF_SetStatus` to set each element instead of dirrectly assigning.
|
||||
/// `TF_SetStatus` to set each element instead of directly assigning.
|
||||
///
|
||||
/// DEFAULT IMPLEMENTATION: Checks existence of every file. Needs
|
||||
/// `path_exists`.
|
||||
@ -736,95 +736,108 @@ constexpr size_t TF_FILESYSTEM_OPS_SIZE = sizeof(TF_FilesystemOps);
|
||||
/// SECTION 4. Plugin registration and initialization
|
||||
/// ----------------------------------------------------------------------------
|
||||
///
|
||||
/// In this section we define two functions:
|
||||
/// * `TF_InitPlugin`: must be present in the plugin shared object as it will
|
||||
/// be called by core TensorFlow when the filesystem plugin is loaded;
|
||||
/// * `RegisterFilesystemPlugin`: it is implemented by core TensorFlow but
|
||||
/// plugins must call it in their `TF_InitPlugin`, usually using the macro
|
||||
/// `TF_REGISTER_FILESYSTEM_PLUGIN`.
|
||||
/// In this section we define the API used by core TensorFlow to initialize a
|
||||
/// filesystem provided by a plugin. That is, we define the following:
|
||||
/// * `TF_InitPlugin` function: must be present in the plugin shared object as
|
||||
/// it will be called by core TensorFlow when the filesystem plugin is
|
||||
/// loaded;
|
||||
/// * `TF_FilesystemPluginInfo` struct: used to transfer information between
|
||||
/// plugins and core TensorFlow about the operations provided and metadata;
|
||||
/// * `TF_SetFilesystemVersionMetadata` function: must be called by plugins in
|
||||
/// their `TF_InitPlugin` to record the versioning information the plugins
|
||||
/// are compiled against.
|
||||
///
|
||||
/// The `TF_InitPlugin` function is used by plugins to set up the data
|
||||
/// structures that implement this interface, as presented in Section 2.
|
||||
///
|
||||
/// The `RegisterFilesystemPlugin` is used by core TensorFlow to check that
|
||||
/// plugins satisfy the requirements expected by core TensorFlow, as follows:
|
||||
/// 1. If ABI numbers don't match we don't load the plugin, else we continue.
|
||||
/// 2. If the API numbers are mismatched, we warn the user and continue
|
||||
/// loading the plugin.
|
||||
/// 3. If any required operation is missing, we stop loading the plugin.
|
||||
///
|
||||
/// If all these checks succeed, we copy the plugin operations to a different
|
||||
/// memory location so that core TensorFlow has the guarantee that they won't be
|
||||
/// changed by plugins at a later time. Finally, we initialize the opaque
|
||||
/// pointer of `TF_Filesystem` by calling the required `init` function of
|
||||
/// `TF_FilesystemOps` and if that succeeds we register the filesystem.
|
||||
/// structures that implement this interface, as presented in Section 2. In
|
||||
/// order to not have plugin shared objects call back symbols defined in core
|
||||
/// TensorFlow, `TF_InitPlugin` has a `TF_FilesystemPluginInfo` argument which
|
||||
/// the plugin must fill (using the `TF_SetFilesystemVersionMetadata` for the
|
||||
/// metadata and setting up all the supported operations and the URI schemes
|
||||
/// that are supported).
|
||||
|
||||
// Initializes a TensorFlow plugin.
|
||||
//
|
||||
// Must be implemented by the plugin DSO. It is called by TensorFlow runtime.
|
||||
//
|
||||
// Filesystem plugins can be loaded on demand by users via
|
||||
// `Env::LoadLibrary` or during TensorFlow's startup if they are on certain
|
||||
// paths (although this has a security risk if two plugins register for the
|
||||
// same filesystem and the malicious one loads before the legimitate one -
|
||||
// but we consider this to be something that users should care about and
|
||||
// manage themselves). In both of these cases, core TensorFlow looks for
|
||||
// the `TF_InitPlugin` symbol and calls that function.
|
||||
//
|
||||
// A plugin is loaded only if this `status` is `TF_OK` after the call.
|
||||
TF_CAPI_EXPORT extern void TF_InitPlugin(TF_Status* status);
|
||||
/// This structure incorporates the operations defined in Section 2 and the
|
||||
/// metadata defined in section 3, allowing plugins to define different ops
|
||||
/// for different URI schemes.
|
||||
///
|
||||
/// Every URI scheme is of the form "fs" for URIs of form "fs:///path/to/file".
|
||||
/// For local filesystems (i.e., when the URI is "/path/to/file"), the scheme
|
||||
/// must be "". The scheme must never be `nullptr`.
|
||||
///
|
||||
/// Every plugin fills this in `TF_InitPlugin`, using the alocator passed as
|
||||
/// argument to allocate memory. After `TF_InitPlugin` finishes, core
|
||||
/// TensorFlow uses the information present in this to initialize filesystems
|
||||
/// for the URI schemes that the plugin requests.
|
||||
///
|
||||
/// All pointers defined in this structure point to memory allocated by the DSO
|
||||
/// using an allocator provided by core TensorFlow when calling `TF_InitPlugin`.
|
||||
///
|
||||
/// IMPORTANT: To maintain binary compatibility, the layout of this structure
|
||||
/// must not change! In the unlikely case that a new type of file needs to be
|
||||
/// supported, add the new ops and metadata at the end of the structure.
|
||||
typedef struct TF_FilesystemPluginInfo {
|
||||
char* scheme;
|
||||
int filesystem_ops_abi;
|
||||
int filesystem_ops_api;
|
||||
size_t filesystem_ops_size;
|
||||
TF_FilesystemOps* filesystem_ops;
|
||||
int random_access_file_ops_abi;
|
||||
int random_access_file_ops_api;
|
||||
size_t random_access_file_ops_size;
|
||||
TF_RandomAccessFileOps* random_access_file_ops;
|
||||
int writable_file_ops_abi;
|
||||
int writable_file_ops_api;
|
||||
size_t writable_file_ops_size;
|
||||
TF_WritableFileOps* writable_file_ops;
|
||||
int read_only_memory_region_ops_abi;
|
||||
int read_only_memory_region_ops_api;
|
||||
size_t read_only_memory_region_ops_size;
|
||||
TF_ReadOnlyMemoryRegionOps* read_only_memory_region_ops;
|
||||
} TF_FilesystemPluginInfo;
|
||||
|
||||
/// Registers a filesystem plugin so that core TensorFlow can use it.
|
||||
/// Convenience function for setting the versioning metadata.
|
||||
///
|
||||
/// Must be called by the plugin during `TF_InitPlugin`, usually by using the
|
||||
/// convenience `TF_REGISTER_FILESYSTEM_PLUGIN` macro.
|
||||
/// The argument is guaranteed to not be `nullptr`.
|
||||
///
|
||||
/// Arguments (grouped by category):
|
||||
/// * `..ABI`: ABI compatibility numbers (see Section 3.).
|
||||
/// * `..API`: API compatibility numbers (see Section 3.).
|
||||
/// * `..Size`: Sizes of the operation tables (see Section 3.).
|
||||
/// * `scheme`: The URI scheme that plugin is registering filesystems for.
|
||||
/// Must be of the form "fs" for URIs of form "fs:///path/to/file". For
|
||||
/// local filesystems (i.e., when the URI is "/path/to/file"), `scheme`
|
||||
/// must be "". Must never be `nullptr`.
|
||||
/// * `..Ops`: The function tables provided by the plugin. Owned by the
|
||||
/// plugin, but core TensorFlow makes a copy of these.
|
||||
/// * `status`: The output variable for representing success/failure.
|
||||
///
|
||||
/// Sets `status` to `TF_OK` if plugin was registered and filesystem operations
|
||||
/// can be invoked from anywhere during TensorFlow's runtime. Any other value of
|
||||
/// `status` means that plugin failed to load properly and as such the
|
||||
/// operations it provides cannot be used at all (i.e., core TensorFlow will
|
||||
/// never run them, returning early with `TF_UNIMPLEMENTED` or similar error
|
||||
/// values).
|
||||
TF_CAPI_EXPORT extern void RegisterFilesystemPlugin(
|
||||
int pluginFilesystemOpsABI, int pluginFilesystemOpsAPI,
|
||||
size_t pluginFilesystemOpsSize, int pluginRandomAccessFileOpsABI,
|
||||
int pluginRandomAccessFileOpsAPI, size_t pluginRandomAccessFileOpsSize,
|
||||
int pluginWritableFileOpsABI, int pluginWritableFileOpsAPI,
|
||||
size_t pluginWritableFileOpsSize, int pluginReadOnlyMemoryRegionOpsABI,
|
||||
int pluginReadOnlyMemoryRegionOpsAPI,
|
||||
size_t pluginReadOnlyMemoryRegionOpsSize, const char* scheme,
|
||||
const TF_FilesystemOps* pluginFilesystemOps,
|
||||
const TF_RandomAccessFileOps* pluginRandomAccessFileOps,
|
||||
const TF_WritableFileOps* pluginWritableFileOps,
|
||||
const TF_ReadOnlyMemoryRegionOps* pluginReadOnlyMemoryRegionOps,
|
||||
TF_Status* status);
|
||||
/// We want this to be defined in the plugin's memory space and we guarantee
|
||||
/// that core TensorFlow will never call this.
|
||||
static inline void TF_SetFilesystemVersionMetadata(
|
||||
TF_FilesystemPluginInfo* info) {
|
||||
info->filesystem_ops_abi = TF_FILESYSTEM_OPS_ABI;
|
||||
info->filesystem_ops_api = TF_FILESYSTEM_OPS_API;
|
||||
info->filesystem_ops_size = TF_FILESYSTEM_OPS_SIZE;
|
||||
info->random_access_file_ops_abi = TF_RANDOM_ACCESS_FILE_OPS_ABI;
|
||||
info->random_access_file_ops_api = TF_RANDOM_ACCESS_FILE_OPS_API;
|
||||
info->random_access_file_ops_size = TF_RANDOM_ACCESS_FILE_OPS_SIZE;
|
||||
info->writable_file_ops_abi = TF_WRITABLE_FILE_OPS_ABI;
|
||||
info->writable_file_ops_api = TF_WRITABLE_FILE_OPS_API;
|
||||
info->writable_file_ops_size = TF_WRITABLE_FILE_OPS_SIZE;
|
||||
info->read_only_memory_region_ops_abi = TF_READ_ONLY_MEMORY_REGION_OPS_ABI;
|
||||
info->read_only_memory_region_ops_api = TF_READ_ONLY_MEMORY_REGION_OPS_API;
|
||||
info->read_only_memory_region_ops_size = TF_READ_ONLY_MEMORY_REGION_OPS_SIZE;
|
||||
}
|
||||
|
||||
/// This macro is just a convenience wrapper around `RegisterFilesystemPlugin`.
|
||||
/// Plugins should prefer using this macro instead of a direct call.
|
||||
#define TF_REGISTER_FILESYSTEM_PLUGIN( \
|
||||
scheme, pluginFilesystemOps, pluginRandomAccessFileOps, \
|
||||
pluginWritableFileOps, pluginReadOnlyMemoryRegionOps, status) \
|
||||
RegisterFilesystemPlugin( \
|
||||
TF_FILESYSTEM_OPS_ABI, TF_FILESYSTEM_OPS_API, TF_FILESYSTEM_OPS_SIZE, \
|
||||
TF_RANDOM_ACCESS_FILE_OPS_ABI, TF_RANDOM_ACCESS_FILE_OPS_API, \
|
||||
TF_RANDOM_ACCESS_FILE_OPS_SIZE, TF_WRITABLE_FILE_OPS_ABI, \
|
||||
TF_WRITABLE_FILE_OPS_API, TF_WRITABLE_FILE_OPS_SIZE, \
|
||||
TF_READ_ONLY_MEMORY_REGION_OPS_ABI, TF_READ_ONLY_MEMORY_REGION_OPS_API, \
|
||||
TF_READ_ONLY_MEMORY_REGION_OPS_SIZE, scheme, pluginFilesystemOps, \
|
||||
pluginRandomAccessFileOps, pluginWritableFileOps, \
|
||||
pluginReadOnlyMemoryRegionOps, status)
|
||||
/// Initializes a TensorFlow plugin.
|
||||
///
|
||||
/// Must be implemented by the plugin DSO. It is called by TensorFlow runtime.
|
||||
///
|
||||
/// Filesystem plugins can be loaded on demand by users via
|
||||
/// `Env::LoadLibrary` or during TensorFlow's startup if they are on certain
|
||||
/// paths (although this has a security risk if two plugins register for the
|
||||
/// same filesystem and the malicious one loads before the legimitate one -
|
||||
/// but we consider this to be something that users should care about and
|
||||
/// manage themselves). In both of these cases, core TensorFlow looks for
|
||||
/// the `TF_InitPlugin` symbol and calls this function.
|
||||
///
|
||||
/// All memory allocated by this function must be allocated via the `allocator`
|
||||
/// argument.
|
||||
///
|
||||
/// For every filesystem URI scheme that this plugin supports, the plugin must
|
||||
/// add one `TF_FilesystemPluginInfo` entry in `plugin_info`.
|
||||
///
|
||||
/// Returns number of entries in `plugin_info` (i.e., number of URI schemes
|
||||
/// supported).
|
||||
TF_CAPI_EXPORT extern int TF_InitPlugin(void* (*allocator)(size_t size),
|
||||
TF_FilesystemPluginInfo** plugin_info);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
|
@ -18,11 +18,10 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/file_system_helper.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
|
||||
// TODO(mihaimaruseac): After all filesystems are converted, all calls to
|
||||
@ -435,4 +434,8 @@ Status ModularWritableFile::Tell(int64* position) {
|
||||
return StatusFromTF_Status(plugin_status.get());
|
||||
}
|
||||
|
||||
Status RegisterFilesystemPlugin(const std::string& dso_path) {
|
||||
return filesystem_registration::RegisterFilesystemPluginImpl(dso_path);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -32,7 +32,7 @@ namespace tensorflow {
|
||||
// TODO(b/143949615): After all filesystems are converted, this file will be
|
||||
// moved to core/platform, and this class can become a singleton and replace the
|
||||
// need for `Env::Default()`. At that time, we might decide to remove the need
|
||||
// for `Env::Default()` altoghether, but that's a different project, not in
|
||||
// for `Env::Default()` altogether, but that's a different project, not in
|
||||
// scope for now. I'm just mentioning this here as that transition will mean
|
||||
// removal of the registration part from `Env` and adding it here instead: we
|
||||
// will need tables to hold for each scheme the function tables that implement
|
||||
@ -156,6 +156,9 @@ class ModularReadOnlyMemoryRegion final : public ReadOnlyMemoryRegion {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ModularReadOnlyMemoryRegion);
|
||||
};
|
||||
|
||||
// Registers a filesystem plugin so that core TensorFlow can use it.
|
||||
Status RegisterFilesystemPlugin(const std::string& dso_path);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_H_
|
||||
|
@ -0,0 +1,325 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h"
|
||||
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Checks that all schemes provided by a plugin are valid.
|
||||
// TODO(mihaimaruseac): More validation could be done here, based on supported
|
||||
// charset, maximum length, etc. Punting it for later.
|
||||
static Status ValidateScheme(const char* scheme) {
|
||||
if (scheme == nullptr)
|
||||
return errors::InvalidArgument(
|
||||
"Attempted to register filesystem with `nullptr` URI scheme");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Checks if the plugin and core ABI numbers match.
|
||||
//
|
||||
// If the numbers don't match, plugin cannot be loaded.
|
||||
static Status CheckABI(int pluginABI, int coreABI, StringPiece where) {
|
||||
if (pluginABI != coreABI)
|
||||
return errors::FailedPrecondition(
|
||||
strings::StrCat("Plugin ABI (", pluginABI, ") for ", where,
|
||||
" operations doesn't match expected core ABI (",
|
||||
coreABI, "). Plugin cannot be loaded."));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Checks if the plugin and core ABI numbers match, for all operations.
|
||||
//
|
||||
// If the numbers don't match, plugin cannot be loaded.
|
||||
//
|
||||
// Uses the simpler `CheckABI(int, int, StringPiece)`.
|
||||
static Status ValidateABI(const TF_FilesystemPluginInfo* info) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckABI(info->filesystem_ops_abi, TF_FILESYSTEM_OPS_ABI, "filesystem"));
|
||||
|
||||
if (info->random_access_file_ops != nullptr)
|
||||
TF_RETURN_IF_ERROR(CheckABI(info->random_access_file_ops_abi,
|
||||
TF_RANDOM_ACCESS_FILE_OPS_ABI,
|
||||
"random access file"));
|
||||
|
||||
if (info->writable_file_ops != nullptr)
|
||||
TF_RETURN_IF_ERROR(CheckABI(info->writable_file_ops_abi,
|
||||
TF_WRITABLE_FILE_OPS_ABI, "writable file"));
|
||||
|
||||
if (info->read_only_memory_region_ops != nullptr)
|
||||
TF_RETURN_IF_ERROR(CheckABI(info->read_only_memory_region_ops_abi,
|
||||
TF_READ_ONLY_MEMORY_REGION_OPS_ABI,
|
||||
"read only memory region"));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Checks if the plugin and core API numbers match, logging mismatches.
|
||||
static void CheckAPI(int plugin_API, int core_API, StringPiece where) {
|
||||
if (plugin_API != core_API) {
|
||||
VLOG(0) << "Plugin API (" << plugin_API << ") for " << where
|
||||
<< " operations doesn't match expected core API (" << core_API
|
||||
<< "). Plugin will be loaded but functionality might be missing.";
|
||||
}
|
||||
}
|
||||
|
||||
// Checks if the plugin and core API numbers match, for all operations.
|
||||
//
|
||||
// Uses the simpler `CheckAPIHelper(int, int, StringPiece)`.
|
||||
static void ValidateAPI(const TF_FilesystemPluginInfo* info) {
|
||||
CheckAPI(info->filesystem_ops_api, TF_FILESYSTEM_OPS_API, "filesystem");
|
||||
|
||||
if (info->random_access_file_ops != nullptr)
|
||||
CheckAPI(info->random_access_file_ops_api, TF_RANDOM_ACCESS_FILE_OPS_API,
|
||||
"random access file");
|
||||
|
||||
if (info->writable_file_ops != nullptr)
|
||||
CheckAPI(info->writable_file_ops_api, TF_WRITABLE_FILE_OPS_API,
|
||||
"writable file");
|
||||
|
||||
if (info->read_only_memory_region_ops != nullptr)
|
||||
CheckAPI(info->read_only_memory_region_ops_api,
|
||||
TF_READ_ONLY_MEMORY_REGION_OPS_API, "read only memory region");
|
||||
}
|
||||
|
||||
// Validates the filesystem operations supplied by the plugin.
|
||||
static Status ValidateHelper(const TF_FilesystemOps* ops) {
|
||||
if (ops == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Trying to register filesystem without operations");
|
||||
|
||||
if (ops->init == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Trying to register filesystem without `init` operation");
|
||||
|
||||
if (ops->cleanup == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Trying to register filesystem without `cleanup` operation");
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Validates the random access file operations supplied by the plugin.
|
||||
static Status ValidateHelper(const TF_RandomAccessFileOps* ops) {
|
||||
if (ops == nullptr) {
|
||||
// We allow filesystems where files can only be written to (from TF code)
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (ops->cleanup == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Trying to register filesystem without `cleanup` operation on random "
|
||||
"access files");
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Validates the writable file operations supplied by the plugin.
|
||||
static Status ValidateHelper(const TF_WritableFileOps* ops) {
|
||||
if (ops == nullptr) {
|
||||
// We allow read-only filesystems
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (ops->cleanup == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Trying to register filesystem without `cleanup` operation on writable "
|
||||
"files");
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Validates the read only memory region operations given by the plugin.
|
||||
static Status ValidateHelper(const TF_ReadOnlyMemoryRegionOps* ops) {
|
||||
if (ops == nullptr) {
|
||||
// read only memory region support is always optional
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (ops->cleanup == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Trying to register filesystem without `cleanup` operation on read "
|
||||
"only memory regions");
|
||||
|
||||
if (ops->data == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Trying to register filesystem without `data` operation on read only "
|
||||
"memory regions");
|
||||
|
||||
if (ops->length == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Trying to register filesystem without `length` operation on read only "
|
||||
"memory regions");
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Validates the operations supplied by the plugin.
|
||||
//
|
||||
// Uses the 4 simpler `ValidateHelper(const TF_...*)` to validate each
|
||||
// individual function table and then checks that the function table for a
|
||||
// specific file type exists if the plugin offers support for creating that
|
||||
// type of files.
|
||||
static Status ValidateOperations(const TF_FilesystemPluginInfo* info) {
|
||||
TF_RETURN_IF_ERROR(ValidateHelper(info->filesystem_ops));
|
||||
TF_RETURN_IF_ERROR(ValidateHelper(info->random_access_file_ops));
|
||||
TF_RETURN_IF_ERROR(ValidateHelper(info->writable_file_ops));
|
||||
TF_RETURN_IF_ERROR(ValidateHelper(info->read_only_memory_region_ops));
|
||||
|
||||
if (info->filesystem_ops->new_random_access_file != nullptr &&
|
||||
info->random_access_file_ops == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Filesystem allows creation of random access files but no "
|
||||
"operations on them have been supplied.");
|
||||
|
||||
if ((info->filesystem_ops->new_writable_file != nullptr ||
|
||||
info->filesystem_ops->new_appendable_file != nullptr) &&
|
||||
info->writable_file_ops == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Filesystem allows creation of writable files but no "
|
||||
"operations on them have been supplied.");
|
||||
|
||||
if (info->filesystem_ops->new_read_only_memory_region_from_file != nullptr &&
|
||||
info->read_only_memory_region_ops == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Filesystem allows creation of readonly memory regions but no "
|
||||
"operations on them have been supplied.");
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Copies a function table from plugin memory space to core memory space.
|
||||
//
|
||||
// This has three benefits:
|
||||
// * allows having newer plugins than the current core TensorFlow: the
|
||||
// additional entries in the plugin's table are just discarded;
|
||||
// * allows having older plugins than the current core TensorFlow (though
|
||||
// we are still warning users): the entries that core TensorFlow expects
|
||||
// but plugins didn't provide will be set to `nullptr` values and core
|
||||
// TensorFlow will know to not call these on behalf of users;
|
||||
// * increased security as plugins will not be able to alter function table
|
||||
// after loading up. Thus, malicious plugins can't alter functionality to
|
||||
// probe for gadgets inside core TensorFlow. We can even protect the area
|
||||
// of memory where the copies reside to not allow any more writes to it
|
||||
// after all copies are created.
|
||||
template <typename T>
|
||||
static std::unique_ptr<const T> CopyToCore(const T* plugin_ops,
|
||||
size_t plugin_size) {
|
||||
if (plugin_ops == nullptr) return nullptr;
|
||||
|
||||
size_t copy_size = std::min(plugin_size, sizeof(T));
|
||||
auto core_ops = tensorflow::MakeUnique<T>();
|
||||
memset(core_ops.get(), 0, sizeof(T));
|
||||
memcpy(core_ops.get(), plugin_ops, copy_size);
|
||||
return core_ops;
|
||||
}
|
||||
|
||||
// Registers one filesystem from the plugin.
|
||||
static Status RegisterFileSystem(const TF_FilesystemPluginInfo* info) {
|
||||
// Step 1: Copy all the function tables to core TensorFlow memory space
|
||||
auto core_filesystem_ops = CopyToCore<TF_FilesystemOps>(
|
||||
info->filesystem_ops, info->filesystem_ops_size);
|
||||
auto core_random_access_file_ops = CopyToCore<TF_RandomAccessFileOps>(
|
||||
info->random_access_file_ops, info->random_access_file_ops_size);
|
||||
auto core_writable_file_ops = CopyToCore<TF_WritableFileOps>(
|
||||
info->writable_file_ops, info->writable_file_ops_size);
|
||||
auto core_read_only_memory_region_ops =
|
||||
CopyToCore<TF_ReadOnlyMemoryRegionOps>(
|
||||
info->read_only_memory_region_ops,
|
||||
info->read_only_memory_region_ops_size);
|
||||
|
||||
// Step 2: Initialize the opaque filesystem structure
|
||||
auto filesystem = tensorflow::MakeUnique<TF_Filesystem>();
|
||||
TF_Status* c_status = TF_NewStatus();
|
||||
Status status = Status::OK();
|
||||
core_filesystem_ops->init(filesystem.get(), c_status);
|
||||
status = Status(c_status->status);
|
||||
TF_DeleteStatus(c_status);
|
||||
if (!status.ok()) return status;
|
||||
|
||||
// Step 3: Actual registration
|
||||
return Env::Default()->RegisterFileSystem(
|
||||
info->scheme, tensorflow::MakeUnique<tensorflow::ModularFileSystem>(
|
||||
std::move(filesystem), std::move(core_filesystem_ops),
|
||||
std::move(core_random_access_file_ops),
|
||||
std::move(core_writable_file_ops),
|
||||
std::move(core_read_only_memory_region_ops)));
|
||||
}
|
||||
|
||||
// Registers all filesystems, if plugin is providing valid information.
|
||||
//
|
||||
// Extracted to a separate function so that pointers inside `info` are freed
|
||||
// by the caller regardless of whether validation/registration failed or not.
|
||||
static Status ValidateAndRegisterFilesystems(
|
||||
const TF_FilesystemPluginInfo* info) {
|
||||
TF_RETURN_IF_ERROR(ValidateScheme(info->scheme));
|
||||
TF_RETURN_IF_ERROR(ValidateABI(info));
|
||||
ValidateAPI(info); // we just warn on API number mismatch
|
||||
TF_RETURN_IF_ERROR(ValidateOperations(info));
|
||||
TF_RETURN_IF_ERROR(RegisterFileSystem(info));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Alocates memory in plugin DSO.
|
||||
//
|
||||
// Provided by core TensorFlow so that it can free this memory after DSO is
|
||||
// loaded and filesystem information has been used to register the filesystem.
|
||||
static void* basic_allocator(size_t size) { return calloc(1, size); }
|
||||
|
||||
namespace filesystem_registration {
|
||||
|
||||
Status RegisterFilesystemPluginImpl(const std::string& dso_path) {
|
||||
// Step 1: Load plugin
|
||||
Env* env = Env::Default();
|
||||
void* dso_handle;
|
||||
TF_RETURN_IF_ERROR(env->LoadLibrary(dso_path.c_str(), &dso_handle));
|
||||
|
||||
// Step 2: Load symbol for `TF_InitPlugin`
|
||||
void* dso_symbol;
|
||||
TF_RETURN_IF_ERROR(
|
||||
env->GetSymbolFromLibrary(dso_handle, "TF_InitPlugin", &dso_symbol));
|
||||
|
||||
// Step 3: Call `TF_InitPlugin`
|
||||
TF_FilesystemPluginInfo* info = nullptr;
|
||||
auto TF_InitPlugin = reinterpret_cast<int (*)(
|
||||
decltype(&basic_allocator), TF_FilesystemPluginInfo**)>(dso_symbol);
|
||||
int num_schemes = TF_InitPlugin(&basic_allocator, &info);
|
||||
if (num_schemes < 0 || info == nullptr)
|
||||
return errors::InvalidArgument("DSO returned invalid filesystem data");
|
||||
|
||||
// Step 4: Validate and register all filesystems
|
||||
// Try to register as many filesystems as possible.
|
||||
// Free memory once we no longer need it
|
||||
Status status;
|
||||
for (int i = 0; i < num_schemes; i++) {
|
||||
status.Update(ValidateAndRegisterFilesystems(&info[i]));
|
||||
free(info[i].scheme);
|
||||
free(info[i].filesystem_ops);
|
||||
free(info[i].random_access_file_ops);
|
||||
free(info[i].writable_file_ops);
|
||||
free(info[i].read_only_memory_region_ops);
|
||||
}
|
||||
free(info);
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace filesystem_registration
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,28 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
|
||||
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace filesystem_registration {
|
||||
|
||||
Status RegisterFilesystemPluginImpl(const std::string& dso_path);
|
||||
|
||||
} // namespace filesystem_registration
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
|
File diff suppressed because it is too large
Load Diff
@ -1,35 +1,47 @@
|
||||
# Experimental posix filesystem plugin.
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:private"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
# Although this target results in a shared object that will be loaded at
|
||||
# runtime, this target must be a `cc_library` instead of a `cc_binary`. Making
|
||||
# it a `cc_binary` requires `linkshared = True`. In turn, this brings in several
|
||||
# TensorFlow symbols under `tensorflow::` namespace, for which we have no ABI
|
||||
# guarantees. Hence, in order to maintain ABI compatibility, this is marked as a
|
||||
# `cc_library` for now and we will revisit in the future.
|
||||
# TODO(mihaimaruseac): Determine if `cc_binary` makes more sense (when all
|
||||
# filesystems are converted and BUILD files are refactored to be modular).
|
||||
# TODO(b/144585140): The helpers should be separated into a different BUILD target
|
||||
# but doing that would result in symbols not being visible when loading plugin.
|
||||
# Revisit this once POSIX filesystem completely lands. See also the other TODO.
|
||||
# This also has the unfortunate effect that both versions of copy_file get
|
||||
# compiled, regardless of which one actually gets used!
|
||||
# Filesystem implementation for POSIX environments: Linux, MacOS, Android, etc.
|
||||
tf_cc_shared_object(
|
||||
name = "libposix_filesystem.so",
|
||||
framework_so = [],
|
||||
linkstatic = False,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":posix_filesystem_impl"],
|
||||
)
|
||||
|
||||
# The real implementation of the filesystem.
|
||||
cc_library(
|
||||
name = "posix_filesystem",
|
||||
srcs = [
|
||||
"posix_filesystem.cc",
|
||||
"posix_filesystem_helper.cc",
|
||||
"posix_filesystem_helper.h",
|
||||
"copy_file.h",
|
||||
] + select({
|
||||
"//tensorflow:linux_x86_64": ["copy_file_linux.cc"],
|
||||
"//conditions:default": ["copy_file_portable.cc"],
|
||||
}),
|
||||
name = "posix_filesystem_impl",
|
||||
srcs = ["posix_filesystem.cc"],
|
||||
deps = [
|
||||
":posix_filesystem_helper",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
],
|
||||
)
|
||||
|
||||
# Library implementing helper functionality, so that the above only contains
|
||||
# the API implementation for modular filesystems.
|
||||
cc_library(
|
||||
name = "posix_filesystem_helper",
|
||||
srcs = ["posix_filesystem_helper.cc"],
|
||||
hdrs = ["posix_filesystem_helper.h"],
|
||||
deps = [":copy_file"],
|
||||
)
|
||||
|
||||
# On Linux, we can copy files faster using `sendfile`. But not elsewhere.
|
||||
# Hence, this private library to select which implementation to use.
|
||||
cc_library(
|
||||
name = "copy_file",
|
||||
srcs = select({
|
||||
"//tensorflow:linux_x86_64": ["copy_file_linux.cc"],
|
||||
"//conditions:default": ["copy_file_portable.cc"],
|
||||
}),
|
||||
hdrs = ["copy_file.h"],
|
||||
)
|
||||
|
@ -24,8 +24,6 @@ limitations under the License.
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_helper.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
@ -396,48 +394,65 @@ static int GetChildren(const TF_Filesystem* filesystem, const char* path,
|
||||
|
||||
} // namespace tf_posix_filesystem
|
||||
|
||||
void TF_InitPlugin(TF_Status* status) {
|
||||
TF_RandomAccessFileOps random_access_file_ops = {
|
||||
tf_random_access_file::Cleanup,
|
||||
tf_random_access_file::Read,
|
||||
};
|
||||
TF_WritableFileOps writable_file_ops = {
|
||||
tf_writable_file::Cleanup, tf_writable_file::Append,
|
||||
tf_writable_file::Tell, tf_writable_file::Flush,
|
||||
tf_writable_file::Sync, tf_writable_file::Close,
|
||||
};
|
||||
TF_ReadOnlyMemoryRegionOps read_only_memory_region_ops = {
|
||||
tf_read_only_memory_region::Cleanup,
|
||||
tf_read_only_memory_region::Data,
|
||||
tf_read_only_memory_region::Length,
|
||||
};
|
||||
TF_FilesystemOps filesystem_ops = {
|
||||
tf_posix_filesystem::Init,
|
||||
tf_posix_filesystem::Cleanup,
|
||||
tf_posix_filesystem::NewRandomAccessFile,
|
||||
tf_posix_filesystem::NewWritableFile,
|
||||
tf_posix_filesystem::NewAppendableFile,
|
||||
tf_posix_filesystem::NewReadOnlyMemoryRegionFromFile,
|
||||
tf_posix_filesystem::CreateDir,
|
||||
/*recursively_create_dir=*/nullptr,
|
||||
tf_posix_filesystem::DeleteFile,
|
||||
tf_posix_filesystem::DeleteDir,
|
||||
/*delete_recursively=*/nullptr,
|
||||
tf_posix_filesystem::RenameFile,
|
||||
tf_posix_filesystem::CopyFile,
|
||||
tf_posix_filesystem::PathExists,
|
||||
/*paths_exist=*/nullptr,
|
||||
tf_posix_filesystem::Stat,
|
||||
/*is_directory=*/nullptr,
|
||||
/*get_file_size=*/nullptr,
|
||||
/*translate_name=*/nullptr,
|
||||
tf_posix_filesystem::GetChildren,
|
||||
/*get_matching_paths=*/nullptr,
|
||||
/*flush_caches=*/nullptr,
|
||||
};
|
||||
int TF_InitPlugin(void* (*allocator)(size_t), TF_FilesystemPluginInfo** info) {
|
||||
const int num_schemes = 2;
|
||||
*info = static_cast<TF_FilesystemPluginInfo*>(
|
||||
allocator(num_schemes * sizeof((*info)[0])));
|
||||
|
||||
for (const char* scheme : {"", "file"})
|
||||
TF_REGISTER_FILESYSTEM_PLUGIN(scheme, &filesystem_ops,
|
||||
&random_access_file_ops, &writable_file_ops,
|
||||
&read_only_memory_region_ops, status);
|
||||
for (int i = 0; i < num_schemes; i++) {
|
||||
TF_FilesystemPluginInfo* current_info = &((*info)[i]);
|
||||
TF_SetFilesystemVersionMetadata(current_info);
|
||||
|
||||
current_info->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
|
||||
allocator(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
|
||||
current_info->random_access_file_ops->cleanup =
|
||||
tf_random_access_file::Cleanup;
|
||||
current_info->random_access_file_ops->read = tf_random_access_file::Read;
|
||||
|
||||
current_info->writable_file_ops =
|
||||
static_cast<TF_WritableFileOps*>(allocator(TF_WRITABLE_FILE_OPS_SIZE));
|
||||
current_info->writable_file_ops->cleanup = tf_writable_file::Cleanup;
|
||||
current_info->writable_file_ops->append = tf_writable_file::Append;
|
||||
current_info->writable_file_ops->tell = tf_writable_file::Tell;
|
||||
current_info->writable_file_ops->flush = tf_writable_file::Flush;
|
||||
current_info->writable_file_ops->sync = tf_writable_file::Sync;
|
||||
current_info->writable_file_ops->close = tf_writable_file::Close;
|
||||
|
||||
current_info->read_only_memory_region_ops =
|
||||
static_cast<TF_ReadOnlyMemoryRegionOps*>(
|
||||
allocator(TF_READ_ONLY_MEMORY_REGION_OPS_SIZE));
|
||||
current_info->read_only_memory_region_ops->cleanup =
|
||||
tf_read_only_memory_region::Cleanup;
|
||||
current_info->read_only_memory_region_ops->data =
|
||||
tf_read_only_memory_region::Data;
|
||||
current_info->read_only_memory_region_ops->length =
|
||||
tf_read_only_memory_region::Length;
|
||||
|
||||
current_info->filesystem_ops =
|
||||
static_cast<TF_FilesystemOps*>(allocator(TF_FILESYSTEM_OPS_SIZE));
|
||||
current_info->filesystem_ops->init = tf_posix_filesystem::Init;
|
||||
current_info->filesystem_ops->cleanup = tf_posix_filesystem::Cleanup;
|
||||
current_info->filesystem_ops->new_random_access_file =
|
||||
tf_posix_filesystem::NewRandomAccessFile;
|
||||
current_info->filesystem_ops->new_writable_file =
|
||||
tf_posix_filesystem::NewWritableFile;
|
||||
current_info->filesystem_ops->new_appendable_file =
|
||||
tf_posix_filesystem::NewAppendableFile;
|
||||
current_info->filesystem_ops->new_read_only_memory_region_from_file =
|
||||
tf_posix_filesystem::NewReadOnlyMemoryRegionFromFile;
|
||||
current_info->filesystem_ops->create_dir = tf_posix_filesystem::CreateDir;
|
||||
current_info->filesystem_ops->delete_file = tf_posix_filesystem::DeleteFile;
|
||||
current_info->filesystem_ops->delete_dir = tf_posix_filesystem::DeleteDir;
|
||||
current_info->filesystem_ops->rename_file = tf_posix_filesystem::RenameFile;
|
||||
current_info->filesystem_ops->copy_file = tf_posix_filesystem::CopyFile;
|
||||
current_info->filesystem_ops->path_exists = tf_posix_filesystem::PathExists;
|
||||
current_info->filesystem_ops->stat = tf_posix_filesystem::Stat;
|
||||
current_info->filesystem_ops->get_children =
|
||||
tf_posix_filesystem::GetChildren;
|
||||
}
|
||||
|
||||
(*info)[0].scheme = strdup("");
|
||||
(*info)[1].scheme = strdup("file");
|
||||
|
||||
return num_schemes;
|
||||
}
|
||||
|
@ -44,7 +44,7 @@ int TransferFileContents(const char* src, const char* dst, mode_t mode,
|
||||
}
|
||||
|
||||
// Both files have been opened, do the transfer.
|
||||
// Since errno would be overriden by `close` below, save it here.
|
||||
// Since errno would be overridden by `close` below, save it here.
|
||||
int error_code = 0;
|
||||
if (CopyFileContents(dst_fd, src_fd, size) < 0) error_code = errno;
|
||||
|
||||
|
36
tensorflow/c/experimental/filesystem/plugins/windows/BUILD
Normal file
36
tensorflow/c/experimental/filesystem/plugins/windows/BUILD
Normal file
@ -0,0 +1,36 @@
|
||||
# Experimental windows filesystem plugin.
|
||||
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
# Filesystem implementation for Windows environment
|
||||
tf_cc_shared_object(
|
||||
name = "windows_filesystem.dll",
|
||||
framework_so = [],
|
||||
linkstatic = False,
|
||||
tags = [
|
||||
"manual",
|
||||
"nobuilder",
|
||||
"notap",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":windows_filesystem_impl"],
|
||||
)
|
||||
|
||||
# The real implementation of the filesystem.
|
||||
cc_library(
|
||||
name = "windows_filesystem_impl",
|
||||
srcs = ["windows_filesystem.cc"],
|
||||
copts = get_win_copts(),
|
||||
tags = [
|
||||
"manual",
|
||||
"nobuilder",
|
||||
"notap",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
],
|
||||
)
|
@ -0,0 +1,70 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
// Implementation of a filesystem for POSIX environments.
|
||||
// This filesystem will support `file://` and empty (local) URI schemes.
|
||||
|
||||
// SECTION 1. Implementation for `TF_RandomAccessFile`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_random_access_file {
|
||||
|
||||
// TODO(mihaimaruseac): Implement later
|
||||
|
||||
} // namespace tf_random_access_file
|
||||
|
||||
// SECTION 2. Implementation for `TF_WritableFile`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_writable_file {
|
||||
|
||||
// TODO(mihaimaruseac): Implement later
|
||||
|
||||
} // namespace tf_writable_file
|
||||
|
||||
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_read_only_memory_region {
|
||||
|
||||
// TODO(mihaimaruseac): Implement later
|
||||
|
||||
} // namespace tf_read_only_memory_region
|
||||
|
||||
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_windows_filesystem {
|
||||
|
||||
// TODO(mihaimaruseac): Implement later
|
||||
|
||||
} // namespace tf_windows_filesystem
|
||||
|
||||
int TF_InitPlugin(void* (*allocator)(size_t), TF_FilesystemPluginInfo** info) {
|
||||
const int num_schemes = 2;
|
||||
*info = static_cast<TF_FilesystemPluginInfo*>(
|
||||
allocator(num_schemes * sizeof((*info)[0])));
|
||||
|
||||
for (int i = 0; i < num_schemes; i++) {
|
||||
TF_FilesystemPluginInfo* current_info = &((*info)[i]);
|
||||
TF_SetFilesystemVersionMetadata(current_info);
|
||||
}
|
||||
|
||||
(*info)[0].scheme = strdup("");
|
||||
(*info)[1].scheme = strdup("file");
|
||||
|
||||
return num_schemes;
|
||||
}
|
@ -181,7 +181,8 @@ void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor,
|
||||
return;
|
||||
}
|
||||
const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i));
|
||||
TF_Tensor* result = ::tensorflow::TF_TensorFromTensor(cc_tensor, status);
|
||||
TF_Tensor* result =
|
||||
::tensorflow::TF_TensorFromTensor(cc_tensor, &status->status);
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
*tensor = result;
|
||||
}
|
||||
|
@ -133,7 +133,7 @@ TEST(OpsTest, TestShapeInference_VectorizeFunction) {
|
||||
|
||||
TEST(OpsTest, AttributeAccessors) {
|
||||
TF_OpDefinitionBuilder* builder =
|
||||
TF_NewOpDefinitionBuilder("AttributeAccesorsOp");
|
||||
TF_NewOpDefinitionBuilder("AttributeAccessorsOp");
|
||||
TF_OpDefinitionBuilderAddAttr(builder, "foo1: int >= 2");
|
||||
TF_OpDefinitionBuilderAddAttr(builder, "foo2: string=\"my string\"");
|
||||
TF_OpDefinitionBuilderSetIsCommutative(builder, true);
|
||||
@ -151,7 +151,7 @@ TEST(OpsTest, AttributeAccessors) {
|
||||
op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
|
||||
bool found = false;
|
||||
for (const auto& op : op_list.op()) {
|
||||
if (op.name() == "AttributeAccesorsOp") {
|
||||
if (op.name() == "AttributeAccessorsOp") {
|
||||
ASSERT_TRUE(op.is_commutative());
|
||||
ASSERT_TRUE(op.is_aggregate());
|
||||
ASSERT_TRUE(op.allows_uninitialized_input());
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
@ -103,49 +105,35 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
|
||||
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
|
||||
}
|
||||
|
||||
TF_Tensor* ret =
|
||||
new TF_Tensor{Tensor(static_cast<tensorflow::DataType>(dtype),
|
||||
tensorflow::TensorShape(dimvec), buf)};
|
||||
// TODO(gjn): Make the choice of interface a compile-time configuration.
|
||||
tensorflow::TensorInterface ret(
|
||||
Tensor(static_cast<tensorflow::DataType>(dtype),
|
||||
tensorflow::TensorShape(dimvec), buf));
|
||||
buf->Unref();
|
||||
size_t elem_size = TF_DataTypeSize(dtype);
|
||||
if (elem_size > 0 && len < (elem_size * ret->tensor.NumElements())) {
|
||||
delete ret;
|
||||
if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
|
||||
return nullptr;
|
||||
}
|
||||
return ret;
|
||||
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(ret)};
|
||||
}
|
||||
|
||||
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) {
|
||||
// It is safe to move the Tensor if and only if we own the unique reference to
|
||||
// it. In that case, we might as well not delete and reallocate, but a future
|
||||
// implementation might need to do so.
|
||||
TensorBuffer* buf = tensorflow::TensorCApi::Buffer(tensor->tensor);
|
||||
if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
|
||||
buf->OwnsMemory()) {
|
||||
return tensor;
|
||||
}
|
||||
return nullptr;
|
||||
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* t) {
|
||||
return t->tensor->CanMove() ? t : nullptr;
|
||||
}
|
||||
|
||||
void TF_DeleteTensor(TF_Tensor* t) { delete t; }
|
||||
|
||||
TF_DataType TF_TensorType(const TF_Tensor* t) {
|
||||
return static_cast<TF_DataType>(t->tensor.dtype());
|
||||
}
|
||||
TF_DataType TF_TensorType(const TF_Tensor* t) { return t->tensor->Type(); }
|
||||
|
||||
int TF_NumDims(const TF_Tensor* t) { return t->tensor.dims(); }
|
||||
int TF_NumDims(const TF_Tensor* t) { return t->tensor->NumDims(); }
|
||||
|
||||
int64_t TF_Dim(const TF_Tensor* t, int dim_index) {
|
||||
return static_cast<int64_t>(t->tensor.dim_size(dim_index));
|
||||
return t->tensor->Dim(dim_index);
|
||||
}
|
||||
|
||||
size_t TF_TensorByteSize(const TF_Tensor* t) {
|
||||
return tensorflow::TensorCApi::Buffer(t->tensor)->size();
|
||||
}
|
||||
size_t TF_TensorByteSize(const TF_Tensor* t) { return t->tensor->ByteSize(); }
|
||||
|
||||
void* TF_TensorData(const TF_Tensor* t) {
|
||||
return tensorflow::TensorCApi::Buffer(t->tensor)->data();
|
||||
}
|
||||
void* TF_TensorData(const TF_Tensor* t) { return t->tensor->Data(); }
|
||||
|
||||
int64_t TF_TensorElementCount(const TF_Tensor* t) {
|
||||
int64_t result = 1;
|
||||
@ -160,16 +148,69 @@ void TF_TensorBitcastFrom(const TF_Tensor* from, TF_DataType type,
|
||||
TF_Tensor* to, const int64_t* new_dims,
|
||||
int num_new_dims, TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
Status cc_status(
|
||||
static_cast<tensorflow::TensorInterface*>(to->tensor.get())
|
||||
->BitcastFrom(*static_cast<const tensorflow::TensorInterface*>(
|
||||
from->tensor.get()),
|
||||
type, new_dims, num_new_dims));
|
||||
Set_TF_Status_from_Status(status, cc_status);
|
||||
}
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
bool TensorInterface::CanMove() const {
|
||||
// It is safe to move the Tensor if and only if we own the unique reference to
|
||||
// it. In that case, we might as well not delete and reallocate, but a future
|
||||
// implementation might need to do so.
|
||||
TensorBuffer* buf = tensorflow::TensorCApi::Buffer(tensor_);
|
||||
if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
|
||||
buf->OwnsMemory()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
TF_DataType TensorInterface::Type() const {
|
||||
return static_cast<TF_DataType>(tensor_.dtype());
|
||||
}
|
||||
|
||||
int TensorInterface::NumDims() const { return tensor_.dims(); }
|
||||
|
||||
int64_t TensorInterface::Dim(int dim_index) const {
|
||||
return static_cast<int64_t>(tensor_.dim_size(dim_index));
|
||||
}
|
||||
|
||||
int64_t TensorInterface::NumElements() const {
|
||||
return static_cast<int64_t>(tensor_.NumElements());
|
||||
}
|
||||
|
||||
size_t TensorInterface::ByteSize() const {
|
||||
return tensorflow::TensorCApi::Buffer(tensor_)->size();
|
||||
}
|
||||
|
||||
void* TensorInterface::Data() const {
|
||||
return tensorflow::TensorCApi::Buffer(tensor_)->data();
|
||||
}
|
||||
|
||||
Status TensorInterface::BitcastFrom(const TensorInterface& from,
|
||||
TF_DataType type, const int64_t* new_dims,
|
||||
int num_new_dims) {
|
||||
tensorflow::TensorShape s;
|
||||
for (int i = 0; i < num_new_dims; ++i) {
|
||||
s.AddDim(new_dims[i]);
|
||||
}
|
||||
Status cc_status(to->tensor.BitcastFrom(
|
||||
from->tensor, static_cast<tensorflow::DataType>(type), s));
|
||||
Set_TF_Status_from_Status(status, cc_status);
|
||||
return tensor_.BitcastFrom(from.tensor_,
|
||||
static_cast<tensorflow::DataType>(type), s);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
void StringEncode(const char* src, size_t src_len, char* dst) {
|
||||
dst = tensorflow::core::EncodeVarint64(dst, src_len);
|
||||
memcpy(dst, src, src_len);
|
||||
}
|
||||
|
||||
size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
|
||||
size_t dst_len, TF_Status* status) {
|
||||
const size_t sz = TF_StringEncodedSize(src_len);
|
||||
@ -185,8 +226,7 @@ size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
|
||||
src_len, "-byte string"));
|
||||
return 0;
|
||||
}
|
||||
dst = tensorflow::core::EncodeVarint64(dst, src_len);
|
||||
memcpy(dst, src, src_len);
|
||||
StringEncode(src, src_len, dst);
|
||||
return sz;
|
||||
}
|
||||
|
||||
@ -245,13 +285,11 @@ static TF_Tensor* EmptyTensor(TF_DataType dtype,
|
||||
namespace tensorflow {
|
||||
|
||||
// Non-static for testing.
|
||||
TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
||||
TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) {
|
||||
*status = tensorflow::Status::OK();
|
||||
if (!src.IsInitialized()) {
|
||||
Set_TF_Status_from_Status(
|
||||
status, FailedPrecondition(
|
||||
"attempt to use a tensor with an uninitialized value"));
|
||||
*status = FailedPrecondition(
|
||||
"attempt to use a tensor with an uninitialized value");
|
||||
return nullptr;
|
||||
}
|
||||
if (src.NumElements() == 0) {
|
||||
@ -259,14 +297,13 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
||||
}
|
||||
if (src.dtype() == tensorflow::DT_RESOURCE) {
|
||||
if (src.shape().dims() != 0) {
|
||||
Set_TF_Status_from_Status(
|
||||
status, InvalidArgument(
|
||||
"Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
|
||||
src.shape().DebugString(),
|
||||
"). Please file a bug at "
|
||||
"https://github.com/tensorflow/tensorflow/issues/new, "
|
||||
"ideally with a "
|
||||
"short code snippet that reproduces this error."));
|
||||
*status = InvalidArgument(
|
||||
"Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
|
||||
src.shape().DebugString(),
|
||||
"). Please file a bug at "
|
||||
"https://github.com/tensorflow/tensorflow/issues/new, "
|
||||
"ideally with a "
|
||||
"short code snippet that reproduces this error.");
|
||||
return nullptr;
|
||||
}
|
||||
const string str =
|
||||
@ -276,12 +313,11 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
||||
return t;
|
||||
}
|
||||
if (src.dtype() != tensorflow::DT_STRING) {
|
||||
auto* result = new TF_Tensor();
|
||||
if (!result->tensor.CopyFrom(src, src.shape())) {
|
||||
delete result;
|
||||
Tensor tensor;
|
||||
if (!tensor.CopyFrom(src, src.shape())) {
|
||||
return nullptr;
|
||||
}
|
||||
return result;
|
||||
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(tensor)};
|
||||
}
|
||||
// DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly
|
||||
// encoded sequence of strings.
|
||||
@ -305,23 +341,15 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
||||
*offsets = (dst - data_start);
|
||||
offsets++;
|
||||
const string& s = srcarray(i);
|
||||
size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
Set_TF_Status_from_Status(
|
||||
status,
|
||||
InvalidArgument("invalid string tensor encoding (string #", i, " of ",
|
||||
srcarray.size(), "): ", TF_Message(status)));
|
||||
delete[] base;
|
||||
return nullptr;
|
||||
}
|
||||
const size_t consumed = TF_StringEncodedSize(s.size());
|
||||
StringEncode(s.data(), s.size(), dst);
|
||||
dst += consumed;
|
||||
dst_len -= consumed;
|
||||
}
|
||||
if (dst != base + size) {
|
||||
Set_TF_Status_from_Status(
|
||||
status, InvalidArgument(
|
||||
"invalid string tensor encoding (decoded ", (dst - base),
|
||||
" bytes, but the tensor is encoded in ", size, " bytes"));
|
||||
*status = InvalidArgument(
|
||||
"invalid string tensor encoding (decoded ", (dst - base),
|
||||
" bytes, but the tensor is encoded in ", size, " bytes");
|
||||
delete[] base;
|
||||
return nullptr;
|
||||
}
|
||||
@ -339,31 +367,35 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
||||
}
|
||||
|
||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
|
||||
if (src->tensor.dtype() == DT_RESOURCE) {
|
||||
if (src->tensor.dims() != 0) {
|
||||
return static_cast<const tensorflow::TensorInterface*>(src->tensor.get())
|
||||
->ToTensor(dst);
|
||||
}
|
||||
|
||||
Status TensorInterface::ToTensor(Tensor* dst) const {
|
||||
if (tensor_.dtype() == DT_RESOURCE) {
|
||||
if (tensor_.dims() != 0) {
|
||||
return InvalidArgument(
|
||||
"Malformed TF_RESOURCE tensor: expected a scalar, got a tensor with "
|
||||
"shape ",
|
||||
src->tensor.shape().DebugString());
|
||||
tensor_.shape().DebugString());
|
||||
}
|
||||
*dst = Tensor(tensorflow::DT_RESOURCE, src->tensor.shape());
|
||||
*dst = Tensor(tensorflow::DT_RESOURCE, tensor_.shape());
|
||||
if (!dst->scalar<tensorflow::ResourceHandle>()().ParseFromString(
|
||||
string(static_cast<const char*>(TF_TensorData(src)),
|
||||
TF_TensorByteSize(src)))) {
|
||||
string(static_cast<const char*>(Data()), ByteSize()))) {
|
||||
return InvalidArgument(
|
||||
"Malformed TF_RESOUCE tensor: unable to parse resource handle");
|
||||
"Malformed TF_RESOURCE tensor: unable to parse resource handle");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
if (src->tensor.dtype() != DT_STRING) {
|
||||
*dst = src->tensor;
|
||||
if (tensor_.dtype() != DT_STRING) {
|
||||
*dst = tensor_;
|
||||
return Status::OK();
|
||||
}
|
||||
// TF_STRING tensors require copying since Tensor class expects a sequence of
|
||||
// string objects.
|
||||
const tensorflow::int64 num_elements = src->tensor.NumElements();
|
||||
const char* input = reinterpret_cast<const char*>(TF_TensorData(src));
|
||||
const size_t src_size = TF_TensorByteSize(src);
|
||||
const tensorflow::int64 num_elements = tensor_.NumElements();
|
||||
const char* input = reinterpret_cast<const char*>(Data());
|
||||
const size_t src_size = ByteSize();
|
||||
if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) <
|
||||
num_elements) {
|
||||
return InvalidArgument(
|
||||
@ -372,7 +404,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
|
||||
const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
|
||||
const char* limit = input + src_size;
|
||||
|
||||
*dst = Tensor(src->tensor.dtype(), src->tensor.shape());
|
||||
*dst = Tensor(tensor_.dtype(), tensor_.shape());
|
||||
auto dstarray = dst->flat<tstring>();
|
||||
for (tensorflow::int64 i = 0; i < num_elements; ++i) {
|
||||
tensorflow::uint64 offset =
|
||||
@ -391,8 +423,8 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool TensorInterface::IsAligned() const { return tensor_.IsAligned(); }
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
bool TF_TensorIsAligned(const TF_Tensor* tensor) {
|
||||
return tensor->tensor.IsAligned();
|
||||
}
|
||||
bool TF_TensorIsAligned(const TF_Tensor* t) { return t->tensor->IsAligned(); }
|
||||
|
@ -16,9 +16,12 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
|
||||
#define TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
// Internal structures used by the C API. These are likely to change and should
|
||||
@ -28,7 +31,7 @@ limitations under the License.
|
||||
// passed to or returned from C functions *by pointer*. Otherwise, changes to
|
||||
// its internal structure will break the C API's binary interface.
|
||||
typedef struct TF_Tensor {
|
||||
::tensorflow::Tensor tensor;
|
||||
std::unique_ptr<AbstractTensorInterface> tensor;
|
||||
} TF_Tensor;
|
||||
|
||||
class TF_ManagedBuffer : public tensorflow::TensorBuffer {
|
||||
@ -83,4 +86,5 @@ void* allocate_tensor(const char* operation, size_t len, Allocator* allocator);
|
||||
// a different Allocator as `arg`.
|
||||
void deallocate_buffer(void* data, size_t len, void* arg);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
|
||||
|
@ -96,7 +96,7 @@ class SymbolicGradientBuilder {
|
||||
// Used to identify nodes at which to stop backprop.
|
||||
std::unordered_set<int> GetStopBackpropNodes(
|
||||
const std::vector<bool>& reachable_nodes,
|
||||
const std::unordered_set<int>& output_nodes);
|
||||
const std::unordered_set<int>& output_nodes) const;
|
||||
|
||||
const Scope& scope_;
|
||||
const ops::GradOpRegistry* registry_;
|
||||
@ -190,7 +190,7 @@ std::vector<bool> SymbolicGradientBuilder::GetReachableNodes() {
|
||||
|
||||
std::unordered_set<int> SymbolicGradientBuilder::GetStopBackpropNodes(
|
||||
const std::vector<bool>& reachable_nodes,
|
||||
const std::unordered_set<int>& output_nodes) {
|
||||
const std::unordered_set<int>& output_nodes) const {
|
||||
// Output nodes that get transitively consumed by other `outputs_` are stored
|
||||
// in `internal_outputs`.
|
||||
std::unordered_set<int> internal_outputs;
|
||||
@ -346,8 +346,8 @@ Status SymbolicGradientBuilder::SumGradients(const Output& src, Output* grad) {
|
||||
"Unable to find backprop list for node.id ", src.node()->name());
|
||||
}
|
||||
const auto& grads = iter->second;
|
||||
// Filter any backproped 'NoGradient' Outputs from 'grads' (if needed).
|
||||
// Return any valid backproped gradients that remain after filtering,
|
||||
// Filter any backpropped 'NoGradient' Outputs from 'grads' (if needed).
|
||||
// Return any valid backpropped gradients that remain after filtering,
|
||||
// or 'NoGradient' otherwise.
|
||||
std::vector<Output> grads_to_keep;
|
||||
for (const Output& o : grads) {
|
||||
@ -519,7 +519,7 @@ Status SymbolicGradientBuilder::AddGradients() {
|
||||
// Backprop along the in edges.
|
||||
// TODO(andydavis) Find cleaner way to map each grad output returned by
|
||||
// gradient function to the src node/output to which it should be
|
||||
// backproped. Maybe grad functions can return a vector of Output pairs to
|
||||
// backpropped. Maybe grad functions can return a vector of Output pairs to
|
||||
// make this association explicit.
|
||||
size_t dx_index = 0;
|
||||
for (const Edge* e : n->in_edges()) {
|
||||
|
@ -64,7 +64,7 @@ bool IsZero(const Scope& scope, const Output& grad) {
|
||||
// Multiply after broadcasting vec to match dimensions of mat.
|
||||
// Args:
|
||||
// vec: A 1-D tensor of dimension [D0]
|
||||
// mat: A 2-D tensor of dimesnion [D0, D1]
|
||||
// mat: A 2-D tensor of dimension [D0, D1]
|
||||
//
|
||||
// Returns:
|
||||
// A tensor of dimension [D0, D1], the result fo vec * mat.
|
||||
|
@ -259,6 +259,9 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) {
|
||||
RunTest(x, x_init_value, y, y_shape);
|
||||
}
|
||||
|
||||
// TODO(rocm):
|
||||
// Re-enable this test once 3D pooling is supported on ROCm platform
|
||||
#ifndef TENSORFLOW_USE_ROCM
|
||||
TEST_F(NNGradTest, MaxPool3DGradHelper) {
|
||||
TensorShape x_shape({1, 3, 3, 3, 1});
|
||||
TensorShape y_shape({1, 1, 1, 1, 1});
|
||||
@ -271,6 +274,7 @@ TEST_F(NNGradTest, MaxPool3DGradHelper) {
|
||||
SetRandomValuesForMaxPooling<float>(&x_init_value);
|
||||
RunTest(x, x_init_value, y, y_shape);
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST_F(NNGradTest, AvgPoolGradHelper) {
|
||||
TensorShape x_shape({1, 2, 2, 1});
|
||||
@ -283,6 +287,9 @@ TEST_F(NNGradTest, AvgPoolGradHelper) {
|
||||
RunTest(x, x_shape, y, y_shape);
|
||||
}
|
||||
|
||||
// TODO(rocm):
|
||||
// Re-enable this test once 3D pooling is supported on ROCm platform
|
||||
#ifndef TENSORFLOW_USE_ROCM
|
||||
TEST_F(NNGradTest, AvgPool3DGradHelper) {
|
||||
TensorShape x_shape({1, 3, 3, 3, 1});
|
||||
TensorShape y_shape({1, 1, 1, 1, 1});
|
||||
@ -293,6 +300,7 @@ TEST_F(NNGradTest, AvgPool3DGradHelper) {
|
||||
auto y = AvgPool3D(scope_, x, ksize, strides, "SAME");
|
||||
RunTest(x, x_shape, y, y_shape);
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST_F(NNGradTest, LRN) {
|
||||
TensorShape x_shape({1, 1, 2, 1});
|
||||
|
@ -124,13 +124,12 @@ cc_library(
|
||||
hdrs = ["bundle_v2.h"],
|
||||
deps = [
|
||||
":constants",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
] + if_not_mobile([
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:strcat",
|
||||
"//tensorflow/core/util/tensor_bundle",
|
||||
]),
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
|
@ -1,5 +1,6 @@
|
||||
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
|
||||
load("//tensorflow/core/platform:build_config.bzl", "if_llvm_aarch64_available")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:private"],
|
||||
@ -27,9 +28,14 @@ cc_library(
|
||||
"compile.h",
|
||||
"flags.h",
|
||||
],
|
||||
defines = if_llvm_aarch64_available(["TF_LLVM_AARCH64_AVAILABLE=1"]),
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
deps = [
|
||||
":aot_only_var_handle_op",
|
||||
":embedded_protocol_buffers",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"//tensorflow/compiler/tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
|
||||
@ -53,10 +59,13 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:target",
|
||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||
] + if_llvm_aarch64_available([
|
||||
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
|
||||
]),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
@ -86,6 +95,19 @@ tf_cc_binary(
|
||||
deps = [":tfcompile_main"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "llvm_targets",
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
deps = [
|
||||
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:target",
|
||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||
] + if_llvm_aarch64_available([
|
||||
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
|
||||
]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfcompile_main",
|
||||
srcs = ["tfcompile_main.cc"],
|
||||
@ -104,11 +126,6 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:aarch64_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:target",
|
||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||
],
|
||||
)
|
||||
|
||||
@ -214,8 +231,13 @@ cc_library(
|
||||
cc_library(
|
||||
name = "aot_only_var_handle_op",
|
||||
srcs = ["aot_only_var_handle_op.cc"],
|
||||
hdrs = ["aot_only_var_handle_op.h"],
|
||||
visibility = [
|
||||
"//tensorflow/compiler/tf2xla:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -13,9 +13,12 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/aot/aot_only_var_handle_op.h"
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -51,6 +54,31 @@ void XlaAotOnlyVarHandleOp::Compile(XlaOpKernelContext* context) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
REGISTER_XLA_OP(Name("VarHandleOp").CompilationOnly(), XlaAotOnlyVarHandleOp);
|
||||
REGISTER_OP(tfcompile::kXlaAotOnlyVarHandleOp)
|
||||
.Doc(R"doc(
|
||||
Internal VarHandleOp registration used for XLA AOT compilation.
|
||||
)doc")
|
||||
.Attr("container: string = ''")
|
||||
.Attr("shared_name: string = ''")
|
||||
.Attr("dtype: type")
|
||||
.Attr("shape: shape")
|
||||
.Output("resource: resource")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
c->set_output(0, c->Scalar());
|
||||
DataType t;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t));
|
||||
PartialTensorShape p;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("shape", &p));
|
||||
shape_inference::ShapeHandle s;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s));
|
||||
c->set_output_handle_shapes_and_types(
|
||||
0, std::vector<shape_inference::ShapeAndType>{{s, t}});
|
||||
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_XLA_OP(Name(tfcompile::kXlaAotOnlyVarHandleOp).CompilationOnly(),
|
||||
XlaAotOnlyVarHandleOp);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
27
tensorflow/compiler/aot/aot_only_var_handle_op.h
Normal file
27
tensorflow/compiler/aot/aot_only_var_handle_op.h
Normal file
@ -0,0 +1,27 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_
|
||||
#define TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
|
||||
static constexpr const char* const kXlaAotOnlyVarHandleOp =
|
||||
"_XlaAotOnlyVarHandleOp";
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_
|
@ -74,16 +74,16 @@ void DumpStatsToStdout(const Stats& stats) {
|
||||
const int kBufSize = 1000;
|
||||
char buf[kBufSize];
|
||||
snprintf(buf, kBufSize, "Mean with %2.0f%% trimmed:", trim_ratio * 100);
|
||||
const string label_trimmed(buf);
|
||||
std::string label_trimmed(buf);
|
||||
snprintf(buf, kBufSize, "Mean of %2.0f%% best:", best_ratio * 100);
|
||||
const string label_best(buf);
|
||||
std::vector<std::pair<string, double>> groups = {
|
||||
std::string label_best(buf);
|
||||
std::vector<std::pair<std::string, double>> groups = {
|
||||
{"Best:", sorted_us.front()},
|
||||
{"Worst:", sorted_us.back()},
|
||||
{"Median:", sorted_us[count_us / 2]},
|
||||
{"Mean:", sum_us / count_us},
|
||||
{label_trimmed, sum_us_trimmed / count_us_trimmed},
|
||||
{label_best, sum_us_best / count_us_best},
|
||||
{std::move(label_trimmed), sum_us_trimmed / count_us_trimmed},
|
||||
{std::move(label_best), sum_us_best / count_us_best},
|
||||
};
|
||||
int max_label_size = 0;
|
||||
double max_us = 0;
|
||||
@ -102,7 +102,7 @@ void DumpStatsToStdout(const Stats& stats) {
|
||||
}
|
||||
// Dump stats out.
|
||||
printf("Benchmark ran %zu iterations over %lld us\n", count_us,
|
||||
stats.total_us);
|
||||
static_cast<long long>(stats.total_us)); // NOLINT
|
||||
for (const auto& g : groups) {
|
||||
printf(" %-*s %*.3f us\n", max_label_size, g.first.c_str(), max_digits + 4,
|
||||
g.second);
|
||||
@ -114,7 +114,8 @@ void Benchmark(const Options& options, const BenchmarkFn& fn, Stats* stats) {
|
||||
const int64 max_us = (options.max_micros <= 0 && options.max_iters <= 0)
|
||||
? Options::kDefaultMicros
|
||||
: options.max_micros;
|
||||
printf("Running benchmark for %lld us\n", max_us);
|
||||
// NOLINTNEXTLINE
|
||||
printf("Running benchmark for %lld us\n", static_cast<long long>(max_us));
|
||||
const int64 start_us = NowMicros();
|
||||
int64 iters = 0;
|
||||
while (true) {
|
||||
|
@ -423,8 +423,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
|
||||
GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
|
||||
const string include_xla_data_proto =
|
||||
opts.gen_program_shape
|
||||
?
|
||||
R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
|
||||
? R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
|
||||
: "";
|
||||
|
||||
const string include_hlo_profile_printer_data_proto =
|
||||
|
@ -20,6 +20,8 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "llvm-c/Target.h"
|
||||
#include "tensorflow/compiler/aot/codegen.h"
|
||||
#include "tensorflow/compiler/aot/flags.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
@ -90,7 +92,7 @@ Status CompileXla(xla::CompileOnlyClient* client,
|
||||
|
||||
} // namespace
|
||||
|
||||
Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
||||
Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
|
||||
const MainFlags& flags, CompileResult* compile_result) {
|
||||
// Converts the graph into an XLA computation, and compiles the
|
||||
// computation.
|
||||
@ -108,8 +110,8 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
||||
if (!flags.mlir_components.empty()) {
|
||||
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertGraphDefToXla(graph_def, config, client, &computation));
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config,
|
||||
client, &computation));
|
||||
}
|
||||
if (!flags.out_session_module.empty()) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
|
||||
@ -132,5 +134,96 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
||||
return CompileXla(client, computation, aot_opts, compile_result);
|
||||
}
|
||||
|
||||
static Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
|
||||
if (absl::EndsWith(fname, ".pbtxt")) {
|
||||
return ReadTextProto(Env::Default(), fname, proto);
|
||||
} else {
|
||||
return ReadBinaryProto(Env::Default(), fname, proto);
|
||||
}
|
||||
}
|
||||
|
||||
static std::once_flag targets_init;
|
||||
|
||||
static void InitializeTargets() {
|
||||
// Initialize all LLVM targets so we can cross compile.
|
||||
#if TF_LLVM_AARCH64_AVAILABLE
|
||||
LLVMInitializeAArch64Target();
|
||||
LLVMInitializeAArch64TargetInfo();
|
||||
LLVMInitializeAArch64TargetMC();
|
||||
LLVMInitializeAArch64AsmPrinter();
|
||||
#endif
|
||||
LLVMInitializeARMTarget();
|
||||
LLVMInitializeARMTargetInfo();
|
||||
LLVMInitializeARMTargetMC();
|
||||
LLVMInitializeARMAsmPrinter();
|
||||
LLVMInitializePowerPCTarget();
|
||||
LLVMInitializePowerPCTargetInfo();
|
||||
LLVMInitializePowerPCTargetMC();
|
||||
LLVMInitializePowerPCAsmPrinter();
|
||||
LLVMInitializeX86Target();
|
||||
LLVMInitializeX86TargetInfo();
|
||||
LLVMInitializeX86TargetMC();
|
||||
LLVMInitializeX86AsmPrinter();
|
||||
}
|
||||
|
||||
Status Main(const MainFlags& flags) {
|
||||
std::call_once(targets_init, &InitializeTargets);
|
||||
|
||||
// Process config.
|
||||
tf2xla::Config config;
|
||||
if (flags.config.empty()) {
|
||||
return errors::InvalidArgument("Must specify --config");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
|
||||
TF_RETURN_IF_ERROR(ValidateConfig(config));
|
||||
if (flags.dump_fetch_nodes) {
|
||||
std::set<string> nodes;
|
||||
for (const tf2xla::Fetch& fetch : config.fetch()) {
|
||||
nodes.insert(fetch.id().node_name());
|
||||
}
|
||||
std::cout << absl::StrJoin(nodes, ",");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Read and initialize the graph.
|
||||
if (flags.graph.empty()) {
|
||||
return errors::InvalidArgument("Must specify --graph");
|
||||
}
|
||||
GraphDef graph_def;
|
||||
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
|
||||
CompileResult compile_result;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CompileGraph(std::move(graph_def), config, flags, &compile_result));
|
||||
|
||||
// Write output files.
|
||||
Env* env = Env::Default();
|
||||
const std::vector<char>& obj = compile_result.aot->object_file_data();
|
||||
TF_RETURN_IF_ERROR(
|
||||
WriteStringToFile(env, flags.out_function_object,
|
||||
absl::string_view(obj.data(), obj.size())));
|
||||
CodegenOpts codegen_opts;
|
||||
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
|
||||
codegen_opts.gen_program_shape = flags.gen_program_shape;
|
||||
codegen_opts.target_triple = flags.target_triple;
|
||||
if (flags.cpp_class.empty()) {
|
||||
return errors::InvalidArgument("Must specify --cpp_class");
|
||||
}
|
||||
codegen_opts.gen_hlo_profile_printer_data =
|
||||
xla::GetDebugOptionsFromFlags().xla_hlo_profile();
|
||||
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
|
||||
&codegen_opts.namespaces));
|
||||
|
||||
MetadataResult metadata_result;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GenerateMetadata(codegen_opts, compile_result, &metadata_result));
|
||||
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object,
|
||||
metadata_result.object_file_data));
|
||||
string header;
|
||||
TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result,
|
||||
metadata_result, &header));
|
||||
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
@ -42,9 +42,12 @@ struct CompileResult {
|
||||
// that performs the graph operations.
|
||||
//
|
||||
// The XLA compilation options are specified in the flags.
|
||||
Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
||||
Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
|
||||
const MainFlags& flags, CompileResult* compile_result);
|
||||
|
||||
// The full compilation method, for reuse in a library setting.
|
||||
Status Main(const MainFlags& flags);
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -25,6 +25,7 @@ namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
|
||||
// Flags for the tfcompile binary. See *.cc file for descriptions.
|
||||
|
||||
struct MainFlags {
|
||||
string graph;
|
||||
string config;
|
||||
|
@ -25,6 +25,7 @@ test_suite(
|
||||
":test_graph_tfmatmulandadd_test",
|
||||
":test_graph_tfsplits_test",
|
||||
":test_graph_tftop_k_test",
|
||||
":test_graph_tfvariable_readonly_test",
|
||||
":test_graph_tfvariable_sequential_updates_test",
|
||||
":test_graph_tfvariable_test",
|
||||
":tfcompile_test",
|
||||
@ -73,6 +74,7 @@ genrule(
|
||||
"test_graph_tfsplits.pb",
|
||||
"test_graph_tftop_k.pb",
|
||||
"test_graph_tfvariable.pb",
|
||||
"test_graph_tfvariable_readonly.pb",
|
||||
"test_graph_tfvariable_sequential_updates.pb",
|
||||
],
|
||||
# Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any
|
||||
@ -238,6 +240,17 @@ tf_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfvariable_readonly",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfvariable_readonly.config.pbtxt",
|
||||
cpp_class = "VariableReadonlyComp",
|
||||
graph = "test_graph_tfvariable_readonly.pb",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfvariable_sequential_updates",
|
||||
testonly = 1,
|
||||
@ -269,6 +282,7 @@ tf_cc_test(
|
||||
":test_graph_tfsplits",
|
||||
":test_graph_tftop_k",
|
||||
":test_graph_tfvariable",
|
||||
":test_graph_tfvariable_readonly",
|
||||
":test_graph_tfvariable_sequential_updates",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
@ -323,6 +337,42 @@ tf_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfcond_mlir_bridge",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfcond.config.pbtxt",
|
||||
cpp_class = "CondComp",
|
||||
graph = "test_graph_tfcond.pb",
|
||||
mlir_components = "Bridge",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfassert_eq_mlir_bridge",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfassert_eq.config.pbtxt",
|
||||
cpp_class = "AssertComp",
|
||||
graph = "test_graph_tfassert_eq.pb",
|
||||
mlir_components = "Bridge",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfgather_mlir_bridge",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfgather.config.pbtxt",
|
||||
cpp_class = "GatherComp",
|
||||
graph = "test_graph_tfgather.pb",
|
||||
mlir_components = "Bridge",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfmatmul_mlir_bridge",
|
||||
testonly = 1,
|
||||
@ -361,6 +411,42 @@ tf_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfsplits_mlir_bridge",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfsplits.config.pbtxt",
|
||||
cpp_class = "SplitsComp",
|
||||
graph = "test_graph_tfsplits.pb",
|
||||
mlir_components = "Bridge",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tftop_k_mlir_bridge",
|
||||
testonly = 1,
|
||||
config = "test_graph_tftop_k.config.pbtxt",
|
||||
cpp_class = "TopKComp",
|
||||
graph = "test_graph_tftop_k.pb",
|
||||
mlir_components = "Bridge",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfvariable_readonly_mlir_bridge",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfvariable_readonly.config.pbtxt",
|
||||
cpp_class = "VariableReadonlyComp",
|
||||
graph = "test_graph_tfvariable_readonly.pb",
|
||||
mlir_components = "Bridge",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tfcompile_test_mlir_bridge",
|
||||
srcs = ["tfcompile_test.cc"],
|
||||
@ -372,9 +458,15 @@ tf_cc_test(
|
||||
":test_graph_tfadd_mlir_bridge",
|
||||
":test_graph_tfadd_with_ckpt_mlir_bridge",
|
||||
":test_graph_tfadd_with_ckpt_saver_mlir_bridge",
|
||||
":test_graph_tfassert_eq_mlir_bridge",
|
||||
":test_graph_tfcond_mlir_bridge",
|
||||
":test_graph_tfgather_mlir_bridge",
|
||||
":test_graph_tfmatmul_mlir_bridge",
|
||||
":test_graph_tfmatmulandadd_mlir_bridge",
|
||||
":test_graph_tfmatmulandadd_with_profiling_mlir_bridge",
|
||||
":test_graph_tfsplits_mlir_bridge",
|
||||
":test_graph_tftop_k_mlir_bridge",
|
||||
":test_graph_tfvariable_readonly_mlir_bridge",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
|
@ -34,6 +34,7 @@ from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import variables
|
||||
@ -153,6 +154,14 @@ def tftop_k(_):
|
||||
array_ops.identity(output[1], name='indices')
|
||||
|
||||
|
||||
def tfvariable_readonly(_):
|
||||
x = variables.Variable(1000.0, name='x')
|
||||
old_x = x.value()
|
||||
with ops.control_dependencies([old_x]):
|
||||
new_value = math_ops.add(old_x, 42.0)
|
||||
array_ops.identity(new_value, name='result')
|
||||
|
||||
|
||||
def tfvariable(_):
|
||||
x = variables.Variable(1000.0, name='x')
|
||||
old_x = x.value()
|
||||
@ -184,6 +193,7 @@ def write_graph(build_graph, out_dir):
|
||||
|
||||
|
||||
def main(_):
|
||||
control_flow_util.enable_control_flow_v2()
|
||||
write_graph(tfadd, FLAGS.out_dir)
|
||||
write_graph(tfadd_with_ckpt, FLAGS.out_dir)
|
||||
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
|
||||
@ -196,6 +206,7 @@ def main(_):
|
||||
write_graph(tfsplits, FLAGS.out_dir)
|
||||
write_graph(tftop_k, FLAGS.out_dir)
|
||||
write_graph(tfvariable, FLAGS.out_dir)
|
||||
write_graph(tfvariable_readonly, FLAGS.out_dir)
|
||||
write_graph(tfvariable_sequential_updates, FLAGS.out_dir)
|
||||
|
||||
|
||||
|
@ -0,0 +1,12 @@
|
||||
# Text form of tensorflow.tf2xla.Config proto.
|
||||
fetch {
|
||||
id { node_name: "result" }
|
||||
}
|
||||
|
||||
variable {
|
||||
node_name: "x"
|
||||
shape {
|
||||
}
|
||||
type: DT_FLOAT
|
||||
readonly: true
|
||||
}
|
@ -30,9 +30,15 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfcond_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfgather_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tftop_k_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly_mlir_bridge.h"
|
||||
#else
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
|
||||
@ -47,6 +53,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates.h"
|
||||
#endif
|
||||
|
||||
@ -167,8 +174,6 @@ TEST(TFCompileTest, AddWithCkptSaver) {
|
||||
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
|
||||
}
|
||||
|
||||
// TODO(bixia): the following tests failed with MLIR bridge.
|
||||
#if !defined(ENABLE_MLIR_BRIDGE_TEST)
|
||||
TEST(TFCompileTest, Cond) {
|
||||
CondComp cond;
|
||||
EXPECT_EQ(cond.arg0_data(), cond.arg_data(0));
|
||||
@ -233,7 +238,6 @@ TEST(TFCompileTest, Gather) {
|
||||
EXPECT_EQ(gather_const.result0_data(), gather.results()[0]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST(TFCompileTest, MatMul2) {
|
||||
Eigen::ThreadPool tp(2);
|
||||
@ -439,6 +443,7 @@ TEST(TFCompileTest, Function) {
|
||||
EXPECT_EQ(add_fn.result0_data()[0], 3);
|
||||
EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]);
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST(TFCompileTest, Splits) {
|
||||
Eigen::ThreadPool tp(1);
|
||||
@ -492,6 +497,22 @@ TEST(TFCompileTest, TopK) {
|
||||
EXPECT_EQ(expected_indices[1], fn.result1(1));
|
||||
}
|
||||
|
||||
TEST(TFCompileTest, VariableReadonly) {
|
||||
Eigen::ThreadPool tp(1);
|
||||
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
|
||||
|
||||
VariableReadonlyComp fn;
|
||||
float x = 23;
|
||||
fn.set_var_x_data(&x);
|
||||
|
||||
fn.set_thread_pool(&device);
|
||||
fn.Run();
|
||||
EXPECT_EQ(fn.result0(), 65);
|
||||
EXPECT_EQ(fn.var_x(), 23);
|
||||
}
|
||||
|
||||
// TODO(bixia): the following tests failed with MLIR bridge.
|
||||
#if !defined(ENABLE_MLIR_BRIDGE_TEST)
|
||||
TEST(TFCompileTest, Variable) {
|
||||
Eigen::ThreadPool tp(1);
|
||||
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
|
||||
@ -564,6 +585,7 @@ TEST(TFCompileTest, VariableSequentialUpdatesNoAlloc) {
|
||||
fn.Run();
|
||||
EXPECT_NEAR(x, 0.594322f, 1e-6);
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST(TFCompileTest, AssertEqAndReturnDiff) {
|
||||
// Assert is converted into a no-op in XLA, so there is no failure even if the
|
||||
@ -665,6 +687,11 @@ TEST(TFCompileTest, HloProfiling) {
|
||||
/*clock_rate_ghz=*/1.0);
|
||||
VLOG(1) << "Original HLO profile string:\n" << hlo_profile_as_string;
|
||||
|
||||
// Replace Arg_n with argn when the MLIR bridge is used.
|
||||
#if defined(ENABLE_MLIR_BRIDGE_TEST)
|
||||
RE2::GlobalReplace(&hlo_profile_as_string, "(Arg_)([0-9].)", "arg\\2");
|
||||
#endif
|
||||
|
||||
// Strip away identifier details from the profile string to avoid this test
|
||||
// being a change detector for xla internals. Identifiers such as '%dot.0.7'
|
||||
// just become '%dot'.
|
||||
@ -690,7 +717,6 @@ TEST(TFCompileTest, HloProfiling) {
|
||||
IsSupersetOf({header, total_cycles_profile_line, dot_profile_line,
|
||||
add_profile_line, tuple_profile_line}));
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace
|
||||
} // namespace tfcompile
|
||||
|
@ -21,7 +21,6 @@ limitations under the License.
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm-c/Target.h"
|
||||
#include "tensorflow/compiler/aot/codegen.h"
|
||||
#include "tensorflow/compiler/aot/compile.h"
|
||||
#include "tensorflow/compiler/aot/flags.h"
|
||||
@ -56,88 +55,6 @@ const char kUsageHeader[] =
|
||||
"--cpp_class=\"mynamespace::MyComputation\"\n"
|
||||
"\n";
|
||||
|
||||
Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
|
||||
if (absl::EndsWith(fname, ".pbtxt")) {
|
||||
return ReadTextProto(Env::Default(), fname, proto);
|
||||
} else {
|
||||
return ReadBinaryProto(Env::Default(), fname, proto);
|
||||
}
|
||||
}
|
||||
|
||||
Status Main(const MainFlags& flags) {
|
||||
// Initialize all LLVM targets so we can cross compile.
|
||||
LLVMInitializeAArch64Target();
|
||||
LLVMInitializeAArch64TargetInfo();
|
||||
LLVMInitializeAArch64TargetMC();
|
||||
LLVMInitializeAArch64AsmPrinter();
|
||||
LLVMInitializeARMTarget();
|
||||
LLVMInitializeARMTargetInfo();
|
||||
LLVMInitializeARMTargetMC();
|
||||
LLVMInitializeARMAsmPrinter();
|
||||
LLVMInitializePowerPCTarget();
|
||||
LLVMInitializePowerPCTargetInfo();
|
||||
LLVMInitializePowerPCTargetMC();
|
||||
LLVMInitializePowerPCAsmPrinter();
|
||||
LLVMInitializeX86Target();
|
||||
LLVMInitializeX86TargetInfo();
|
||||
LLVMInitializeX86TargetMC();
|
||||
LLVMInitializeX86AsmPrinter();
|
||||
|
||||
// Process config.
|
||||
tf2xla::Config config;
|
||||
if (flags.config.empty()) {
|
||||
return errors::InvalidArgument("Must specify --config");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
|
||||
TF_RETURN_IF_ERROR(ValidateConfig(config));
|
||||
if (flags.dump_fetch_nodes) {
|
||||
std::set<string> nodes;
|
||||
for (const tf2xla::Fetch& fetch : config.fetch()) {
|
||||
nodes.insert(fetch.id().node_name());
|
||||
}
|
||||
std::cout << absl::StrJoin(nodes, ",");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Read and initialize the graph.
|
||||
if (flags.graph.empty()) {
|
||||
return errors::InvalidArgument("Must specify --graph");
|
||||
}
|
||||
GraphDef graph_def;
|
||||
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
|
||||
CompileResult compile_result;
|
||||
TF_RETURN_IF_ERROR(CompileGraph(graph_def, config, flags, &compile_result));
|
||||
|
||||
// Write output files.
|
||||
Env* env = Env::Default();
|
||||
const std::vector<char>& obj = compile_result.aot->object_file_data();
|
||||
TF_RETURN_IF_ERROR(
|
||||
WriteStringToFile(env, flags.out_function_object,
|
||||
absl::string_view(obj.data(), obj.size())));
|
||||
CodegenOpts codegen_opts;
|
||||
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
|
||||
codegen_opts.gen_program_shape = flags.gen_program_shape;
|
||||
codegen_opts.target_triple = flags.target_triple;
|
||||
if (flags.cpp_class.empty()) {
|
||||
return errors::InvalidArgument("Must specify --cpp_class");
|
||||
}
|
||||
codegen_opts.gen_hlo_profile_printer_data =
|
||||
xla::GetDebugOptionsFromFlags().xla_hlo_profile();
|
||||
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
|
||||
&codegen_opts.namespaces));
|
||||
|
||||
MetadataResult metadata_result;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GenerateMetadata(codegen_opts, compile_result, &metadata_result));
|
||||
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object,
|
||||
metadata_result.object_file_data));
|
||||
string header;
|
||||
TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result,
|
||||
metadata_result, &header));
|
||||
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // end namespace tfcompile
|
||||
} // end namespace tensorflow
|
||||
|
||||
|
@ -4,12 +4,7 @@ load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilati
|
||||
load("//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
":internal",
|
||||
# BEGIN-GOOGLE-INTERNAL
|
||||
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
|
||||
# END-GOOGLE-INTERNAL
|
||||
],
|
||||
default_visibility = [":internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
@ -82,19 +77,6 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_mlir_gpu_jit",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = if_cuda_or_rocm([
|
||||
":jit_compilation_passes",
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||
"//tensorflow/compiler/xla/service:mlir_gpu_plugin",
|
||||
]),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_cpu_device",
|
||||
srcs = ["xla_cpu_device.cc"],
|
||||
@ -120,6 +102,7 @@ cc_library(
|
||||
srcs = ["xla_gpu_device.cc"],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
":flags",
|
||||
":jit_compilation_passes",
|
||||
":xla_device",
|
||||
":xla_kernel_creator", # buildcleaner: keep
|
||||
@ -128,6 +111,7 @@ cc_library(
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:gpu_init",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
@ -1584,7 +1584,6 @@ DeadnessAnalysis::~DeadnessAnalysis() {}
|
||||
absl::flat_hash_map<TensorId, string, TensorId::Hasher>
|
||||
DeadnessAnalysisImpl::PredicateMapAsString() const {
|
||||
absl::flat_hash_map<TensorId, string, TensorId::Hasher> result;
|
||||
std::vector<TensorId> tensor_ids;
|
||||
for (const auto& kv_pair : predicate_map_) {
|
||||
CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second);
|
||||
}
|
||||
|
@ -374,39 +374,6 @@ xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
|
||||
return new_def;
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NOINLINE Status
|
||||
ValidateOutsideCompilationCallNode(Node* call_node) {
|
||||
// DT_INT64 as input/output for outside compilation is not supported yet:
|
||||
// b/120809951.
|
||||
for (const Edge* e : call_node->in_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
continue;
|
||||
}
|
||||
DataType dtype = e->src()->output_type(e->src_output());
|
||||
if (dtype == DT_INT64) {
|
||||
return errors::Unimplemented(
|
||||
"int64 input for outside compilation is not supported yet: "
|
||||
"b/120809951. Please cast output of node ",
|
||||
e->src()->DebugString(),
|
||||
" to int32 before feeding it into outside compilation.");
|
||||
}
|
||||
}
|
||||
for (const Edge* e : call_node->out_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
continue;
|
||||
}
|
||||
DataType dtype = e->dst()->input_type(e->dst_input());
|
||||
if (dtype == DT_INT64) {
|
||||
return errors::Unimplemented(
|
||||
"int64 output for outside compilation is not supported yet: "
|
||||
"b/120809951. Please cast input of node ",
|
||||
e->dst()->DebugString(),
|
||||
" to int32 before returning it from outside compilation.");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Replace outside compilation function call node with XlaHostCompute node.
|
||||
TF_ATTRIBUTE_NOINLINE xla::StatusOr<Node*> ReplaceOutsideCompilationCallNode(
|
||||
Graph* g, Node* call_node, const std::map<string, int>& host_compute_core,
|
||||
@ -2384,7 +2351,6 @@ Status ExtractOutsideCompilationForFunction(
|
||||
}
|
||||
std::map<string, Node*> host_compute_nodes;
|
||||
for (Node* n : outside_compilation_nodes) {
|
||||
TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n));
|
||||
auto host_compute_node_or = ReplaceOutsideCompilationCallNode(
|
||||
graph_out.get(), n, host_compute_core, *cluster_deps);
|
||||
TF_RETURN_IF_ERROR(host_compute_node_or.status());
|
||||
|
@ -155,6 +155,7 @@ void AllocateAndParseFlags() {
|
||||
|
||||
device_flags = new XlaDeviceFlags;
|
||||
device_flags->tf_xla_compile_on_demand = false;
|
||||
device_flags->tf_xla_enable_xla_devices = true;
|
||||
|
||||
ops_flags = new XlaOpsCommonFlags;
|
||||
ops_flags->tf_xla_always_defer_compilation = false;
|
||||
@ -187,6 +188,12 @@ void AllocateAndParseFlags() {
|
||||
"Switch a device into 'on-demand' mode, where instead of "
|
||||
"autoclustering ops are compiled one by one just-in-time."),
|
||||
|
||||
Flag("tf_xla_enable_xla_devices",
|
||||
&device_flags->tf_xla_enable_xla_devices,
|
||||
"Generate XLA_* devices, where placing a computation on such a "
|
||||
"device"
|
||||
"forces compilation by XLA. Deprecated."),
|
||||
|
||||
Flag("tf_xla_always_defer_compilation",
|
||||
&ops_flags->tf_xla_always_defer_compilation, ""),
|
||||
|
||||
|
@ -87,6 +87,9 @@ struct XlaDeviceFlags {
|
||||
// Enabling this mode by a legacy flag is a temporary mechanism. When this
|
||||
// feature is battle-tested, we will switch this to be a session option.
|
||||
bool tf_xla_compile_on_demand;
|
||||
|
||||
// Enables "XLA" devices if this flag is set.
|
||||
bool tf_xla_enable_xla_devices;
|
||||
};
|
||||
|
||||
// Flags common to the _Xla* ops and their kernels.
|
||||
|
@ -1776,9 +1776,9 @@ absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable() {
|
||||
"Lgamma", "Digamma",
|
||||
// Binary
|
||||
"Add", "AddV2", "Sub", "Mul", "Div", "Atan2", "Complex", "DivNoNan",
|
||||
"MulNoNan", "FloorDiv", "Xlogy", "Xdivy", "FloorMod", "BitwiseAnd",
|
||||
"BitwiseOr", "BitwiseXor", "LeftShift", "RightShift", "LogicalAnd",
|
||||
"LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
|
||||
"MulNoNan", "FloorDiv", "Xlogy", "Xlog1py", "Xdivy", "FloorMod",
|
||||
"BitwiseAnd", "BitwiseOr", "BitwiseXor", "LeftShift", "RightShift",
|
||||
"LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
|
||||
"ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "TruncateDiv",
|
||||
"TruncateMod", "Equal", "NotEqual", "Greater", "GreaterEqual",
|
||||
"Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", "SoftsignGrad",
|
||||
@ -1872,6 +1872,8 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"Einsum",
|
||||
"EmptyTensorList",
|
||||
"ExtractImagePatches",
|
||||
"Igamma",
|
||||
"Igammac",
|
||||
"FFT",
|
||||
"FFT2D",
|
||||
"FFT3D",
|
||||
|
@ -36,8 +36,13 @@ class XlaCpuDeviceFactory : public DeviceFactory {
|
||||
};
|
||||
|
||||
Status XlaCpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
||||
devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0"));
|
||||
XlaDeviceFlags* flags = GetXlaDeviceFlags();
|
||||
if (!flags->tf_xla_enable_xla_devices) {
|
||||
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0"));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -45,6 +50,10 @@ Status XlaCpuDeviceFactory::CreateDevices(
|
||||
const SessionOptions& session_options, const string& name_prefix,
|
||||
std::vector<std::unique_ptr<Device>>* devices) {
|
||||
XlaDeviceFlags* flags = GetXlaDeviceFlags();
|
||||
if (!flags->tf_xla_enable_xla_devices) {
|
||||
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
|
||||
return Status::OK();
|
||||
}
|
||||
bool compile_on_demand = flags->tf_xla_compile_on_demand;
|
||||
|
||||
XlaOpRegistry::DeviceRegistration registration;
|
||||
|
@ -140,7 +140,6 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
||||
// The device tensor should always be fresh.
|
||||
TF_RET_CHECK(!xla_tensor->has_shaped_buffer());
|
||||
|
||||
xla_tensor->set_host_tensor(*cpu_tensor);
|
||||
TF_RETURN_IF_ERROR(
|
||||
xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
|
||||
stream_->parent()->device_ordinal()));
|
||||
|
@ -14,17 +14,20 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
|
||||
// operators using XLA via the XLA "CUDA" (GPU) backend.
|
||||
// operators using XLA via the XLA "CUDA" or "ROCM" (GPU) backend.
|
||||
|
||||
#include <set>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#include "tensorflow/compiler/jit/xla_device_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -61,7 +64,14 @@ class XlaGpuDeviceFactory : public DeviceFactory {
|
||||
};
|
||||
|
||||
Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
||||
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA");
|
||||
XlaDeviceFlags* flags = GetXlaDeviceFlags();
|
||||
if (!flags->tf_xla_enable_xla_devices) {
|
||||
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
auto platform =
|
||||
se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
|
||||
if (!platform.ok()) {
|
||||
// Treat failures as non-fatal; there might not be a GPU in the machine.
|
||||
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
|
||||
@ -84,6 +94,12 @@ Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
||||
Status XlaGpuDeviceFactory::CreateDevices(
|
||||
const SessionOptions& session_options, const string& name_prefix,
|
||||
std::vector<std::unique_ptr<Device>>* devices) {
|
||||
XlaDeviceFlags* flags = GetXlaDeviceFlags();
|
||||
if (!flags->tf_xla_enable_xla_devices) {
|
||||
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
XlaOpRegistry::DeviceRegistration registration;
|
||||
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
|
||||
registration.autoclustering_policy =
|
||||
@ -103,7 +119,8 @@ Status XlaGpuDeviceFactory::CreateDevices(
|
||||
RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT);
|
||||
(void)registrations;
|
||||
|
||||
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA");
|
||||
auto platform =
|
||||
se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
|
||||
if (!platform.ok()) {
|
||||
// Treat failures as non-fatal; there might not be a GPU in the machine.
|
||||
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
|
||||
|
@ -222,8 +222,9 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
OpKernelConstruction construction(
|
||||
DeviceType(dev->device_type()), dev,
|
||||
dev->GetAllocator(AllocatorAttributes()), &node_def,
|
||||
&fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types,
|
||||
fbody->ret_types, output_memory_types, flr->graph_def_version(), &s);
|
||||
&fbody->fdef.signature(), flr, dev->resource_manager(), fbody->arg_types,
|
||||
input_memory_types, fbody->ret_types, output_memory_types,
|
||||
flr->graph_def_version(), &s);
|
||||
|
||||
*kernel = absl::make_unique<XlaLocalLaunchBase>(
|
||||
&construction, constant_arg_indices, resource_arg_indices, function);
|
||||
|
@ -44,8 +44,11 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:AffineDialectRegistration",
|
||||
"@llvm-project//mlir:LoopDialectRegistration",
|
||||
"@llvm-project//mlir:MlirOptLib",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOpsDialectRegistration",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir/test:TestTransforms",
|
||||
],
|
||||
@ -80,9 +83,10 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
|
||||
"//tensorflow/compiler/mlir/xla:xla_lower",
|
||||
"@llvm-project//mlir:AffineDialectRegistration",
|
||||
"//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts",
|
||||
"//tensorflow/compiler/mlir/xla:xla_test_passes",
|
||||
"@llvm-project//mlir:AffineOps",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:QuantOpsDialectRegistration",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -47,6 +47,14 @@ gentbl(
|
||||
"-gen-op-doc",
|
||||
"g3doc/tfl_ops.md",
|
||||
),
|
||||
(
|
||||
"-gen-op-interface-decls",
|
||||
"ir/tfl_ops_interface.h.inc",
|
||||
),
|
||||
(
|
||||
"-gen-op-interface-defs",
|
||||
"ir/tfl_ops_interface.cc.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "ir/tfl_ops.td",
|
||||
@ -177,11 +185,12 @@ cc_library(
|
||||
"ir/tfl_ops.cc",
|
||||
"ir/tfl_ops.cc.inc",
|
||||
"ir/tfl_ops.h.inc",
|
||||
"ir/tfl_ops_interface.cc.inc",
|
||||
"ir/tfl_ops_interface.h.inc",
|
||||
"utils/attribute_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"ir/tfl_ops.h",
|
||||
"ir/tfl_traits.h",
|
||||
"transforms/passes.h",
|
||||
"utils/attribute_utils.h",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",
|
||||
@ -330,6 +339,7 @@ cc_library(
|
||||
cc_library(
|
||||
name = "tensorflow_lite_quantize",
|
||||
srcs = [
|
||||
"transforms/default_quant_params.cc",
|
||||
"transforms/generated_post_quantize.inc",
|
||||
"transforms/generated_quantize.inc",
|
||||
"transforms/load_quantization_recipe.cc",
|
||||
@ -506,6 +516,7 @@ cc_library(
|
||||
"//tensorflow/lite:schema_fbs_version",
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite/delegates/flex:whitelisted_flex_ops_lib",
|
||||
"//tensorflow/lite/kernels/internal:kernel_utils",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/lite/tools/versioning:op_version",
|
||||
"@com_google_absl//absl/base",
|
||||
@ -671,12 +682,16 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
exports_files(
|
||||
["transforms/passes.h"],
|
||||
cc_library(
|
||||
name = "empty_passes",
|
||||
hdrs = ["transforms/passes.h"],
|
||||
visibility = [
|
||||
"//configs/devtools/hawkeye/tflite:__subpackages__",
|
||||
"//learning/brain/models/app_benchmarks:__subpackages__",
|
||||
"//tensorflow/compiler/mlir/lite:friends",
|
||||
"//tensorflow/lite/experimental/mlir:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
"@llvm-project//llvm:support",
|
||||
],
|
||||
)
|
||||
|
@ -31,10 +31,11 @@ struct PassConfig {
|
||||
: emit_builtin_tflite_ops(true),
|
||||
lower_tensor_list_ops(false),
|
||||
trim_functions_whitelist({}),
|
||||
quant_specs(specs),
|
||||
quant_specs(std::move(specs)),
|
||||
skip_control_dialect(false),
|
||||
form_clusters(false),
|
||||
inline_functions(false) {}
|
||||
inline_functions(false),
|
||||
unfold_batch_matmul(true) {}
|
||||
|
||||
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
|
||||
// added, which produces TF Lite ops.
|
||||
@ -57,6 +58,9 @@ struct PassConfig {
|
||||
// Inline function calls within the main function in the MLIR module, prior
|
||||
// to legalization to TFLite.
|
||||
bool inline_functions;
|
||||
// if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set
|
||||
// of tfl.fully_connected ops.
|
||||
bool unfold_batch_matmul;
|
||||
};
|
||||
|
||||
} // namespace TFL
|
||||
|
@ -389,7 +389,6 @@ StatusOr<mlir::ElementsAttr> ConvertIntBuffer(
|
||||
mlir::RankedTensorType shaped_type, mlir::Type elem_type,
|
||||
const std::vector<uint8_t>& buffer) {
|
||||
unsigned bit_width;
|
||||
mlir::RankedTensorType buffer_type;
|
||||
if (auto itype = elem_type.dyn_cast<mlir::IntegerType>()) {
|
||||
bit_width = itype.getWidth();
|
||||
} else if (auto qtype = elem_type.dyn_cast<QuantizedType>()) {
|
||||
@ -920,15 +919,13 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
// represents TFLite, this entry point must be called "main"
|
||||
// TODO(b/131175224,b/132239787) Support multiple entry points
|
||||
std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
|
||||
if (subgraph.name.empty()) {
|
||||
if (index == 0) {
|
||||
return "main";
|
||||
} else {
|
||||
return llvm::formatv("fn_{0}", index).str();
|
||||
}
|
||||
} else {
|
||||
return subgraph.name;
|
||||
if (index == 0) {
|
||||
return "main";
|
||||
}
|
||||
if (subgraph.name.empty()) {
|
||||
return llvm::formatv("fn_{0}", index).str();
|
||||
}
|
||||
return subgraph.name;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -259,9 +259,9 @@ Status mlir::CustomOptionsToAttributes(
|
||||
attributes->emplace_back(builder.getNamedAttr(
|
||||
"stride_w", builder.getI32IntegerAttr(pool_params->stride_width)));
|
||||
attributes->emplace_back(builder.getNamedAttr(
|
||||
"filter_w", builder.getI32IntegerAttr(pool_params->filter_height)));
|
||||
"filter_h", builder.getI32IntegerAttr(pool_params->filter_height)));
|
||||
attributes->emplace_back(builder.getNamedAttr(
|
||||
"filter_h", builder.getI32IntegerAttr(pool_params->filter_width)));
|
||||
"filter_w", builder.getI32IntegerAttr(pool_params->filter_width)));
|
||||
return Status::OK();
|
||||
|
||||
} else if (op_name == "tfl.convolution_2d_transpose_bias") {
|
||||
|
@ -71,6 +71,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
#include "tensorflow/lite/tools/versioning/op_version.h"
|
||||
@ -218,6 +219,13 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
|
||||
auto qtype = type.cast<mlir::quant::UniformQuantizedPerAxisType>();
|
||||
return GetTFLiteType(qtype.getStorageType(), qtype.isSigned());
|
||||
}
|
||||
case mlir::TF::TensorFlowTypes::RESOURCE: {
|
||||
// Treat tf.resource values as integer values in flatbuffer.
|
||||
// TODO(b/146131919): Maybe need to have a detailed design for supporting
|
||||
// other resource types beyonds hash table resources and resource
|
||||
// variables.
|
||||
return tflite::TensorType_INT32;
|
||||
}
|
||||
default:
|
||||
// TFLite export fills FLOAT32 for unknown data types. Returning an error
|
||||
// for now for safety and this could be revisited when required.
|
||||
@ -317,6 +325,48 @@ static std::unique_ptr<::tensorflow::NodeDef> getTensorFlowNodeDef(
|
||||
return std::move(status_or_node_def.ValueOrDie());
|
||||
}
|
||||
|
||||
// Converts a mlir padding StringRef to TfLitePadding.
|
||||
// Returns llvm::None if conversion fails.
|
||||
static Optional<TfLitePadding> GetTflitePadding(Operation* inst,
|
||||
llvm::StringRef padding) {
|
||||
const tflite::Padding padding_attr =
|
||||
std::move(llvm::StringSwitch<tflite::Padding>(padding)
|
||||
.Case("SAME", tflite::Padding_SAME)
|
||||
.Case("VALID", tflite::Padding_VALID));
|
||||
if (padding_attr == tflite::Padding_SAME) {
|
||||
return kTfLitePaddingSame;
|
||||
}
|
||||
if (padding_attr == tflite::Padding_VALID) {
|
||||
return kTfLitePaddingValid;
|
||||
}
|
||||
|
||||
return inst->emitOpError() << "Invalid padding attribute: " << padding,
|
||||
llvm::None;
|
||||
}
|
||||
|
||||
// Extracts TfLitePoolParams from a TFL custom op.
|
||||
// Template parameter, TFLOp, should be a TFL custom op containing attributes
|
||||
// generated from TfLitePoolParams.
|
||||
// Returns llvm::None if conversion fails.
|
||||
template <typename TFLOp>
|
||||
static Optional<TfLitePoolParams> GetTflitePoolParams(Operation* inst,
|
||||
TFLOp op) {
|
||||
TfLitePoolParams pool_params;
|
||||
pool_params.stride_height = op.stride_h().getSExtValue();
|
||||
pool_params.stride_width = op.stride_w().getSExtValue();
|
||||
pool_params.filter_height = op.filter_h().getSExtValue();
|
||||
pool_params.filter_width = op.filter_w().getSExtValue();
|
||||
const auto padding = GetTflitePadding(inst, op.padding());
|
||||
if (padding) {
|
||||
pool_params.padding = *padding;
|
||||
pool_params.activation = kTfLiteActNone;
|
||||
pool_params.computed.padding = TfLitePaddingValues{0, 0, 0, 0};
|
||||
return pool_params;
|
||||
}
|
||||
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Translates an MLIR module in TFLite dialect to TFLite FlatBuffer.
|
||||
@ -375,9 +425,31 @@ class Translator {
|
||||
mlir::TF::WhileOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
|
||||
// Builds custom operators.
|
||||
// Templated on a) data type of custom_option to be stored into flatbuffer,
|
||||
// and b) TFL custom op type.
|
||||
template <typename CustomOptionType, typename TFLOp>
|
||||
BufferOffset<tflite::Operator> BuildCustomOperator(
|
||||
const CustomOptionType& custom_option, const std::string& opcode_name,
|
||||
TFLOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
|
||||
BufferOffset<tflite::Operator> BuildNumericVerifyOperator(
|
||||
mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
Optional<BufferOffset<tflite::Operator>>
|
||||
BuildConvolution2DTransposeBiasOperator(
|
||||
Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op,
|
||||
const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
Optional<BufferOffset<tflite::Operator>> BuildMaxPoolingWithArgMax2DOperator(
|
||||
Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op,
|
||||
const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
Optional<BufferOffset<tflite::Operator>> BuildMaxUnpooling2DOperator(
|
||||
Operation* inst, mlir::TFL::MaxUnpooling2DOp op,
|
||||
const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
|
||||
Optional<CustomOptionsOffset> CreateFlexOpCustomOptions(
|
||||
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
|
||||
@ -615,19 +687,72 @@ BufferOffset<tflite::Operator> Translator::BuildWhileOperator(
|
||||
builtin_options);
|
||||
}
|
||||
|
||||
template <typename CustomOptionType, typename TFLOp>
|
||||
BufferOffset<tflite::Operator> Translator::BuildCustomOperator(
|
||||
const CustomOptionType& custom_option, const std::string& opcode_name,
|
||||
TFLOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results) {
|
||||
std::vector<uint8_t> custom_option_vector(sizeof(CustomOptionType));
|
||||
memcpy(custom_option_vector.data(), &custom_option, sizeof(CustomOptionType));
|
||||
auto opcode_index =
|
||||
GetOpcodeIndex(opcode_name, tflite::BuiltinOperator_CUSTOM);
|
||||
return tflite::CreateOperator(
|
||||
builder_, opcode_index, builder_.CreateVector(operands),
|
||||
builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
|
||||
/*builtin_options=*/0,
|
||||
builder_.CreateVector<uint8_t>(custom_option_vector),
|
||||
tflite::CustomOptionsFormat_FLEXBUFFERS);
|
||||
}
|
||||
|
||||
BufferOffset<tflite::Operator> Translator::BuildNumericVerifyOperator(
|
||||
mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results) {
|
||||
float tolerance = op.tolerance().convertToFloat();
|
||||
std::vector<uint8_t> custom_options(sizeof(float));
|
||||
memcpy(custom_options.data(), &tolerance, sizeof(float));
|
||||
auto opcode_index =
|
||||
GetOpcodeIndex("NumericVerify", tflite::BuiltinOperator_CUSTOM);
|
||||
return tflite::CreateOperator(
|
||||
builder_, opcode_index, builder_.CreateVector(operands),
|
||||
builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
|
||||
/*builtin_options=*/0, builder_.CreateVector<uint8_t>(custom_options),
|
||||
tflite::CustomOptionsFormat_FLEXBUFFERS);
|
||||
return BuildCustomOperator(tolerance, "NumericVerify", op, operands, results);
|
||||
}
|
||||
|
||||
Optional<BufferOffset<tflite::Operator>>
|
||||
Translator::BuildConvolution2DTransposeBiasOperator(
|
||||
Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op,
|
||||
const std::vector<int32_t>& operands, const std::vector<int32_t>& results) {
|
||||
TfLiteTransposeConvParams conv_params;
|
||||
conv_params.stride_height = op.stride_h().getSExtValue();
|
||||
conv_params.stride_width = op.stride_w().getSExtValue();
|
||||
const auto padding = GetTflitePadding(inst, op.padding());
|
||||
if (padding) {
|
||||
conv_params.padding = *padding;
|
||||
return BuildCustomOperator(conv_params, "Convolution2DTransposeBias", op,
|
||||
operands, results);
|
||||
}
|
||||
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
Optional<BufferOffset<tflite::Operator>>
|
||||
Translator::BuildMaxPoolingWithArgMax2DOperator(
|
||||
Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op,
|
||||
const std::vector<int32_t>& operands, const std::vector<int32_t>& results) {
|
||||
const auto pool_params = GetTflitePoolParams(inst, op);
|
||||
if (pool_params) {
|
||||
return BuildCustomOperator(*pool_params, "MaxPoolingWithArgmax2D", op,
|
||||
operands, results);
|
||||
}
|
||||
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
Optional<BufferOffset<tflite::Operator>>
|
||||
Translator::BuildMaxUnpooling2DOperator(Operation* inst,
|
||||
mlir::TFL::MaxUnpooling2DOp op,
|
||||
const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results) {
|
||||
const auto pool_params = GetTflitePoolParams(inst, op);
|
||||
if (pool_params) {
|
||||
return BuildCustomOperator(*pool_params, "MaxUnpooling2D", op, operands,
|
||||
results);
|
||||
}
|
||||
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
|
||||
@ -769,6 +894,20 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
if (auto verify_op = dyn_cast<mlir::TFL::NumericVerifyOp>(inst)) {
|
||||
return BuildNumericVerifyOperator(verify_op, operands, results);
|
||||
}
|
||||
if (auto conv_transpose_bias_op =
|
||||
dyn_cast<mlir::TFL::Convolution2DTransposeBiasOp>(inst)) {
|
||||
return BuildConvolution2DTransposeBiasOperator(
|
||||
inst, conv_transpose_bias_op, operands, results);
|
||||
}
|
||||
if (auto max_pooling_with_arg_max_op =
|
||||
dyn_cast<mlir::TFL::MaxPoolingWithArgMax2DOp>(inst)) {
|
||||
return BuildMaxPoolingWithArgMax2DOperator(
|
||||
inst, max_pooling_with_arg_max_op, operands, results);
|
||||
}
|
||||
if (auto max_unpooling_op = dyn_cast<mlir::TFL::MaxUnpooling2DOp>(inst)) {
|
||||
return BuildMaxUnpooling2DOperator(inst, max_unpooling_op, operands,
|
||||
results);
|
||||
}
|
||||
inst->emitOpError("is not a supported TFLite op");
|
||||
return llvm::None;
|
||||
}
|
||||
@ -904,11 +1043,6 @@ void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) {
|
||||
|
||||
bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) {
|
||||
std::vector<int> operand_indices;
|
||||
// TODO(b/138254427): When the bug is addressed, we'll be able to inspect
|
||||
// for the presence of a specific OpTrait using mlir::Operation, without
|
||||
// having to cast it to specific ops like below.
|
||||
// Until then, when a new RNN/LSTM op is added to TFLite and has stateful
|
||||
// tensors as operands, they will need to be added here as well.
|
||||
if (!mlir::TFL::IsStatefulOp(op, &operand_indices)) return false;
|
||||
return absl::c_find(operand_indices, operand_index) != operand_indices.end();
|
||||
}
|
||||
|
@ -1728,6 +1728,7 @@ static LogicalResult Verify(TransposeOp op) {
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.cc.inc"
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
|
||||
|
||||
|
@ -27,7 +27,6 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
@ -44,6 +43,7 @@ class TensorFlowLiteDialect : public Dialect {
|
||||
Location loc) override;
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.h.inc"
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc"
|
||||
|
||||
|
@ -249,14 +249,39 @@ def TFL_ComparisonBinaryBuilder : OpBuilder<
|
||||
}]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL native op trait for stateful operands and channel indices.
|
||||
// TFL op interface for stateful operands.
|
||||
|
||||
class StatefulOperands<list<int> operands>
|
||||
: ParamNativeOpTrait<"TFL::StatefulOperands", StrJoinInt<operands>.result>;
|
||||
def TFL_StatefulOp : OpInterface<"StatefulOpInterface"> {
|
||||
let description = [{
|
||||
Interface for ops that are stateful and need to identify stateful operands.
|
||||
|
||||
Stateful operands correspond to TF's variables semantics. An op that has 1
|
||||
or more stateful operands is a stateful op.
|
||||
}];
|
||||
|
||||
class ChannelDimIndex<int index>
|
||||
: ParamNativeOpTrait<"TFL::ChannelDimIndex", !cast<string>(index)>;
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
[{Returns the indices of stateful operands.}],
|
||||
"std::vector<int>", "GetStatefulOperands", (ins)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL op interface for output channel index.
|
||||
|
||||
def TFL_ChannelDimIndexInterface : OpInterface<"ChannelDimIndexInterface"> {
|
||||
let description = [{
|
||||
Interface for defining the index of out channel index.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
[{Returns the dimension index of the output channels.}],
|
||||
"int", "GetChannelDimIndex", (ins)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL op base class.
|
||||
@ -285,7 +310,7 @@ class TFL_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
|
||||
class TFL_ConvOp<string mnemonic, string opSummary, int index> :
|
||||
TFL_Op<mnemonic, [NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
|
||||
ChannelDimIndex<index>, AffineOpCoefficient<index, 1>]> {
|
||||
TFL_ChannelDimIndexInterface, AffineOpCoefficient<index, 1>]> {
|
||||
let summary = opSummary # " operator";
|
||||
|
||||
let description = [{
|
||||
@ -486,8 +511,7 @@ def TFL_ArgMaxOp : TFL_Op<"arg_max", [NoSideEffect]> {
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
// TODO: Add support for uint8.
|
||||
ins TensorOf<[F32, I32, I8]>:$input,
|
||||
ins TensorOf<[F32, I32, I8, TFL_Uint8, QI8, QUI8]>:$input,
|
||||
TFL_I32OrI64Tensor:$dim
|
||||
);
|
||||
|
||||
@ -515,8 +539,7 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> {
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
// TODO(pkanwar): Add support for uint8.
|
||||
ins TensorOf<[F32, I32, I8]>:$input,
|
||||
ins TensorOf<[F32, I32, I8, TFL_Uint8, QI8, QUI8]>:$input,
|
||||
TFL_I32OrI64Tensor:$dim
|
||||
);
|
||||
|
||||
@ -617,7 +640,12 @@ def TFL_ExternalConstOp : Op<TFL_Dialect, "external_const", [NoSideEffect]> {
|
||||
let results = (outs AnyTensor:$output);
|
||||
}
|
||||
|
||||
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0>;
|
||||
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0> {
|
||||
let extraClassDeclaration = [{
|
||||
// StatefulOpInterface:
|
||||
int GetChannelDimIndex() { return 0; }
|
||||
}];
|
||||
}
|
||||
|
||||
def TFL_CosOp: TFL_Op<"cos", [
|
||||
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
|
||||
@ -637,6 +665,11 @@ def TFL_CosOp: TFL_Op<"cos", [
|
||||
def TFL_DepthwiseConv2DOp :
|
||||
TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> {
|
||||
let arguments = !con(TFL_Conv2DOp.arguments, (ins I32Attr:$depth_multiplier));
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// StatefulOpInterface:
|
||||
int GetChannelDimIndex() { return 3; }
|
||||
}];
|
||||
}
|
||||
|
||||
def TFL_FCWO_Default : StrEnumAttrCase<"DEFAULT">;
|
||||
@ -650,7 +683,8 @@ def TFL_FullyConnectedOptionsWeightFormatAttr :
|
||||
|
||||
// TODO(jpienaar): Update post discussion on semantics of FC OP.
|
||||
def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
|
||||
NoSideEffect, AccumulatorUniformScale<2, 0, 1>, ChannelDimIndex<0>,
|
||||
NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
|
||||
TFL_ChannelDimIndexInterface,
|
||||
AffineOpCoefficient<-1, 1>]> {
|
||||
let summary = "Fully connected op";
|
||||
|
||||
@ -672,6 +706,11 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
|
||||
let hasOptions = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// ChannelDimIndexInterface:
|
||||
int GetChannelDimIndex() { return 0; }
|
||||
}];
|
||||
}
|
||||
|
||||
def TFL_GatherOp : TFL_Op<"gather", [
|
||||
@ -1208,7 +1247,8 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [Broadcastable, NoSideEffect]> {
|
||||
let builders = [TFL_BroadcastableBinaryBuilder];
|
||||
}
|
||||
|
||||
def TFL_GreaterOp : TFL_Op<"greater", [NoSideEffect, NoQuantizableResult]> {
|
||||
def TFL_GreaterOp : TFL_Op<"greater", [
|
||||
Broadcastable, NoSideEffect, NoQuantizableResult]> {
|
||||
let summary = "Greater operator";
|
||||
|
||||
let description = [{
|
||||
@ -1221,6 +1261,8 @@ def TFL_GreaterOp : TFL_Op<"greater", [NoSideEffect, NoQuantizableResult]> {
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
|
||||
let builders = [TFL_ComparisonBinaryBuilder];
|
||||
|
||||
let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }];
|
||||
|
||||
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
|
||||
@ -1287,7 +1329,8 @@ def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [NoSideEffect, SameOperandsAndResultTy
|
||||
let hasOptions = 0b1;
|
||||
}
|
||||
|
||||
def TFL_LessOp : TFL_Op<"less", [NoSideEffect, NoQuantizableResult]> {
|
||||
def TFL_LessOp : TFL_Op<"less", [
|
||||
Broadcastable, NoSideEffect, NoQuantizableResult]> {
|
||||
let summary = "Less operator";
|
||||
|
||||
let description = [{
|
||||
@ -2123,7 +2166,7 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2",
|
||||
|
||||
Args:
|
||||
tensor: A Tensor. Must be one of the following types:
|
||||
int16, int32, int64, float32 Up to 8-D.
|
||||
uint8, int16, int32, int64, float32, bool Up to 8-D.
|
||||
|
||||
axis: A Tensor. Must be one of the following types: int32, int64.
|
||||
with only 1 element which is the axis index.
|
||||
@ -2132,12 +2175,12 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2",
|
||||
|
||||
let arguments = (
|
||||
ins
|
||||
TensorOf<[F32, I16, I32, I64]>:$input,
|
||||
TensorOf<[F32, I16, I32, I64, TFL_Uint8, I1]>:$input,
|
||||
TensorOf<[I32, I64]>:$axis
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F32, I16, I32, I64, I8]>:$output
|
||||
TensorOf<[F32, I16, I32, I64, TFL_Uint8, I1]>:$output
|
||||
);
|
||||
}
|
||||
|
||||
@ -2341,9 +2384,9 @@ def TFL_TanhOp: TFL_Op<"tanh", [
|
||||
let results = (outs TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$y);
|
||||
}
|
||||
|
||||
def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
|
||||
def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, SameOperandsAndResultsScale,
|
||||
PredOpTrait<"resultant element type needs to match first operand type",
|
||||
TCresVTEtIsSameAsOp<0,0>>]> {
|
||||
TFL_TCresVTEtIsSameAsOp<0,0>>]> {
|
||||
let summary = "Tile operator.";
|
||||
let description = [{
|
||||
Constructs a tensor by tiling a given tensor.
|
||||
@ -2356,10 +2399,11 @@ def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I1, I32, I64, TFL_Uint8]>:$input,
|
||||
TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8]>:$input,
|
||||
TFL_I32OrI64Tensor:$multiples);
|
||||
|
||||
let results = (outs TensorOf<[F32, I1, I32, I64, TFL_Uint8]>:$output);
|
||||
let results = (outs
|
||||
TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8]>:$output);
|
||||
|
||||
let hasOptions = 0;
|
||||
}
|
||||
@ -2369,7 +2413,7 @@ def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
|
||||
// TODO(jpienaar): Check that k is less or equal the internal dimension
|
||||
def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>,
|
||||
PredOpTrait<"result and input element type match",
|
||||
TCresVTEtIsSameAsOp<0,0>>]> {
|
||||
TCresVTEtIsSameAsOp<0,0>>, SameOperandsAndResultsScale]> {
|
||||
let summary = "TopK operator";
|
||||
|
||||
let description = [{
|
||||
@ -2379,11 +2423,11 @@ def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>,
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I8, I32, I64, TFL_Uint8]>:$input,
|
||||
TensorOf<[F32, I8, I32, I64, TFL_Uint8, QI8, QUI8]>:$input,
|
||||
I32Tensor:$k);
|
||||
|
||||
let results = (outs
|
||||
AnyTensor:$values,
|
||||
TensorOf<[F32, I8, I32, I64, TFL_Uint8, QI8, QUI8]>:$values,
|
||||
I32Tensor:$indices);
|
||||
|
||||
let builders = [OpBuilder<"Builder *builder, OperationState &result, "
|
||||
@ -2907,6 +2951,20 @@ def TFL_QuantizeOp: TFL_Op<"quantize", [
|
||||
let results = (outs AnyTensor:$output);
|
||||
}
|
||||
|
||||
def TFL_DensifyOp: TFL_Op<"densify", [NoSideEffect,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Densify operator";
|
||||
|
||||
let description = [{
|
||||
Converts sparse tensor to dense format.
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTensor:$input);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LSTM Ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -2996,7 +3054,7 @@ def TFL_LSTMOp :
|
||||
LstmOptionalPeepholeWeightConstraint,
|
||||
LstmProjectionWeightBiasConstraint,
|
||||
LstmResultConstraint,
|
||||
StatefulOperands<[18, 19]>]> {
|
||||
TFL_StatefulOp]> {
|
||||
let summary = "The full lstm operator";
|
||||
|
||||
let description = [{
|
||||
@ -3080,6 +3138,11 @@ Ba et al. “Layer Normalization”
|
||||
let hasOptions = 1;
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// StatefulOpInterface:
|
||||
std::vector<int> GetStatefulOperands() { return {18, 19}; }
|
||||
}];
|
||||
}
|
||||
|
||||
// UnidirectionalSequenceLstm op.
|
||||
@ -3091,7 +3154,7 @@ def TFL_UnidirectionalSequenceLSTMOp :
|
||||
LstmOptionalPeepholeWeightConstraint,
|
||||
LstmProjectionWeightBiasConstraint,
|
||||
LstmResultConstraint,
|
||||
StatefulOperands<[18, 19]>]> {
|
||||
TFL_StatefulOp]> {
|
||||
let summary = "Unidirectional sequence lstm operator";
|
||||
|
||||
let description = [{
|
||||
@ -3160,6 +3223,11 @@ def TFL_UnidirectionalSequenceLSTMOp :
|
||||
let hasOptions = 1;
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// StatefulOpInterface:
|
||||
std::vector<int> GetStatefulOperands() { return {18, 19}; }
|
||||
}];
|
||||
}
|
||||
|
||||
def RnnResultConstraint : PredOpTrait<
|
||||
@ -3169,7 +3237,7 @@ def RnnResultConstraint : PredOpTrait<
|
||||
// UnidirectionalSequenceRNN op.
|
||||
def TFL_UnidirectionalSequenceRNNOp :
|
||||
TFL_Op<"unidirectional_sequence_rnn",
|
||||
[RnnResultConstraint, StatefulOperands<[4]>]> {
|
||||
[RnnResultConstraint, TFL_StatefulOp]> {
|
||||
|
||||
let summary = "Unidirectional sequence rnn operator";
|
||||
|
||||
@ -3213,6 +3281,11 @@ def TFL_UnidirectionalSequenceRNNOp :
|
||||
let customOption = "SequenceRNNOptions";
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// StatefulOpInterface:
|
||||
std::vector<int> GetStatefulOperands() { return {4}; }
|
||||
}];
|
||||
}
|
||||
|
||||
def TFL_WhereOp : TFL_Op<"where", [NoSideEffect]> {
|
||||
@ -3264,7 +3337,7 @@ def SVDFResultConstraint: PredOpTrait<
|
||||
// SVDF op.
|
||||
def TFL_SVDFOp :
|
||||
TFL_Op<"svdf",
|
||||
[SVDFResultConstraint, StatefulOperands<[4]>]> {
|
||||
[SVDFResultConstraint, TFL_StatefulOp]> {
|
||||
|
||||
let summary = "Single value decomposition filter operator";
|
||||
|
||||
@ -3300,6 +3373,25 @@ def TFL_SVDFOp :
|
||||
let hasOptions = 1;
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// StatefulOpInterface:
|
||||
std::vector<int> GetStatefulOperands() { return {4}; }
|
||||
}];
|
||||
}
|
||||
|
||||
def TFL_SegmentSumOp: TFL_Op<"segment_sum", [NoSideEffect]> {
|
||||
let summary = "SegmentSum operator";
|
||||
|
||||
let description = [{
|
||||
Computes the sum along segments of a tensor.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I32]>:$data,
|
||||
I32Tensor:$segment_ids
|
||||
);
|
||||
let results = (outs TensorOf<[F32, I32]>:$output);
|
||||
}
|
||||
|
||||
#endif // TFL_OPS
|
||||
|
@ -1,67 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file defines the op traits used in the MLIR TensorFlow Lite dialect.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace OpTrait {
|
||||
namespace TFL {
|
||||
|
||||
// The trait to specify that the specified operands of the TFL op are stateful.
|
||||
// This is used as a trait like this:
|
||||
//
|
||||
// class LSTMOp
|
||||
// : public Op<LSTMOp, OpTrait::TFL::StatefulOperands<18, 19>::Impl> {
|
||||
//
|
||||
template <int... Operands>
|
||||
class StatefulOperands {
|
||||
public:
|
||||
template <typename ConcreteType>
|
||||
class Impl
|
||||
: public TraitBase<ConcreteType, StatefulOperands<Operands...>::Impl> {
|
||||
public:
|
||||
static std::vector<int> GetStatefulOperands() {
|
||||
return std::vector<int>({Operands...});
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
// The trait to specify the channel dimension index of the input (first operand)
|
||||
// of an affine TFL op (Conv2D, DepthwiseConv2D, FullyConnected).
|
||||
//
|
||||
// class Conv2DOp
|
||||
// : public Op<Conv2DOp, OpTrait::TFL::ChannelDimIndex<0>::Impl> {
|
||||
//
|
||||
template <int Index>
|
||||
class ChannelDimIndex {
|
||||
public:
|
||||
template <typename ConcreteType>
|
||||
class Impl : public TraitBase<ConcreteType, ChannelDimIndex<Index>::Impl> {
|
||||
public:
|
||||
static int GetChannelDimIndex() { return Index; }
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace OpTrait
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
|
@ -32,6 +32,6 @@ cc_library(
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:ViewOpGraph",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
@ -107,9 +107,6 @@ void WarningUnusedFlags(const toco::ModelFlags& model_flags,
|
||||
if (toco_flags.output_format()) {
|
||||
LOG(WARNING) << "Ignored output_format.";
|
||||
}
|
||||
if (toco_flags.default_ranges_min() || toco_flags.default_ranges_max()) {
|
||||
LOG(WARNING) << "Ignored default_ranges_stats.";
|
||||
}
|
||||
if (toco_flags.drop_control_dependency()) {
|
||||
LOG(WARNING) << "Ignored drop_control_dependency.";
|
||||
}
|
||||
@ -242,6 +239,13 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
tensorflow::ParseOutputArrayInfo(output_arrays, &specs.outputs));
|
||||
|
||||
// Other flags.
|
||||
if (toco_flags.has_default_ranges_min()) {
|
||||
quant_specs.default_ranges.first = toco_flags.default_ranges_min();
|
||||
}
|
||||
if (toco_flags.has_default_ranges_max()) {
|
||||
quant_specs.default_ranges.second = toco_flags.default_ranges_max();
|
||||
}
|
||||
|
||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
|
||||
bool emit_custom_ops = toco_flags.allow_custom_ops();
|
||||
|
@ -206,10 +206,17 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
|
||||
std::unique_ptr<OpPassBase<FuncOp>>
|
||||
CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) {
|
||||
auto get_name_func = [](Operation *op) {
|
||||
if (auto name = op->getAttrOfType<StringAttr>("name"))
|
||||
return name.getValue();
|
||||
else
|
||||
return llvm::StringRef("");
|
||||
Location loc = op->getLoc();
|
||||
if (auto name = loc.dyn_cast<NameLoc>()) {
|
||||
return name.getName().strref();
|
||||
} else if (auto fused_name = loc.dyn_cast<FusedLoc>()) {
|
||||
for (auto sub_loc : fused_name.getLocations()) {
|
||||
if (auto named_sub_loc = sub_loc.dyn_cast<NameLoc>()) {
|
||||
return named_sub_loc.getName().strref();
|
||||
}
|
||||
}
|
||||
}
|
||||
return llvm::StringRef("");
|
||||
};
|
||||
|
||||
return CreateImportQuantStatsPass(get_name_func, stats_str);
|
||||
|
@ -23,7 +23,6 @@ cc_library(
|
||||
],
|
||||
hdrs = [
|
||||
"quantize_model.h",
|
||||
"//tensorflow/compiler/mlir/lite:transforms/passes.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite:common",
|
||||
|
@ -73,19 +73,19 @@ TfLiteStatus QuantizeModel(
|
||||
|
||||
// Apply quantization passes
|
||||
PassManager pm(module->getContext());
|
||||
TFL::QuantizationSpecs pass_config;
|
||||
pass_config.inference_type = tensorflow::DT_QINT8;
|
||||
pass_config.post_training_quantization = true;
|
||||
TFL::QuantizationSpecs quant_specs;
|
||||
quant_specs.inference_type = tensorflow::DT_QINT8;
|
||||
quant_specs.post_training_quantization = true;
|
||||
|
||||
bool emit_adaptor = false;
|
||||
auto input_tf_type = tflite::TflTypeToTfType(input_type);
|
||||
if (input_tf_type == tensorflow::DT_FLOAT) {
|
||||
emit_adaptor = true;
|
||||
} else if (input_tf_type == tensorflow::DT_UINT8) {
|
||||
pass_config.inference_type = tensorflow::DT_QUINT8;
|
||||
quant_specs.inference_type = tensorflow::DT_QUINT8;
|
||||
}
|
||||
|
||||
pm.addPass(TFL::CreatePrepareQuantizePass(pass_config));
|
||||
pm.addPass(TFL::CreatePrepareQuantizePass(quant_specs));
|
||||
pm.addPass(TFL::CreateQuantizePass());
|
||||
pm.addPass(TFL::CreatePostQuantizePass(emit_adaptor));
|
||||
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
|
||||
@ -64,6 +65,10 @@ struct QuantizationSpecs {
|
||||
// quantization aware training or calibration, for the remaining tensors.
|
||||
std::vector<std::pair<double, double>> input_ranges;
|
||||
|
||||
// The default ranges can be used when a tensor doesn't have quantization
|
||||
// parameters and couldn't be quantized. Used only for latency tests.
|
||||
std::pair<llvm::Optional<double>, llvm::Optional<double>> default_ranges;
|
||||
|
||||
// A serialized "QuantizationInfo" object to specify value ranges for some of
|
||||
// the tensors with known names.
|
||||
std::string serialized_quant_stats = "";
|
||||
|
@ -35,7 +35,6 @@ limitations under the License.
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
@ -3,7 +3,8 @@
|
||||
|
||||
// CHECK-LABEL: import_stats_skip
|
||||
func @import_stats_skip(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
|
||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "skip"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
loc(fused["skip1", "skip2.cc":10:8, callsite("op" at "skip3.cc":10:8)])
|
||||
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
|
||||
|
||||
// CHECK-NEXT: "tfl.split"
|
||||
@ -12,7 +13,8 @@ func @import_stats_skip(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf3
|
||||
|
||||
// CHECK-LABEL: import_stats_name
|
||||
func @import_stats_name(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
|
||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
loc(fused["skip1.cc":10:8, "op", callsite("skip2" at "skip3.cc":10:8)])
|
||||
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
|
||||
|
||||
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"
|
||||
@ -23,7 +25,8 @@ func @import_stats_name(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf3
|
||||
|
||||
// CHECK-LABEL: import_stats_name_port
|
||||
func @import_stats_name_port(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
|
||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op_0"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
loc(fused["skip1.cc":10:8, "op_0", callsite("skip2" at "skip3.cc":10:8)])
|
||||
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
|
||||
|
||||
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"
|
||||
@ -34,6 +37,7 @@ func @import_stats_name_port(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor
|
||||
// CHECK-LABEL: import_stats_name_regex
|
||||
func @import_stats_name_regex(%arg0: tensor<4xf32>, %cst: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
|
||||
%0:2 = "tfl.split"(%cst, %arg0) {num_splits = 2 : i32, name = "op_regex"} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
loc(fused["skip1.cc":10:8, "op_regex", callsite("skip2" at "skip3.cc":10:8)])
|
||||
return %0#0, %0#1 : tensor<2xf32>, tensor<2xf32>
|
||||
|
||||
// CHECK-NEXT: %[[split:.*]]:2 = "tfl.split"
|
||||
|
@ -0,0 +1,89 @@
|
||||
// RUN: tf-opt %s --tfl-default-quant --tfl-quantize | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: hardcode_all
|
||||
func @hardcode_all(%arg0: tensor<2x2xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x2xf32> {
|
||||
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function="NONE"}: (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
|
||||
// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform<u8:f32, 0.0078431372549019607:128>>}
|
||||
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>}
|
||||
// Quantized tfl.add
|
||||
// CHECK: %[[add:.*]] = "tfl.add"(%[[q1]], %[[q0]]) {fused_activation_function = "NONE"} : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[add]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>)
|
||||
// CHECK: return %[[dq]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: hardcode_input
|
||||
func @hardcode_input(%arg0: tensor<2x2xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x2xf32> {
|
||||
%0 = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>}: (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>
|
||||
%1 = "tfl.dequantize"(%0) : (tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>) -> tensor<2x2xf32>
|
||||
%4 = "tfl.add"(%1, %arg1) {fused_activation_function="NONE"}: (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
|
||||
return %4 : tensor<2x2xf32>
|
||||
|
||||
// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform<u8:f32, 0.0078431372549019607:128>>}
|
||||
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e+00:128>>}
|
||||
// CHECK: %[[add:.*]] = "tfl.add"(%[[q1]], %[[q0]]) {fused_activation_function = "NONE"} : (tensor<2x2x!quant.uniform<u8:f32, 1.000000e+00:128>>
|
||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[add]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>)
|
||||
// CHECK: return %[[dq]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: hardcode_input_deq
|
||||
func @hardcode_input_deq(%arg0: tensor<2x2x!quant.uniform<u8:f32, 1.0>>, %arg1: tensor<2x1xf32>) -> tensor<2x2xf32> {
|
||||
%1 = "tfl.dequantize"(%arg0) : (tensor<2x2x!quant.uniform<u8:f32, 1.0>>) -> tensor<2x2xf32>
|
||||
%4 = "tfl.add"(%1, %arg1) {fused_activation_function="NONE"}: (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
|
||||
return %4 : tensor<2x2xf32>
|
||||
|
||||
// CHECK: %[[q:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform<u8:f32, 0.0078431372549019607:128>>}
|
||||
// CHECK: %[[add:.*]] = "tfl.add"(%arg0, %[[q]]) {fused_activation_function = "NONE"} : (tensor<2x2x!quant.uniform<u8:f32, 1.000000e+00>>
|
||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[add]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>)
|
||||
// CHECK: return %[[dq]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: hardcode_output
|
||||
func @hardcode_output(%arg0: tensor<2x2xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x2xf32> {
|
||||
%0 = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>}: (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>
|
||||
%1 = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform<u8:f32, 1.0:128>>}: (tensor<2x1xf32>) -> tensor<2x1x!quant.uniform<u8:f32, 1.0:128>>
|
||||
%2 = "tfl.dequantize"(%0) : (tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>) -> tensor<2x2xf32>
|
||||
%3 = "tfl.dequantize"(%1) : (tensor<2x1x!quant.uniform<u8:f32, 1.0:128>>) -> tensor<2x1xf32>
|
||||
%4 = "tfl.add"(%2, %3) {fused_activation_function="NONE"}: (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
|
||||
return %4 : tensor<2x2xf32>
|
||||
|
||||
// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e+00:128>>}
|
||||
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform<u8:f32, 1.000000e+00:128>>}
|
||||
// CHECK: %[[add:.*]] = "tfl.add"(%[[q0]], %[[q1]]) {fused_activation_function = "NONE"} : (tensor<2x2x!quant.uniform<u8:f32, 1.000000e+00:128>>
|
||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[add]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>)
|
||||
// CHECK: return %[[dq]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test_conv_2d_add
|
||||
func @test_conv_2d_add(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32, 1.0>>, %arg1: tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 1.0>>, %arg2: tensor<32x!quant.uniform<i32:f32, 1.0>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>> {
|
||||
%0 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 1.0>>) -> tensor<1x224x224x3xf32>
|
||||
%1 = "tfl.dequantize"(%arg1) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 1.0>>) -> tensor<32x3x3x3xf32>
|
||||
%2 = "tfl.dequantize"(%arg2) : (tensor<32x!quant.uniform<i32:f32, 1.0>>) -> tensor<32xf32>
|
||||
%3 = "tfl.conv_2d"(%0, %1, %2) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
|
||||
%4 = "tfl.pseudo_qconst"() {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>, value = dense<1> : tensor<1x112x112x32xi8>} : () -> tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>
|
||||
%5 = "tfl.dequantize"(%4) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>) -> tensor<1x112x112x32xf32>
|
||||
%6 = "tfl.add"(%3, %5) {fused_activation_function="NONE"}: (tensor<1x112x112x32xf32>, tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32>
|
||||
%7 = "tfl.quantize"(%6) {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>
|
||||
return %7 : tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>
|
||||
|
||||
// CHECK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %arg1, %arg2)
|
||||
// CHECK-SAME: -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"()
|
||||
// CHECK: %[[add:.*]] = "tfl.add"(%[[conv]], %[[cst]])
|
||||
// CHECK-SAME: -> tensor<1x112x112x32x!quant.uniform<u8:f32, 1.000000e+00>>
|
||||
// CHECK: return %[[add]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test_conv_2d_activation_and_bias
|
||||
func @test_conv_2d_activation_and_bias(%arg0: tensor<1x224x224x3xf32>, %arg1: tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 1.0>>, %arg2: tensor<32xf32>) -> tensor<1x112x112x32xf32> {
|
||||
%0 = "tfl.dequantize"(%arg1) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 1.0>>) -> tensor<32x3x3x3xf32>
|
||||
%1 = "tfl.conv_2d"(%arg0, %0, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
|
||||
return %1 : tensor<1x112x112x32xf32>
|
||||
|
||||
// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg2) {qtype = tensor<32x!quant.uniform<i32:f32, 0.0078431372549019607>>}
|
||||
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x224x224x3x!quant.uniform<u8:f32, 0.0078431372549019607:128>>}
|
||||
// CHECK: %[[conv:.*]] = "tfl.conv_2d"(%[[q1]], %arg1, %[[q0]])
|
||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[conv]]) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.0078431372549019607:128>>)
|
||||
// CHECK: return %[[dq]]
|
||||
}
|
@ -0,0 +1,76 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s
|
||||
|
||||
|
||||
func @main(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>) -> tensor<1x64x84x32xf32> {
|
||||
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
||||
// CHECK-NEXT: custom_code: "Convolution2DTransposeBias"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 32, 4, 4, 128 ],
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "arg0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1, 32, 42, 128 ],
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "arg1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 3,
|
||||
// CHECK-NEXT: name: "arg2",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1, 64, 84, 32 ],
|
||||
// CHECK-NEXT: buffer: 4,
|
||||
// CHECK-NEXT: name: "tfl.convolution_2d_transpose_bias",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0, 1, 2 ],
|
||||
// CHECK-NEXT: outputs: [ 3 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0, 1, 2 ],
|
||||
// CHECK-NEXT: outputs: [ 3 ],
|
||||
// CHECK-NEXT: custom_options: [ 1, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
// MLIR-LABEL: func @main(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>)
|
||||
// MLIR-SAME: -> tensor<1x64x84x32xf32>
|
||||
// MLIR: %0 = "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %arg2)
|
||||
// MLIR-SAME: {padding = "SAME", stride_h = 1 : i32, stride_w = 2 : i32}
|
||||
// MLIR-SAME: (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32>
|
||||
// MLIR-NEXT: return %0 : tensor<1x64x84x32xf32>
|
||||
|
||||
%0 = "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 1 : i32, stride_w = 2 : i32} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32>
|
||||
return %0 : tensor<1x64x84x32xf32>
|
||||
}
|
@ -0,0 +1,39 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -emit-builtin-tflite-ops=false -o - | flatbuffer_to_string - | FileCheck %s
|
||||
|
||||
// CHECK: {
|
||||
// CHECK: version: 3,
|
||||
// CHECK: operator_codes: [ {
|
||||
// CHECK: builtin_code: CUSTOM,
|
||||
// CHECK: custom_code: "HashTableV2"
|
||||
// CHECK: } ],
|
||||
// CHECK: subgraphs: [ {
|
||||
// CHECK: tensors: [ {
|
||||
// CHECK: shape: [ ],
|
||||
// CHECK: type: INT32,
|
||||
// CHECK: buffer: 1,
|
||||
// CHECK: name: "tf.HashTableV2",
|
||||
// CHECK: quantization: {
|
||||
// CHECK-EMPTY
|
||||
// CHECK: }
|
||||
// CHECK: } ],
|
||||
// CHECK: inputs: [ ],
|
||||
// CHECK: outputs: [ 0 ],
|
||||
// CHECK: operators: [ {
|
||||
// CHECK: inputs: [ ],
|
||||
// CHECK: outputs: [ 0 ],
|
||||
// CHECK: custom_options:
|
||||
// CHECK: name: "main"
|
||||
// CHECK: } ],
|
||||
// CHECK: description: "MLIR Converted.",
|
||||
// CHECK: buffers: [ {
|
||||
// CHECK-EMPTY
|
||||
// CHECK: }, {
|
||||
// CHECK-EMPTY
|
||||
// CHECK: } ]
|
||||
// CHECK: }
|
||||
|
||||
func @main() -> tensor<*x!tf.resource> {
|
||||
%0 = "tf.HashTableV2"() {container = "" , shared_name= "table", use_node_name_sharing = false, key_dtype = i32, value_dtype = i32 } : () -> tensor<*x!tf.resource>
|
||||
return %0 : tensor<*x!tf.resource>
|
||||
}
|
||||
|
@ -0,0 +1,65 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s
|
||||
|
||||
func @main(%arg0: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) {
|
||||
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
||||
// CHECK-NEXT: custom_code: "MaxPoolingWithArgmax2D"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 1, 64, 64, 32 ],
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "arg0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1, 32, 32, 32 ],
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "tfl.max_pooling_with_argmax_2d",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1, 32, 32, 32 ],
|
||||
// CHECK-NEXT: buffer: 3,
|
||||
// CHECK-NEXT: name: "tfl.max_pooling_with_argmax_2d:1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0 ],
|
||||
// CHECK-NEXT: outputs: [ 1, 2 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0 ],
|
||||
// CHECK-NEXT: outputs: [ 1, 2 ],
|
||||
// CHECK-NEXT: custom_options: [ 1, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
// MLIR-LABEL: func @main(%arg0: tensor<1x64x64x32xf32>)
|
||||
// MLIR-SAME: -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
|
||||
// MLIR: %value, %indices = "tfl.max_pooling_with_argmax_2d"(%arg0)
|
||||
// MLIR-SAME: {filter_h = 4 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 1 : i32}
|
||||
// MLIR-SAME: (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
|
||||
// MLIR-NEXT: return %value, %indices : tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>
|
||||
|
||||
%0, %1 = "tfl.max_pooling_with_argmax_2d"(%arg0) {filter_h = 4 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 1 : i32} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
|
||||
return %0, %1 : tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>
|
||||
}
|
@ -0,0 +1,65 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s
|
||||
|
||||
func @main(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf32>) -> tensor<1x8x8x128xf32> {
|
||||
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
||||
// CHECK-NEXT: custom_code: "MaxUnpooling2D"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 1, 8, 8, 128 ],
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "arg0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1, 8, 8, 128 ],
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "arg1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1, 8, 8, 128 ],
|
||||
// CHECK-NEXT: buffer: 3,
|
||||
// CHECK-NEXT: name: "tfl.max_unpooling_2d",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 2 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 2 ],
|
||||
// CHECK-NEXT: custom_options: [ 1, 0, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
// MLIR-LABEL: func @main(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf32>)
|
||||
// MLIR-SAME: -> tensor<1x8x8x128xf32>
|
||||
// MLIR: %0 = "tfl.max_unpooling_2d"(%arg0, %arg1)
|
||||
// MLIR-SAME: {filter_h = 1 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 4 : i32, stride_w = 2 : i32}
|
||||
// MLIR-SAME: (tensor<1x8x8x128xf32>, tensor<1x8x8x128xf32>) -> tensor<1x8x8x128xf32>
|
||||
// MLIR-NEXT: return %0 : tensor<1x8x8x128xf32>
|
||||
|
||||
%0 = "tfl.max_unpooling_2d"(%arg0, %arg1) {filter_h = 1 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 4 : i32, stride_w = 2 : i32} : (tensor<1x8x8x128xf32>, tensor<1x8x8x128xf32>) -> (tensor<1x8x8x128xf32>)
|
||||
return %0 : tensor<1x8x8x128xf32>
|
||||
}
|
@ -1977,3 +1977,12 @@ func @testTransposeConvBadOutputShape(%arg1: tensor<32x4x4x128xf32>, %arg2: tens
|
||||
%0 = "tfl.transpose_conv"(%cst, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>) -> tensor<1x64x84x31xf32>
|
||||
return %0 : tensor<1x64x84x31xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testDensify
|
||||
func @testDensify(%arg0: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.densify"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.densify"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
|
||||
return %0 : tensor<? x f32>
|
||||
}
|
||||
|
@ -1,4 +1,7 @@
|
||||
// Run optimize pass only and check the results.
|
||||
// RUN: tf-opt %s -tfl-optimize | FileCheck %s
|
||||
// Run optimize pass and then canonicalize pass, and make sure some folding is applied.
|
||||
// RUN: tf-opt %s -tfl-optimize -canonicalize | FileCheck --check-prefix=FOLD %s
|
||||
|
||||
// CHECK-LABEL: fusedConv2dRelu
|
||||
func @fusedConv2dRelu(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<256x30x30x16xf32> {
|
||||
@ -75,10 +78,10 @@ func @fuseSubIntoFollowingConv2d(%arg0: tensor<256x32x32x3xf32>) -> tensor<256x3
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fuseAddIntoDepthwiseConv2d
|
||||
func @fuseAddIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> {
|
||||
func @fuseAddIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> {
|
||||
%cst = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
|
||||
%cst_0 = constant dense<1.5> : tensor<16xf32>
|
||||
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %1 : tensor<256x30x30x16xf32>
|
||||
|
||||
@ -87,10 +90,10 @@ func @fuseAddIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fuseSubIntoDepthwiseConv2d
|
||||
func @fuseSubIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> {
|
||||
func @fuseSubIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> {
|
||||
%cst = constant dense<0.5> : tensor<16xf32>
|
||||
%cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
|
||||
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
%1 = "tfl.sub"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %1 : tensor<256x30x30x16xf32>
|
||||
|
||||
@ -128,10 +131,10 @@ func @fuseAddWithRelu6IntoConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fuseAddWithRelu6IntoDepthwiseConv2d
|
||||
func @fuseAddWithRelu6IntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> {
|
||||
func @fuseAddWithRelu6IntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> {
|
||||
%cst = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
|
||||
%cst_0 = constant dense<1.5> : tensor<16xf32>
|
||||
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "RELU6"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %1 : tensor<256x30x30x16xf32>
|
||||
|
||||
@ -302,6 +305,58 @@ func @FuseFullyConnectedAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf
|
||||
// CHECK: return %[[fc]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @FuseFullyConnectedReshapeAddConst
|
||||
// FOLD-LABEL: @FuseFullyConnectedReshapeAddConst
|
||||
func @FuseFullyConnectedReshapeAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
%cst = constant dense<3.0> : tensor<40x40xf32>
|
||||
%cst2 = constant dense<2.0> : tensor<40xf32>
|
||||
%shape1 = constant dense<[1, 40, 40]> : tensor<3xi32>
|
||||
%shape2 = constant dense<[40, 40]> : tensor<2xi32>
|
||||
|
||||
%0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, tensor<40x40xf32>) -> (tensor<40x40xf32>)
|
||||
%1 = "tfl.reshape"(%0, %shape1) : (tensor<40x40xf32>, tensor<3xi32>) -> tensor<1x40x40xf32>
|
||||
%2 = "tfl.add"(%1, %cst2) {fused_activation_function = "NONE"} : (tensor<1x40x40xf32>, tensor<40xf32>) -> tensor<1x40x40xf32>
|
||||
%3 = "tfl.reshape"(%2, %shape2) : (tensor<1x40x40xf32>, tensor<2xi32>) -> tensor<40x40xf32>
|
||||
|
||||
return %3 : tensor<40x40xf32>
|
||||
|
||||
// CHECK: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40x40xf32>
|
||||
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
|
||||
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%[[fc]]
|
||||
// CHECK: %[[rs2:.*]] = "tfl.reshape"(%[[rs1]]
|
||||
// CHECK: return %[[rs2]]
|
||||
|
||||
// FOLD: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40x40xf32>
|
||||
// FOLD: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
|
||||
// FOLD: return %[[fc]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @NotReorderReshapeAddIfNotBroadcastable
|
||||
func @NotReorderReshapeAddIfNotBroadcastable(%arg0: tensor<40x10x4xf32>) -> tensor<40x40xf32> {
|
||||
%cst = constant dense<2.0> : tensor<40xf32>
|
||||
%shape = constant dense<[40, 40]> : tensor<2xi32>
|
||||
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<40x10x4xf32>, tensor<2xi32>) -> tensor<40x40xf32>
|
||||
%2 = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40xf32>) -> tensor<40x40xf32>
|
||||
return %2 : tensor<40x40xf32>
|
||||
|
||||
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0
|
||||
// CHECK: %[[rs2:.*]] = "tfl.add"(%[[rs1]]
|
||||
// CHECK: return %[[rs2]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @NotReorderReshapeAddIfNotTailingDim
|
||||
func @NotReorderReshapeAddIfNotTailingDim(%arg0: tensor<40x40x1xf32>) -> tensor<40x40xf32> {
|
||||
%cst = constant dense<2.0> : tensor<1x40xf32>
|
||||
%shape = constant dense<[40, 40]> : tensor<2xi32>
|
||||
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<40x40x1xf32>, tensor<2xi32>) -> tensor<40x40xf32>
|
||||
%2 = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<1x40xf32>) -> tensor<40x40xf32>
|
||||
return %2 : tensor<40x40xf32>
|
||||
|
||||
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0
|
||||
// CHECK: %[[rs2:.*]] = "tfl.add"(%[[rs1]]
|
||||
// CHECK: return %[[rs2]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @FuseFullyConnectedRelu
|
||||
func @FuseFullyConnectedRelu(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> {
|
||||
%0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32>
|
||||
|
@ -1,5 +1,6 @@
|
||||
// RUN: tf-opt -tfl-prepare-composite-funcs-tf %s | FileCheck %s --dump-input-on-failure
|
||||
// RUN: tf-opt -tfl-prepare-composite-funcs-tf %s -split-input-file | FileCheck %s --dump-input-on-failure
|
||||
|
||||
module{
|
||||
func @embedding(%arg0: tensor<*xf32>, %arg1: tensor<*xi32>) -> tensor<*xf32> attributes {tf._implements = "embedding_matmul", tf._reference = "mlir"} {
|
||||
%0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf.ExpandDims"(%arg1, %0) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
@ -148,3 +149,39 @@ func @layernormalizedlstmcellsimple(%arg0: tensor<1x?xf32>, %arg1: tensor<3x4xf3
|
||||
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x?xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<3x1xf32>, tensor<3xf32>, tensor<1x3xf32>, tensor<1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x3xf32>
|
||||
// CHECK: [[VAL_104:%.*]] = tensor_cast [[VAL_105:%.*]] : tensor<1x3xf32> to tensor<1x?xf32>
|
||||
// CHECK: return [[VAL_104]] : tensor<1x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
func @inference_standard_lstm_7410(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.signature.is_stateful} {
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
|
||||
%1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
|
||||
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
|
||||
%3 = "tf.Add"(%2, %arg1) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
|
||||
%4 = "tf.Add"(%2, %arg2) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x?x10xf32>
|
||||
%5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
|
||||
%6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: func @inference_standard_lstm_7410([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<?x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.signature.is_stateful} {
|
||||
// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||
// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
|
||||
// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||
// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32>
|
||||
// CHECK: [[VAL_10:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
// CHECK: [[VAL_11:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
|
||||
// CHECK: [[VAL_13:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
// CHECK: [[VAL_14:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
|
||||
// CHECK: [[VAL_16:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
// CHECK: [[VAL_17:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
|
||||
// CHECK: [[VAL_19:%.*]] = constant unit
|
||||
// CHECK: [[VAL_20:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) ( {
|
||||
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<?x8x10xf32>
|
||||
// CHECK: return [[VAL_21:%.*]] : tensor<?x8x10xf32>
|
||||
|
||||
}
|
||||
|
@ -414,6 +414,14 @@ func @CheckNumerics(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
||||
// CHECK: return %arg0 : tensor<3xf32>
|
||||
}
|
||||
|
||||
func @placeholder_with_default(%arg0: tensor<3xf32>) -> tensor<3xf32> {
|
||||
%0 = "tf.PlaceholderWithDefault"(%arg0): (tensor<3xf32>) -> tensor<3xf32>
|
||||
return %0 : tensor<3xf32>
|
||||
// Should be converted to Identity and then from Identity to value
|
||||
// CHECK-LABEL: placeholder_with_default
|
||||
// CHECK: return %arg0 : tensor<3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @NoPadStridedSliceNonNewAxisMask
|
||||
func @NoPadStridedSliceNonNewAxisMask(%arg0: tensor<1x2x3x1xf32>) -> tensor<1x2x3x1xf32> {
|
||||
%cst = constant dense<0> : tensor<4xi32>
|
||||
@ -426,8 +434,8 @@ func @NoPadStridedSliceNonNewAxisMask(%arg0: tensor<1x2x3x1xf32>) -> tensor<1x2x
|
||||
// CHECK: %0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 15 : i64, ellipsis_mask = 0 : i64, end_mask = 15 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1x2x3x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @PadStridedSliceNewAxisMask
|
||||
func @PadStridedSliceNewAxisMask(%arg0: tensor<2x3xf32>) -> tensor<1x2x3x1xf32> {
|
||||
// CHECK-LABEL: @PadStridedSliceNewAxisMask1
|
||||
func @PadStridedSliceNewAxisMask1(%arg0: tensor<2x3xf32>) -> tensor<1x2x3x1xf32> {
|
||||
%cst = constant dense<0> : tensor<4xi32>
|
||||
%cst_0 = constant dense<1> : tensor<4xi32>
|
||||
%0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 9 : i64, shrink_axis_mask = 0 : i64} : (tensor<2x3xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
|
||||
@ -439,3 +447,12 @@ func @PadStridedSliceNewAxisMask(%arg0: tensor<2x3xf32>) -> tensor<1x2x3x1xf32>
|
||||
// CHECK: %0 = "tf.Reshape"(%arg0, %[[cst_1]]) : (tensor<2x3xf32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
|
||||
// CHECK: %1 = "tf.StridedSlice"(%0, %cst, %cst, %cst_0) {begin_mask = 15 : i64, ellipsis_mask = 0 : i64, end_mask = 15 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1x2x3x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @PadStridedSliceNewAxisMask2
|
||||
func @PadStridedSliceNewAxisMask2(%arg0: tensor<4x64x64x1xf32>) -> tensor<1x4x64x64xf32> {
|
||||
%cst = constant dense<0> : tensor<3xi32>
|
||||
%cst_0 = constant dense<1> : tensor<3xi32>
|
||||
%0 = "tf.Squeeze"(%arg0) {T = f32, _output_shapes = ["tfshape$dim { size: 4 } dim { size: 64 } dim { size: 64 }"], device = "", squeeze_dims = []} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
|
||||
%1 = "tf.StridedSlice"(%0, %cst, %cst, %cst_0) {Index = i32, T = f32, _output_shapes = ["tfshape$dim { size: 1 } dim { size: 4 } dim { size: 64 } dim { size: 64 }"], begin_mask = 6 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 1 : i64, shrink_axis_mask = 0 : i64} : (tensor<4x64x64xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x4x64x64xf32>
|
||||
return %1 : tensor<1x4x64x64xf32>
|
||||
}
|
||||
|
@ -43,6 +43,16 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
|
||||
quant_specs.inference_type != quant_specs.inference_input_type;
|
||||
pass_manager->addPass(
|
||||
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
|
||||
|
||||
if (quant_specs.default_ranges.first.hasValue() ||
|
||||
quant_specs.default_ranges.second.hasValue()) {
|
||||
pass_manager->addPass(mlir::TFL::CreateDefaultQuantParamsPass(
|
||||
quant_specs.default_ranges.first.getValueOr(0.0),
|
||||
quant_specs.default_ranges.second.getValueOr(0.0)));
|
||||
pass_manager->addPass(mlir::TFL::CreateQuantizePass());
|
||||
pass_manager->addPass(
|
||||
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
|
||||
}
|
||||
}
|
||||
|
||||
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
@ -115,7 +125,8 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
if (pass_config.emit_builtin_tflite_ops) {
|
||||
// Prepare for TFLite dialect, rerun canonicalization, and then legalize to
|
||||
// the TFLite dialect.
|
||||
pass_manager->addPass(mlir::TFL::CreatePrepareTFPass());
|
||||
pass_manager->addPass(
|
||||
mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul));
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
pass_manager->addPass(mlir::TFL::CreateLegalizeTFPass());
|
||||
pass_manager->addPass(mlir::TFL::CreateOptimizePass());
|
||||
|
@ -86,15 +86,15 @@ StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
|
||||
if (use_splatted_constant) {
|
||||
return tensorflow::GraphdefToSplattedMlirTranslateFunction(
|
||||
file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
|
||||
input_shapes, output_arrays, prune_unused_nodes,
|
||||
/*convert_legacy_fed_inputs=*/true,
|
||||
input_shapes, output_arrays, /*control_output_arrays=*/"",
|
||||
prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
|
||||
/*graph_as_function=*/false, /*upgrade_legacy=*/true, context);
|
||||
}
|
||||
return tensorflow::GraphdefToMlirTranslateFunction(
|
||||
file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
|
||||
input_shapes, output_arrays, prune_unused_nodes,
|
||||
/*convert_legacy_fed_inputs=*/true, /*graph_as_function=*/false,
|
||||
/*upgrade_legacy=*/true, context);
|
||||
input_shapes, output_arrays, /*control_output_arrays=*/"",
|
||||
prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
|
||||
/*graph_as_function=*/false, /*upgrade_legacy=*/true, context);
|
||||
}
|
||||
|
||||
Status ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
|
234
tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc
Normal file
234
tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc
Normal file
@ -0,0 +1,234 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/Functional.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The Pass to add default quantization parameters for the activations which
|
||||
// don't have quantization information. These default parameters are usually
|
||||
// not from real measurement, so this pass is only for test purpose.
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
// Includs an auto-generated function, which can retrieve the quantization
|
||||
// specification for an TFL operation. The signature of the function is
|
||||
// std::unique_pointer<OpQuantSpec> TFL::GetOpQuantSpec(Operation *)
|
||||
#include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc"
|
||||
|
||||
namespace {
|
||||
class DefaultQuantParamsPass : public FunctionPass<DefaultQuantParamsPass> {
|
||||
public:
|
||||
explicit DefaultQuantParamsPass(double default_min, double default_max)
|
||||
: default_min_(default_min), default_max_(default_max) {}
|
||||
|
||||
void runOnFunction() override;
|
||||
|
||||
private:
|
||||
// Whether the value is used as a bias input of another op. Here we assume
|
||||
// bias is used immediately by the user. This assumption is always correct
|
||||
// after constant folding.
|
||||
bool UsedAsBias(Value value) {
|
||||
for (auto &use : value.getUses()) {
|
||||
auto biases = TFL::GetOpQuantSpec(use.getOwner())->biases_params;
|
||||
if (biases.find(use.getOperandNumber()) != biases.end()) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Uses `quant_params` to quantize `value` and inserting a pair of
|
||||
// tfl.quantize and tfl.dequantize ops for this `value`.
|
||||
void QuantizeValue(OpBuilder builder, Value value,
|
||||
TFL::QuantParams quant_params);
|
||||
|
||||
// If the value hasn't been quantized, the functions adds it to `values`.
|
||||
void AddToWorkListIfUnquantized(Value value, std::vector<Value> *values);
|
||||
|
||||
// Converts the default min/max to the default quantization parameters.
|
||||
TFL::QuantParams GetDefaultQuantParams(Builder builder);
|
||||
|
||||
// Gets the quantization parameters for the bias of an operation by using the
|
||||
// quantization parameters from the non-biases operands.
|
||||
TFL::QuantParams GetQuantParamsForBias(Operation *op, int bias,
|
||||
const std::vector<int> &non_biases,
|
||||
TFL::AccumulatorScaleFunc func);
|
||||
|
||||
double default_min_;
|
||||
double default_max_;
|
||||
TFL::QuantParams default_quant_params_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void DefaultQuantParamsPass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
OpBuilder builder(func);
|
||||
|
||||
std::vector<Value> activation_values;
|
||||
std::vector<Value> bias_values;
|
||||
|
||||
// First of all, collect all the values (block arguments and op results) which
|
||||
// are required to be quantized.
|
||||
for (auto arg : func.getBody().begin()->getArguments()) {
|
||||
if (UsedAsBias(arg)) {
|
||||
AddToWorkListIfUnquantized(arg, &bias_values);
|
||||
} else {
|
||||
AddToWorkListIfUnquantized(arg, &activation_values);
|
||||
}
|
||||
}
|
||||
|
||||
func.walk([&](Operation *op) {
|
||||
if (op->isKnownTerminator() ||
|
||||
op->hasTrait<OpTrait::quant::NoQuantizableResult>())
|
||||
return;
|
||||
|
||||
for (auto res : op->getResults()) {
|
||||
if (UsedAsBias(res)) {
|
||||
AddToWorkListIfUnquantized(res, &bias_values);
|
||||
} else {
|
||||
AddToWorkListIfUnquantized(res, &activation_values);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Apply the default quantization parameters for these activation values.
|
||||
TFL::QuantParams default_params = GetDefaultQuantParams(builder);
|
||||
for (Value value : activation_values) {
|
||||
QuantizeValue(builder, value, default_params);
|
||||
}
|
||||
|
||||
// Since all the non-biases operands have quantization parameters now, we
|
||||
// should be able to propagate them to the bias operand.
|
||||
for (Value bias : bias_values) {
|
||||
Operation *op = *bias.user_begin();
|
||||
auto spec = TFL::GetOpQuantSpec(op);
|
||||
for (auto &it : spec->biases_params) {
|
||||
TFL::QuantParams bias_params = GetQuantParamsForBias(
|
||||
op, it.first, it.second.first, it.second.second);
|
||||
if (!bias_params) continue;
|
||||
QuantizeValue(builder, bias, bias_params);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DefaultQuantParamsPass::AddToWorkListIfUnquantized(
|
||||
Value value, std::vector<Value> *values) {
|
||||
// If the result isn't with float type, this result is an integer tensor and
|
||||
// doesn't require quantization.
|
||||
auto tensor_type = value.getType().dyn_cast<TensorType>();
|
||||
if (!tensor_type) {
|
||||
// There are none type values.
|
||||
return;
|
||||
}
|
||||
if (!tensor_type.getElementType().isF32()) return;
|
||||
|
||||
// If the result is consumed by a quantize op, it has been quantized.
|
||||
if (value.hasOneUse() &&
|
||||
llvm::isa<TFL::QuantizeOp>(*value.getUsers().begin()))
|
||||
return;
|
||||
|
||||
// Add this result to the list to apply the default value.
|
||||
values->push_back(value);
|
||||
}
|
||||
|
||||
void DefaultQuantParamsPass::QuantizeValue(OpBuilder builder, Value value,
|
||||
TFL::QuantParams quant_params) {
|
||||
Type expressed_type = value.getType();
|
||||
Type new_type = quant_params.castFromExpressedType(expressed_type);
|
||||
// This value isn't an expressed type (float), skip.
|
||||
if (!new_type) return;
|
||||
|
||||
Block &block = value.getParentRegion()->front();
|
||||
Operation *op = value.getDefiningOp();
|
||||
if (op) {
|
||||
builder.setInsertionPoint(&block, ++Block::iterator(op));
|
||||
} else {
|
||||
builder.setInsertionPointToStart(&block);
|
||||
}
|
||||
TypeAttr type_attr = TypeAttr::get(new_type);
|
||||
auto quantize = builder.create<TFL::QuantizeOp>(value.getLoc(), new_type,
|
||||
value, type_attr);
|
||||
auto dequantize = builder.create<TFL::DequantizeOp>(
|
||||
value.getLoc(), expressed_type, quantize.output());
|
||||
value.replaceAllUsesWith(dequantize);
|
||||
|
||||
// `quantize` is using `dequantize` now, so we should set its operand to
|
||||
// `value`.
|
||||
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
|
||||
}
|
||||
|
||||
TFL::QuantParams DefaultQuantParamsPass::GetQuantParamsForBias(
|
||||
Operation *op, int bias, const std::vector<int> &non_biases,
|
||||
TFL::AccumulatorScaleFunc func) {
|
||||
std::vector<quant::QuantizedType> non_bias_types;
|
||||
non_bias_types.reserve(non_biases.size());
|
||||
for (int non_bias : non_biases) {
|
||||
Operation *non_bias_define = op->getOperand(non_bias).getDefiningOp();
|
||||
if (auto dequant = llvm::dyn_cast<TFL::DequantizeOp>(non_bias_define)) {
|
||||
auto non_bias_type = dequant.input().getType().cast<TensorType>();
|
||||
auto non_bias_ele_type =
|
||||
non_bias_type.getElementType().cast<quant::QuantizedType>();
|
||||
non_bias_types.push_back(non_bias_ele_type);
|
||||
} else {
|
||||
// The non-bias hasn't been quantized, let's skip this bias.
|
||||
break;
|
||||
}
|
||||
}
|
||||
// The non-bias hasn't been quantized, let's skip this bias.
|
||||
if (non_bias_types.size() != non_biases.size()) return {};
|
||||
|
||||
return func(non_bias_types);
|
||||
}
|
||||
|
||||
TFL::QuantParams DefaultQuantParamsPass::GetDefaultQuantParams(
|
||||
Builder builder) {
|
||||
if (!default_quant_params_) {
|
||||
default_quant_params_ = quant::fakeQuantAttrsToType(
|
||||
builder.getUnknownLoc(),
|
||||
/*numBits=*/8, default_min_, default_max_, /*narrowRange=*/false,
|
||||
builder.getF32Type());
|
||||
}
|
||||
return default_quant_params_;
|
||||
}
|
||||
|
||||
// Creates an instance of the default quant parameters pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateDefaultQuantParamsPass(
|
||||
double default_min, double default_max) {
|
||||
return absl::make_unique<DefaultQuantParamsPass>(default_min, default_max);
|
||||
}
|
||||
|
||||
// Registers this pass with default values, only for test
|
||||
static PassRegistration<DefaultQuantParamsPass> pass(
|
||||
"tfl-default-quant",
|
||||
"Apply quantization with default quantization parameter", [] {
|
||||
return CreateDefaultQuantParamsPass(/*default_min=*/-1.0,
|
||||
/*default_max=*/1.0);
|
||||
});
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
@ -150,6 +150,7 @@ def : Pat<(TF_RoundOp $arg), (TFL_RoundOp $arg)>;
|
||||
def : Pat<(TF_RsqrtOp $arg), (TFL_RsqrtOp $arg)>;
|
||||
def : Pat<(TF_SqrtOp $arg), (TFL_SqrtOp $arg)>;
|
||||
def : Pat<(TF_SquareOp $arg), (TFL_SquareOp $arg)>;
|
||||
def : Pat<(TF_SegmentSumOp $data, I32Tensor:$segment_ids), (TFL_SegmentSumOp $data, $segment_ids)>;
|
||||
def : Pat<(TF_SelectOp $cond, $x, $y), (TFL_SelectOp $cond, $x, $y)>;
|
||||
def : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), (TFL_SelectOp $cond, $x, $y), [(HasSameStaticShapes $src_op)]>;
|
||||
def : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), (TFL_SelectV2Op $cond, $x, $y), [(HasNotSameStaticShapes $src_op)]>;
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user