Merge branch 'master' into interface_16x8
This commit is contained in:
commit
e7b615dc2e
12
.bazelrc
12
.bazelrc
@ -73,6 +73,10 @@
|
||||
# tensorflow_testing_rbe_linux: RBE options to use RBE with tensorflow-testing project on linux
|
||||
# tensorflow_testing_rbe_win: RBE options to use RBE with tensorflow-testing project on windows
|
||||
#
|
||||
# Embedded Linux options (experimental and only tested with TFLite build yet)
|
||||
# elinux: General Embedded Linux options shared by all flavors.
|
||||
# elinux_aarch64: Embedded Linux options for aarch64 (ARM64) CPU support.
|
||||
# elinux_armhf: Embedded Linux options for armhf (ARMv7) CPU support.
|
||||
|
||||
|
||||
|
||||
@ -432,6 +436,14 @@ build:tensorflow_testing_rbe_linux --config=rbe_linux
|
||||
|
||||
common:tensorflow_testing_rbe_win --remote_instance_name=projects/tensorflow-testing/instances/windows
|
||||
build:tensorflow_testing_rbe_win --config=tensorflow_testing_rbe
|
||||
|
||||
# TFLite build configs for generic embedded Linux
|
||||
build:elinux --crosstool_top=@local_config_embedded_arm//:toolchain
|
||||
build:elinux --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
|
||||
build:elinux_aarch64 --config=elinux
|
||||
build:elinux_aarch64 --cpu=aarch64
|
||||
build:elinux_armhf --config=elinux
|
||||
build:elinux_armhf --cpu=armhf
|
||||
# END TF REMOTE BUILD EXECUTION OPTIONS
|
||||
|
||||
# Default options should come above this line
|
||||
|
||||
@ -1 +1 @@
|
||||
2.0.0
|
||||
3.0.0
|
||||
|
||||
34
.github/ISSUE_TEMPLATE/00-bug-issue.md
vendored
34
.github/ISSUE_TEMPLATE/00-bug-issue.md
vendored
@ -10,32 +10,30 @@ labels: 'type:bug'
|
||||
we only address code/doc bugs, performance issues, feature requests and
|
||||
build/installation issues on GitHub. tag:bug_template</em>
|
||||
|
||||
**System information**
|
||||
- Have I written custom code (as opposed to using a stock
|
||||
example script provided in TensorFlow):
|
||||
- OS Platform and Distribution (e.g.,
|
||||
Linux Ubuntu 16.04):
|
||||
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
|
||||
the issue happens on mobile device:
|
||||
- TensorFlow installed from (source or
|
||||
binary): - TensorFlow version (use command below):
|
||||
- Python version: - Bazel
|
||||
version (if compiling from source):
|
||||
- GCC/Compiler version (if compiling from
|
||||
source):
|
||||
- CUDA/cuDNN version: - GPU model and memory:
|
||||
**System information**
|
||||
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
|
||||
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
|
||||
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
|
||||
- TensorFlow installed from (source or binary):
|
||||
- TensorFlow version (use command below):
|
||||
- Python version:
|
||||
- Bazel version (if compiling from source):
|
||||
- GCC/Compiler version (if compiling from source):
|
||||
- CUDA/cuDNN version:
|
||||
- GPU model and memory:
|
||||
|
||||
You can collect some of this information using our environment capture
|
||||
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
|
||||
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
|
||||
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
|
||||
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
||||
You can also obtain the TensorFlow version with:
|
||||
1. TF 1.0: `python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"`
|
||||
2. TF 2.0: `python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
||||
|
||||
|
||||
**Describe the current behavior**
|
||||
|
||||
**Describe the expected behavior**
|
||||
|
||||
**Standalone code to reproduce the issue**
|
||||
**Standalone code to reproduce the issue**
|
||||
Provide a reproducible test case that is the bare minimum necessary to generate
|
||||
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
|
||||
|
||||
|
||||
33
.github/ISSUE_TEMPLATE/80-performance-issue.md
vendored
33
.github/ISSUE_TEMPLATE/80-performance-issue.md
vendored
@ -11,32 +11,29 @@ As per our
|
||||
we only address code/doc bugs, performance issues, feature requests and
|
||||
build/installation issues on GitHub. tag:performance_template</em>
|
||||
|
||||
**System information**
|
||||
- Have I written custom code (as opposed to using a stock
|
||||
example script provided in TensorFlow):
|
||||
- OS Platform and Distribution (e.g.,
|
||||
Linux Ubuntu 16.04):
|
||||
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
|
||||
the issue happens on mobile device:
|
||||
- TensorFlow installed from (source or
|
||||
binary): - TensorFlow version (use command below):
|
||||
- Python version: - Bazel
|
||||
version (if compiling from source):
|
||||
- GCC/Compiler version (if compiling from
|
||||
source):
|
||||
- CUDA/cuDNN version: - GPU model and memory:
|
||||
**System information**
|
||||
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
|
||||
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
|
||||
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
|
||||
- TensorFlow installed from (source or binary):
|
||||
- TensorFlow version (use command below):
|
||||
- Python version:
|
||||
- Bazel version (if compiling from source):
|
||||
- GCC/Compiler version (if compiling from source):
|
||||
- CUDA/cuDNN version:
|
||||
- GPU model and memory:
|
||||
|
||||
You can collect some of this information using our environment capture
|
||||
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
|
||||
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
|
||||
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
|
||||
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
||||
You can also obtain the TensorFlow version with:
|
||||
1. TF 1.0: `python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"`
|
||||
2. TF 2.0: `python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
||||
|
||||
**Describe the current behavior**
|
||||
|
||||
**Describe the expected behavior**
|
||||
|
||||
**Standalone code to reproduce the issue**
|
||||
**Standalone code to reproduce the issue**
|
||||
Provide a reproducible test case that is the bare minimum necessary to generate
|
||||
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
|
||||
|
||||
|
||||
@ -2,6 +2,10 @@
|
||||
<img src="https://www.tensorflow.org/images/tf_logo_social.png">
|
||||
</div>
|
||||
|
||||
[](https://badge.fury.io/py/tensorflow)
|
||||
[](https://badge.fury.io/py/tensorflow)
|
||||
|
||||
|
||||
**`Documentation`** |
|
||||
------------------- |
|
||||
[](https://www.tensorflow.org/api_docs/) |
|
||||
|
||||
@ -2,58 +2,42 @@ package(default_visibility = ["//visibility:public"])
|
||||
|
||||
filegroup(
|
||||
name = "gcc",
|
||||
srcs = [
|
||||
"bin/arm-rpi-linux-gnueabihf-gcc",
|
||||
],
|
||||
srcs = glob(["bin/*-gcc"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "ar",
|
||||
srcs = [
|
||||
"bin/arm-rpi-linux-gnueabihf-ar",
|
||||
],
|
||||
srcs = glob(["bin/*-ar"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "ld",
|
||||
srcs = [
|
||||
"bin/arm-rpi-linux-gnueabihf-ld",
|
||||
],
|
||||
srcs = glob(["bin/*-ld"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "nm",
|
||||
srcs = [
|
||||
"bin/arm-rpi-linux-gnueabihf-nm",
|
||||
],
|
||||
srcs = glob(["bin/*-nm"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "objcopy",
|
||||
srcs = [
|
||||
"bin/arm-rpi-linux-gnueabihf-objcopy",
|
||||
],
|
||||
srcs = glob(["bin/*-objcopy"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "objdump",
|
||||
srcs = [
|
||||
"bin/arm-rpi-linux-gnueabihf-objdump",
|
||||
],
|
||||
srcs = glob(["bin/*-objdump"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "strip",
|
||||
srcs = [
|
||||
"bin/arm-rpi-linux-gnueabihf-strip",
|
||||
],
|
||||
srcs = glob(["bin/*-strip"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "as",
|
||||
srcs = [
|
||||
"bin/arm-rpi-linux-gnueabihf-as",
|
||||
],
|
||||
srcs = glob(["bin/*-as"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
@ -66,6 +50,16 @@ filegroup(
|
||||
]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "aarch64_compiler_pieces",
|
||||
srcs = glob([
|
||||
"aarch64-none-linux-gnu/**",
|
||||
"libexec/**",
|
||||
"lib/gcc/aarch64-none-linux-gnu/**",
|
||||
"include/**",
|
||||
]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "compiler_components",
|
||||
srcs = [
|
||||
|
||||
2
configure
vendored
2
configure
vendored
@ -4,7 +4,7 @@ set -e
|
||||
set -o pipefail
|
||||
|
||||
if [ -z "$PYTHON_BIN_PATH" ]; then
|
||||
PYTHON_BIN_PATH=$(which python || which python3 || true)
|
||||
PYTHON_BIN_PATH=$(which python3 || which python || true)
|
||||
fi
|
||||
|
||||
# Set all env variables
|
||||
|
||||
@ -50,7 +50,7 @@ _TF_WORKSPACE_ROOT = ''
|
||||
_TF_BAZELRC = ''
|
||||
_TF_CURRENT_BAZEL_VERSION = None
|
||||
_TF_MIN_BAZEL_VERSION = '2.0.0'
|
||||
_TF_MAX_BAZEL_VERSION = '2.0.0'
|
||||
_TF_MAX_BAZEL_VERSION = '3.99.0'
|
||||
|
||||
NCCL_LIB_PATHS = [
|
||||
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
|
||||
@ -58,8 +58,6 @@ NCCL_LIB_PATHS = [
|
||||
|
||||
# List of files to configure when building Bazel on Apple platforms.
|
||||
APPLE_BAZEL_FILES = [
|
||||
'tensorflow/lite/experimental/delegates/coreml/BUILD',
|
||||
'tensorflow/lite/experimental/delegates/coreml/builders/BUILD',
|
||||
'tensorflow/lite/experimental/ios/BUILD',
|
||||
'tensorflow/lite/experimental/objc/BUILD',
|
||||
'tensorflow/lite/experimental/swift/BUILD',
|
||||
|
||||
@ -214,6 +214,12 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "linux_armhf",
|
||||
values = {"cpu": "armhf"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "linux_x86_64",
|
||||
values = {"cpu": "k8"},
|
||||
@ -703,8 +709,8 @@ tf_cc_shared_object(
|
||||
"//tensorflow/c:version_script.lds",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/core:distributed_tensorflow_dependencies",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@ -186,10 +186,6 @@ struct TF_Server {
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
||||
|
||||
TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
|
||||
|
||||
Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
|
||||
TF_Buffer* out);
|
||||
|
||||
|
||||
@ -240,11 +240,6 @@ tf_cuda_cc_test(
|
||||
"c_api_remote_test.cc",
|
||||
],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = [
|
||||
"guitar",
|
||||
"multi_gpu",
|
||||
"no_oss",
|
||||
],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
@ -372,6 +367,22 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "custom_device_testutil",
|
||||
testonly = True,
|
||||
srcs = ["custom_device_testutil.cc"],
|
||||
hdrs = ["custom_device_testutil.h"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "custom_device_test",
|
||||
size = "small",
|
||||
@ -382,6 +393,7 @@ tf_cc_test(
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
":custom_device_testutil",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/cc/profiler",
|
||||
|
||||
@ -1587,6 +1587,7 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||
// require TFE_Op* and just convert it internally a NameAttrValue, so
|
||||
// consider adding an overload to the C API to make this case easier.
|
||||
TFE_OpSetAttrFunction(op, attr_name, func_op);
|
||||
TFE_DeleteOp(func_op);
|
||||
} break;
|
||||
case tensorflow::AttrValue::kList:
|
||||
TF_FALLTHROUGH_INTENDED;
|
||||
@ -1684,6 +1685,8 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
};
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
|
||||
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
||||
const char* device_name, void* device_info,
|
||||
TF_Status* status) {
|
||||
@ -1694,3 +1697,5 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
||||
status->status =
|
||||
context->RegisterCustomDevice(device_name, std::move(custom_device));
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
|
||||
@ -515,9 +515,11 @@ typedef struct TFE_CustomDevice {
|
||||
// This API is highly experimental, and in particular is expected to change when
|
||||
// it starts supporting operations with attributes and when tf.function support
|
||||
// is added.
|
||||
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
||||
const char* device_name, void* device_info,
|
||||
TF_Status* status);
|
||||
TF_CAPI_EXPORT extern void TFE_RegisterCustomDevice(TFE_Context* ctx,
|
||||
TFE_CustomDevice device,
|
||||
const char* device_name,
|
||||
void* device_info,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx,
|
||||
const char* function_name,
|
||||
|
||||
@ -129,7 +129,45 @@ void TestRemoteExecute(bool async) {
|
||||
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
|
||||
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
|
||||
|
||||
void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||
string MatMulFunction() {
|
||||
tensorflow::FunctionDef def;
|
||||
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||
" signature {"
|
||||
" name: 'MatMulFunction'"
|
||||
" input_arg {"
|
||||
" name: 'a'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" input_arg {"
|
||||
" name: 'b'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" output_arg {"
|
||||
" name: 'm'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'matmul'"
|
||||
" op: 'MatMul'"
|
||||
" input: 'a'"
|
||||
" input: 'b'"
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" ret {"
|
||||
" key: 'm'"
|
||||
" value: 'matmul:product'"
|
||||
" }",
|
||||
&def));
|
||||
return def.SerializeAsString();
|
||||
}
|
||||
|
||||
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
@ -169,12 +207,36 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Handles are on task0 (local), and task2, but op is on task1.
|
||||
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2);
|
||||
TFE_Op* matmul = nullptr;
|
||||
if (func) {
|
||||
string function_def = MatMulFunction();
|
||||
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||
status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
matmul = TFE_NewOp(ctx, "MatMulFunction", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(matmul, h0_task0, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(matmul, h1_task2, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
} else {
|
||||
// Handles are on task0 (local), and task2, but op is on task1.
|
||||
matmul = MatMulOp(ctx, h0_task0, h1_task2);
|
||||
}
|
||||
if (remote) {
|
||||
TFE_OpSetDevice(matmul, task1_name, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
} else if (!async) {
|
||||
// Set the local device to CPU to easily validate mirroring
|
||||
string cpu_device_name;
|
||||
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
|
||||
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
|
||||
// The input handles should never change since they have been mirrored.
|
||||
ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr));
|
||||
}
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
@ -182,12 +244,10 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// TODO(gjn): Add support for waiting on async local mirrors
|
||||
if (!async) {
|
||||
if (!remote && !async) {
|
||||
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
|
||||
tensorflow::EagerOperation* op =
|
||||
tensorflow::OperationFromInterface(matmul->operation);
|
||||
// The input handles should never change since they have been mirrored.
|
||||
ASSERT_EQ(op->Inputs()[1], remote_arg);
|
||||
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
|
||||
}
|
||||
|
||||
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
|
||||
@ -217,6 +277,9 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
if (func) {
|
||||
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
|
||||
}
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
@ -227,16 +290,22 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteSilentCopies) {
|
||||
TestRemoteExecuteSilentCopies(false, true);
|
||||
TestRemoteExecuteSilentCopies(false, true, false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
|
||||
TestRemoteExecuteSilentCopies(true, true);
|
||||
TestRemoteExecuteSilentCopies(true, true, false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
|
||||
TestRemoteExecuteSilentCopies(true, true, true);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
|
||||
TestRemoteExecuteSilentCopies(false, false);
|
||||
TestRemoteExecuteSilentCopies(false, false, false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
|
||||
TestRemoteExecuteSilentCopies(true, false);
|
||||
TestRemoteExecuteSilentCopies(true, false, false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
|
||||
TestRemoteExecuteSilentCopies(true, false, true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||
|
||||
@ -78,11 +78,18 @@ void BM_Execute(int iters, int async) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
TFE_Op* matmul = TFE_NewOp(ctx, "MatMul", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
TFE_OpReset(matmul, "MatMul", nullptr, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(matmul, m, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(matmul, m, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
}
|
||||
@ -113,11 +120,15 @@ void BM_Execute_Identity(int iters, int async) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* identity = IdentityOp(ctx, m);
|
||||
TFE_Op* identity = TFE_NewOp(ctx, "Identity", status);
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
TFE_OpReset(identity, "Identity", nullptr, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(identity, m, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_Execute(identity, &retvals[0], &num_retvals, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
}
|
||||
@ -405,6 +416,11 @@ void TensorHandleSilentCopy(bool async,
|
||||
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
|
||||
auto cpu_arg = tensorflow::TensorHandleFromInterface(hcpu->handle);
|
||||
auto gpu_arg = tensorflow::TensorHandleFromInterface(hgpu->handle);
|
||||
auto gpu_device = absl::get<tensorflow::Device*>(gpu_arg->device());
|
||||
ASSERT_FALSE(cpu_arg->HasLocalMirror(gpu_device));
|
||||
|
||||
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
|
||||
if (cpu_op) {
|
||||
string cpu_device_name;
|
||||
@ -420,15 +436,8 @@ void TensorHandleSilentCopy(bool async,
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Validate if the input was replaced with a different TensorHandle
|
||||
auto arg0 = tensorflow::TensorHandleFromInterface(hcpu->handle);
|
||||
auto arg1 = tensorflow::TensorHandleFromInterface(hgpu->handle);
|
||||
tensorflow::EagerOperation* op =
|
||||
tensorflow::OperationFromInterface(matmul->operation);
|
||||
|
||||
// The input handles should never change since they have been mirrored.
|
||||
EXPECT_EQ(op->Inputs()[0], arg0);
|
||||
EXPECT_EQ(op->Inputs()[1], arg1);
|
||||
// The CPU handle should have been copied and have a mirror on the GPU
|
||||
ASSERT_TRUE(cpu_arg->HasLocalMirror(gpu_device));
|
||||
|
||||
TFE_DeleteOp(matmul);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
@ -626,17 +635,6 @@ void ExecuteAdd(bool async, bool forward_input, bool tfrt) {
|
||||
}
|
||||
|
||||
int num_retvals = 1;
|
||||
|
||||
if (async) {
|
||||
// Enqueue dummy ops so we backlog async execution & actually test async.
|
||||
for (int i = 0; i < 10000; ++i) {
|
||||
TFE_TensorHandle* dummy = nullptr;
|
||||
TFE_Execute(add_op, &dummy, &num_retvals, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(dummy);
|
||||
}
|
||||
}
|
||||
|
||||
TFE_TensorHandle* retval = nullptr;
|
||||
TFE_Execute(add_op, &retval, &num_retvals, status);
|
||||
EXPECT_EQ(1, num_retvals);
|
||||
|
||||
@ -38,96 +38,159 @@ typedef void (*ExecuteOperation)(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs,
|
||||
TF_OutputList* o, TF_ExecutionContext* ctx,
|
||||
TF_Status* s);
|
||||
|
||||
struct TF_ExecutionContext {
|
||||
explicit TF_ExecutionContext() {}
|
||||
absl::variant<TFE_Context*, TF_GraphContext*> ctx;
|
||||
ExecuteOperation execution_callback;
|
||||
};
|
||||
// Needed to implement our own version of RTTI since dynamic_cast is not
|
||||
// supported in mobile builds.
|
||||
enum ExecutionContextKind { GraphContext, EagerContext };
|
||||
explicit TF_ExecutionContext(ExecutionContextKind kind) : k(kind) {}
|
||||
ExecutionContextKind getKind() const { return k; }
|
||||
|
||||
struct TF_AbstractTensor {
|
||||
absl::variant<TFE_TensorHandle*, TF_GraphTensor*> t;
|
||||
};
|
||||
virtual void ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs,
|
||||
TF_OutputList* o, TF_Status* s) = 0;
|
||||
virtual TF_AbstractOp* CreateOperation() = 0;
|
||||
virtual void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) = 0;
|
||||
virtual ~TF_ExecutionContext() {}
|
||||
|
||||
struct TF_AbstractOp {
|
||||
string op_type;
|
||||
string op_name;
|
||||
private:
|
||||
const ExecutionContextKind k;
|
||||
};
|
||||
|
||||
TF_ExecutionContext* TF_NewExecutionContext() {
|
||||
return new TF_ExecutionContext();
|
||||
}
|
||||
|
||||
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete c; }
|
||||
|
||||
TF_AbstractOp* TF_NewAbstractOp() {
|
||||
TF_AbstractOp* op = new TF_AbstractOp;
|
||||
return op;
|
||||
template <typename T, typename S>
|
||||
T* dynamic_cast_helper(S source) {
|
||||
if (source->getKind() != T::kKind) {
|
||||
return nullptr;
|
||||
}
|
||||
return tensorflow::down_cast<T*>(source);
|
||||
}
|
||||
|
||||
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete op; }
|
||||
|
||||
TF_AbstractTensor* TF_NewAbstractTensor() {
|
||||
TF_AbstractTensor* t = new TF_AbstractTensor;
|
||||
return t;
|
||||
}
|
||||
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete t; }
|
||||
|
||||
struct TF_GraphContext {
|
||||
TF_Graph* graph;
|
||||
// TODO(srbs): Handle captures.
|
||||
};
|
||||
|
||||
TF_GraphContext* TF_NewGraphContext(TF_Graph* g) {
|
||||
auto ctx = new TF_GraphContext;
|
||||
ctx->graph = g;
|
||||
return ctx;
|
||||
}
|
||||
|
||||
void TF_DeleteGraphContext(TF_GraphContext* ctx) { delete ctx; }
|
||||
class TF_GraphContext;
|
||||
class TF_EagerContext;
|
||||
|
||||
struct TF_GraphTensor {
|
||||
TF_Output output;
|
||||
TF_GraphContext* ctx;
|
||||
};
|
||||
TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* ctx, TF_Output output,
|
||||
TF_Status* s) {
|
||||
TF_GraphTensor* t = new TF_GraphTensor;
|
||||
t->output = output;
|
||||
t->ctx = ctx;
|
||||
return t;
|
||||
}
|
||||
TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s) {
|
||||
return t->output;
|
||||
}
|
||||
void TF_DeleteGraphTensor(TF_GraphTensor* t) { delete t; }
|
||||
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
|
||||
TF_Status* s) {
|
||||
at->t = t;
|
||||
}
|
||||
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s) {
|
||||
if (!absl::holds_alternative<TFE_TensorHandle*>(at->t)) {
|
||||
string msg = absl::StrCat("Not an eager tensor handle.",
|
||||
reinterpret_cast<uintptr_t>(at));
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return nullptr;
|
||||
|
||||
struct TF_AbstractTensor {
|
||||
absl::variant<TFE_TensorHandle*, TF_GraphTensor*> t;
|
||||
|
||||
~TF_AbstractTensor() {
|
||||
if (absl::holds_alternative<TFE_TensorHandle*>(t)) {
|
||||
TFE_DeleteTensorHandle(absl::get<TFE_TensorHandle*>(t));
|
||||
} else if (absl::holds_alternative<TF_GraphTensor*>(t)) {
|
||||
delete absl::get<TF_GraphTensor*>(t);
|
||||
}
|
||||
}
|
||||
return absl::get<TFE_TensorHandle*>(at->t);
|
||||
};
|
||||
|
||||
struct TF_AbstractOp {
|
||||
// Needed to implement our own version of RTTI since dynamic_cast is not
|
||||
// supported in mobile builds.
|
||||
enum AbstractOpKind { GraphOp, EagerOp };
|
||||
explicit TF_AbstractOp(AbstractOpKind kind) : k(kind) {}
|
||||
AbstractOpKind getKind() const { return k; }
|
||||
virtual void SetOpType(const char* const op_type, TF_Status* s) = 0;
|
||||
virtual void SetOpName(const char* const op_name, TF_Status* s) = 0;
|
||||
virtual void SetAttrType(const char* const attr_name, TF_DataType value,
|
||||
TF_Status* s) = 0;
|
||||
virtual ~TF_AbstractOp() {}
|
||||
|
||||
private:
|
||||
const AbstractOpKind k;
|
||||
};
|
||||
|
||||
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
|
||||
return c->CreateOperation();
|
||||
}
|
||||
void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t,
|
||||
TF_Status* s) {
|
||||
at->t = t;
|
||||
}
|
||||
TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s) {
|
||||
if (!absl::holds_alternative<TF_GraphTensor*>(at->t)) {
|
||||
string msg = absl::StrCat("Not an graph tensor handle.");
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return nullptr;
|
||||
|
||||
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete op; }
|
||||
|
||||
class TF_GraphOp : public TF_AbstractOp {
|
||||
public:
|
||||
explicit TF_GraphOp(TF_Graph* g) : TF_AbstractOp(kKind), g_(g) {}
|
||||
void SetOpType(const char* const op_type, TF_Status* s) override {
|
||||
if (op_) {
|
||||
TF_SetStatus(
|
||||
s, TF_FAILED_PRECONDITION,
|
||||
absl::StrCat("SetOpType called on already built op.").c_str());
|
||||
return;
|
||||
}
|
||||
if (op_name_ != nullptr) {
|
||||
op_.reset(TF_NewOperation(g_, op_type, op_name_));
|
||||
op_name_ = nullptr;
|
||||
} else {
|
||||
op_type_ = op_type;
|
||||
}
|
||||
}
|
||||
return absl::get<TF_GraphTensor*>(at->t);
|
||||
}
|
||||
void SetOpName(const char* const op_name, TF_Status* s) override {
|
||||
if (op_) {
|
||||
TF_SetStatus(
|
||||
s, TF_FAILED_PRECONDITION,
|
||||
absl::StrCat("SetOpName called on already built op.").c_str());
|
||||
return;
|
||||
}
|
||||
if (op_type_ != nullptr) {
|
||||
op_.reset(TF_NewOperation(g_, op_type_, op_name));
|
||||
op_type_ = nullptr;
|
||||
} else {
|
||||
op_name_ = op_name;
|
||||
}
|
||||
}
|
||||
void SetAttrType(const char* const attr_name, TF_DataType value,
|
||||
TF_Status* s) override {
|
||||
if (!op_) {
|
||||
TF_SetStatus(
|
||||
s, TF_FAILED_PRECONDITION,
|
||||
"op_type and op_name must be specified before specifying attrs.");
|
||||
return;
|
||||
}
|
||||
TF_SetAttrType(op_.get(), attr_name, value);
|
||||
}
|
||||
~TF_GraphOp() override {}
|
||||
|
||||
static constexpr AbstractOpKind kKind = GraphOp;
|
||||
|
||||
private:
|
||||
friend class TF_GraphContext; // For access to op_.
|
||||
TF_Graph* g_;
|
||||
std::unique_ptr<TF_OperationDescription> op_;
|
||||
// Hold `op_type` and `op_name` till both are available since we need both
|
||||
// to build a graph operation.
|
||||
const char* op_type_ = nullptr;
|
||||
const char* op_name_ = nullptr;
|
||||
};
|
||||
|
||||
class TF_EagerOp : public TF_AbstractOp {
|
||||
public:
|
||||
explicit TF_EagerOp(TFE_Context* ctx) : TF_AbstractOp(kKind), ctx_(ctx) {}
|
||||
void SetOpType(const char* const op_type, TF_Status* s) override {
|
||||
op_ = TFE_NewOp(ctx_, op_type, s);
|
||||
}
|
||||
void SetOpName(const char* const op_name, TF_Status* s) override {
|
||||
// Name is ignored in eager mode.
|
||||
}
|
||||
void SetAttrType(const char* const attr_name, TF_DataType value,
|
||||
TF_Status* s) override {
|
||||
if (op_ == nullptr) {
|
||||
TF_SetStatus(s, TF_FAILED_PRECONDITION,
|
||||
"op_type must be specified before specifying attrs.");
|
||||
return;
|
||||
}
|
||||
TFE_OpSetAttrType(op_, attr_name, value);
|
||||
}
|
||||
|
||||
~TF_EagerOp() override { TFE_DeleteOp(op_); }
|
||||
static constexpr AbstractOpKind kKind = EagerOp;
|
||||
|
||||
private:
|
||||
friend class TF_EagerContext; // For access to op_.
|
||||
TFE_Op* op_ = nullptr;
|
||||
TFE_Context* ctx_;
|
||||
};
|
||||
|
||||
bool IsEagerTensor(const TF_AbstractTensor* const t) {
|
||||
return absl::holds_alternative<TFE_TensorHandle*>(t->t);
|
||||
@ -138,6 +201,221 @@ struct TF_OutputList {
|
||||
int expected_num_outputs = -1;
|
||||
};
|
||||
|
||||
struct TF_AbstractFunction {
|
||||
TF_Function* func = nullptr;
|
||||
|
||||
~TF_AbstractFunction() { TF_DeleteFunction(func); }
|
||||
};
|
||||
|
||||
class TF_EagerContext : public TF_ExecutionContext {
|
||||
public:
|
||||
TF_EagerContext() : TF_ExecutionContext(kKind) {}
|
||||
|
||||
void Build(TFE_ContextOptions* options, TF_Status* status) {
|
||||
eager_ctx_ = TFE_NewContext(options, status);
|
||||
}
|
||||
|
||||
TF_AbstractOp* CreateOperation() override {
|
||||
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
||||
return new TF_EagerOp(eager_ctx_);
|
||||
}
|
||||
|
||||
void ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_Status* s) override {
|
||||
auto* eager_op = dynamic_cast_helper<TF_EagerOp>(op);
|
||||
if (eager_op == nullptr) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Unable to cast TF_AbstractOp to TF_EagerOp.");
|
||||
return;
|
||||
}
|
||||
auto* tfe_op = eager_op->op_;
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
if (!IsEagerTensor(inputs[i])) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
|
||||
return;
|
||||
}
|
||||
TFE_OpAddInput(tfe_op, absl::get<TFE_TensorHandle*>(inputs[i]->t), s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
}
|
||||
if (o->expected_num_outputs == -1) {
|
||||
string msg =
|
||||
"The number of outputs must be provided in eager mode. Use "
|
||||
"TF_OutputListSetNumOutputs.";
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return;
|
||||
}
|
||||
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals;
|
||||
int num_retvals = o->expected_num_outputs;
|
||||
retvals.resize(num_retvals);
|
||||
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
|
||||
if (TF_GetCode(s) != TF_OK) {
|
||||
return;
|
||||
}
|
||||
o->outputs.clear();
|
||||
o->outputs.reserve(num_retvals);
|
||||
for (int i = 0; i < num_retvals; ++i) {
|
||||
auto* t = new TF_AbstractTensor();
|
||||
t->t = retvals[i];
|
||||
o->outputs.push_back(t);
|
||||
}
|
||||
}
|
||||
|
||||
void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) override {
|
||||
TFE_ContextAddFunction(eager_ctx_, func->func, s);
|
||||
}
|
||||
|
||||
~TF_EagerContext() override { TFE_DeleteContext(eager_ctx_); }
|
||||
|
||||
static constexpr ExecutionContextKind kKind = EagerContext;
|
||||
|
||||
private:
|
||||
friend TFE_Context* TF_ExecutionContextGetTFEContext(
|
||||
TF_ExecutionContext* ctx);
|
||||
TFE_Context* eager_ctx_;
|
||||
};
|
||||
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete t; }
|
||||
|
||||
TF_GraphContext* GetGraphContext(TF_AbstractTensor const* t) {
|
||||
return absl::get<TF_GraphTensor*>(t->t)->ctx;
|
||||
}
|
||||
|
||||
class TF_GraphContext : public TF_ExecutionContext {
|
||||
public:
|
||||
TF_GraphContext()
|
||||
: TF_ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {}
|
||||
|
||||
TF_AbstractOp* CreateOperation() override {
|
||||
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
||||
return new TF_GraphOp(graph_.get());
|
||||
}
|
||||
|
||||
void ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_Status* s) override {
|
||||
auto* graph_op = dynamic_cast_helper<TF_GraphOp>(op);
|
||||
if (graph_op == nullptr) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Unable to cast TF_AbstractOp to TF_GraphOp.");
|
||||
return;
|
||||
}
|
||||
auto* tf_opdesc = graph_op->op_.release();
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
auto* input = inputs[i];
|
||||
if (IsEagerTensor(input)) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Capturing eager tensors is not supported yet.");
|
||||
return;
|
||||
} else {
|
||||
if (GetGraphContext(input) != this) {
|
||||
TF_SetStatus(
|
||||
s, TF_INVALID_ARGUMENT,
|
||||
"Capturing tensors from other graphs is not supported yet.");
|
||||
return;
|
||||
}
|
||||
TF_AddInput(tf_opdesc, absl::get<TF_GraphTensor*>(input->t)->output);
|
||||
}
|
||||
}
|
||||
auto* operation = TF_FinishOperation(tf_opdesc, s);
|
||||
// TF_FinishOperation deletes `tf_opdesc` so clear its reference.
|
||||
graph_op->op_ = nullptr;
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
int num_outputs = TF_OperationNumOutputs(operation);
|
||||
o->outputs.clear();
|
||||
o->outputs.reserve(num_outputs);
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
auto* t = new TF_AbstractTensor;
|
||||
TF_GraphTensor* graph_t = new TF_GraphTensor;
|
||||
graph_t->ctx = this;
|
||||
graph_t->output = {operation, i};
|
||||
t->t = graph_t;
|
||||
o->outputs.push_back(t);
|
||||
}
|
||||
}
|
||||
|
||||
TF_Function* ToFunction(const char* fn_name, int num_inputs,
|
||||
const TF_AbstractTensor* inputs, int num_outputs,
|
||||
const TF_AbstractTensor* outputs,
|
||||
TF_Status* status) const {
|
||||
std::vector<TF_Output> graph_inputs;
|
||||
graph_inputs.resize(num_inputs);
|
||||
std::vector<TF_Output> graph_outputs;
|
||||
graph_outputs.resize(num_outputs);
|
||||
for (int i = 0; i < num_inputs; i++) {
|
||||
graph_inputs[i] = absl::get<TF_GraphTensor*>(inputs[i].t)->output;
|
||||
}
|
||||
for (int i = 0; i < num_outputs; i++) {
|
||||
graph_outputs[i] = absl::get<TF_GraphTensor*>(outputs[i].t)->output;
|
||||
}
|
||||
|
||||
return TF_GraphToFunction(graph_.get(), fn_name, 0, -1, nullptr,
|
||||
graph_inputs.size(), graph_inputs.data(),
|
||||
graph_outputs.size(), graph_outputs.data(),
|
||||
nullptr, nullptr, fn_name, status);
|
||||
}
|
||||
|
||||
void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) override {
|
||||
TF_SetStatus(s, TF_UNIMPLEMENTED,
|
||||
"Registering graph functions has not been implemented yet.");
|
||||
}
|
||||
|
||||
~TF_GraphContext() override {}
|
||||
|
||||
static constexpr ExecutionContextKind kKind = GraphContext;
|
||||
|
||||
private:
|
||||
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
|
||||
};
|
||||
|
||||
struct TF_GraphContextOptions {};
|
||||
struct TF_EagerContextOptions {
|
||||
explicit TF_EagerContextOptions(TFE_ContextOptions* options)
|
||||
: options(options) {}
|
||||
TFE_ContextOptions* options; // Not owned.
|
||||
};
|
||||
|
||||
struct TF_ExecutionContextOptions {
|
||||
absl::variant<TF_GraphContextOptions*, TF_EagerContextOptions*> options;
|
||||
~TF_ExecutionContextOptions() {
|
||||
if (absl::holds_alternative<TF_GraphContextOptions*>(options)) {
|
||||
delete absl::get<TF_GraphContextOptions*>(options);
|
||||
} else if (absl::holds_alternative<TF_EagerContextOptions*>(options)) {
|
||||
delete absl::get<TF_EagerContextOptions*>(options);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TF_ExecutionContextOptions* TF_NewGraphContextOptions() {
|
||||
auto* options = new TF_ExecutionContextOptions();
|
||||
options->options = new TF_GraphContextOptions();
|
||||
return options;
|
||||
}
|
||||
|
||||
void TF_DeleteExecutionContextOptions(TF_ExecutionContextOptions* options) {
|
||||
delete options;
|
||||
}
|
||||
|
||||
TF_ExecutionContextOptions* TF_NewEagerContextOptions(
|
||||
TFE_ContextOptions* tfe_options) {
|
||||
auto* options = new TF_ExecutionContextOptions();
|
||||
options->options = new TF_EagerContextOptions(tfe_options);
|
||||
return options;
|
||||
}
|
||||
|
||||
TF_ExecutionContext* TF_NewExecutionContext(TF_ExecutionContextOptions* options,
|
||||
TF_Status* s) {
|
||||
if (absl::holds_alternative<TF_EagerContextOptions*>(options->options)) {
|
||||
auto* ctx = new TF_EagerContext();
|
||||
ctx->Build(absl::get<TF_EagerContextOptions*>(options->options)->options,
|
||||
s);
|
||||
return ctx;
|
||||
} else {
|
||||
return new TF_GraphContext();
|
||||
}
|
||||
}
|
||||
|
||||
TF_OutputList* TF_NewOutputList() { return new TF_OutputList; }
|
||||
void TF_DeleteOutputList(TF_OutputList* o) { delete o; }
|
||||
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
|
||||
@ -149,113 +427,74 @@ TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
|
||||
return o->outputs[i];
|
||||
}
|
||||
|
||||
void ExecuteOperationEager(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_ExecutionContext* ctx, TF_Status* s) {
|
||||
auto* tfe_op =
|
||||
TFE_NewOp(absl::get<TFE_Context*>(ctx->ctx), op->op_type.c_str(), s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
if (!IsEagerTensor(inputs[i])) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
|
||||
return;
|
||||
}
|
||||
TFE_OpAddInput(tfe_op, absl::get<TFE_TensorHandle*>(inputs[i]->t), s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
}
|
||||
if (o->expected_num_outputs == -1) {
|
||||
string msg =
|
||||
"The number of outputs must be provided in eager mode. Use "
|
||||
"TF_OutputListSetNumOutputs.";
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return;
|
||||
}
|
||||
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals;
|
||||
int num_retvals = o->expected_num_outputs;
|
||||
retvals.resize(num_retvals);
|
||||
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
|
||||
TFE_DeleteOp(tfe_op);
|
||||
if (TF_GetCode(s) != TF_OK) {
|
||||
return;
|
||||
}
|
||||
o->outputs.clear();
|
||||
o->outputs.reserve(num_retvals);
|
||||
for (int i = 0; i < num_retvals; ++i) {
|
||||
auto* t = TF_NewAbstractTensor();
|
||||
t->t = retvals[i];
|
||||
o->outputs.push_back(t);
|
||||
}
|
||||
}
|
||||
|
||||
TF_GraphContext* GetGraphContext(TF_AbstractTensor const* t) {
|
||||
return absl::get<TF_GraphTensor*>(t->t)->ctx;
|
||||
}
|
||||
|
||||
void ExecuteOperationGraph(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_ExecutionContext* ctx, TF_Status* s) {
|
||||
TF_GraphContext* graph_ctx = absl::get<TF_GraphContext*>(ctx->ctx);
|
||||
TF_Graph* g = graph_ctx->graph;
|
||||
auto* tf_opdesc =
|
||||
TF_NewOperation(g, op->op_type.c_str(), op->op_name.c_str());
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
auto* input = inputs[i];
|
||||
if (IsEagerTensor(input)) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Capturing eager tensors is not supported yet.");
|
||||
return;
|
||||
} else {
|
||||
if (GetGraphContext(input) != graph_ctx) {
|
||||
TF_SetStatus(
|
||||
s, TF_INVALID_ARGUMENT,
|
||||
"Capturing tensors from other graphs is not supported yet.");
|
||||
return;
|
||||
}
|
||||
TF_AddInput(tf_opdesc, absl::get<TF_GraphTensor*>(input->t)->output);
|
||||
}
|
||||
}
|
||||
auto* operation = TF_FinishOperation(tf_opdesc, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
int num_outputs = TF_OperationNumOutputs(operation);
|
||||
o->outputs.clear();
|
||||
o->outputs.reserve(num_outputs);
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
auto* t = TF_NewAbstractTensor();
|
||||
TF_GraphTensor* output_t = TF_NewGraphTensor(graph_ctx, {operation, i}, s);
|
||||
if (TF_GetCode(s) != TF_OK) {
|
||||
return;
|
||||
}
|
||||
t->t = output_t;
|
||||
o->outputs.push_back(t);
|
||||
}
|
||||
}
|
||||
|
||||
void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context,
|
||||
TFE_Context* eager_context,
|
||||
TF_Status* s) {
|
||||
context->ctx = eager_context;
|
||||
context->execution_callback = &ExecuteOperationEager;
|
||||
}
|
||||
|
||||
void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context,
|
||||
TF_GraphContext* graph_context,
|
||||
TF_Status* s) {
|
||||
context->ctx = graph_context;
|
||||
context->execution_callback = &ExecuteOperationGraph;
|
||||
}
|
||||
|
||||
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
||||
TF_Status* s) {
|
||||
op->op_type = op_type;
|
||||
op->SetOpType(op_type, s);
|
||||
}
|
||||
|
||||
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
|
||||
TF_Status* s) {
|
||||
op->op_name = op_name;
|
||||
op->SetOpName(op_name, s);
|
||||
}
|
||||
|
||||
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
|
||||
TF_DataType value, TF_Status* s) {
|
||||
op->SetAttrType(attr_name, value, s);
|
||||
}
|
||||
|
||||
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_ExecutionContext* ctx, TF_Status* s) {
|
||||
ctx->execution_callback(op, num_inputs, inputs, o, ctx, s);
|
||||
ctx->ExecuteOperation(op, num_inputs, inputs, o, s);
|
||||
}
|
||||
|
||||
TF_AbstractFunction* TF_ExecutionContextToFunction(
|
||||
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
||||
const TF_AbstractTensor* inputs, int num_outputs,
|
||||
const TF_AbstractTensor* outputs, TF_Status* status) {
|
||||
auto* graph_ctx = dynamic_cast_helper<const TF_GraphContext>(fn_body);
|
||||
if (graph_ctx == nullptr) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"fn_body is not a TF_GraphContext.");
|
||||
return nullptr;
|
||||
}
|
||||
TF_AbstractFunction* func = new TF_AbstractFunction;
|
||||
func->func = graph_ctx->ToFunction(fn_name, num_inputs, inputs, num_outputs,
|
||||
outputs, status);
|
||||
return func;
|
||||
}
|
||||
|
||||
void TF_DeleteAbstractFunction(TF_AbstractFunction* func) { delete func; }
|
||||
|
||||
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext* ctx,
|
||||
TF_AbstractFunction* func,
|
||||
TF_Status* s) {
|
||||
ctx->RegisterFunction(func, s);
|
||||
}
|
||||
|
||||
// Temporary APIs till we figure out how to create scalar valued Eager
|
||||
// tensors and how to get value out of eager abstract tensors.
|
||||
TF_AbstractTensor* TF_NewAbstractTensor() {
|
||||
TF_AbstractTensor* t = new TF_AbstractTensor;
|
||||
return t;
|
||||
}
|
||||
|
||||
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
|
||||
TF_Status* s) {
|
||||
at->t = t;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s) {
|
||||
if (!absl::holds_alternative<TFE_TensorHandle*>(at->t)) {
|
||||
string msg = absl::StrCat("Not an eager tensor handle.",
|
||||
reinterpret_cast<uintptr_t>(at));
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
return absl::get<TFE_TensorHandle*>(at->t);
|
||||
}
|
||||
|
||||
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx) {
|
||||
return dynamic_cast_helper<TF_EagerContext>(ctx)->eager_ctx_;
|
||||
}
|
||||
|
||||
@ -15,8 +15,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
|
||||
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
@ -41,32 +41,19 @@ typedef struct TF_AbstractTensor TF_AbstractTensor;
|
||||
// could contain the op type and other attributes.
|
||||
typedef struct TF_AbstractOp TF_AbstractOp;
|
||||
|
||||
TF_ExecutionContext* TF_NewExecutionContext();
|
||||
// `TF_ExecutionContextOptions` define what type of `TF_ExecutionContext` is
|
||||
// created. It can be used to pass context specific params.
|
||||
typedef struct TF_ExecutionContextOptions TF_ExecutionContextOptions;
|
||||
void TF_DeleteExecutionContextOptions(TF_ExecutionContextOptions*);
|
||||
|
||||
TF_ExecutionContext* TF_NewExecutionContext(TF_ExecutionContextOptions*,
|
||||
TF_Status* s);
|
||||
void TF_DeleteExecutionContext(TF_ExecutionContext*);
|
||||
|
||||
TF_AbstractOp* TF_NewAbstractOp();
|
||||
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx);
|
||||
void TF_DeleteAbstractOp(TF_AbstractOp*);
|
||||
|
||||
TF_AbstractTensor* TF_NewAbstractTensor();
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// APIs for Eager and graph modes
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// Keeps track of the current graph and other state e.g. captures etc.
|
||||
typedef struct TF_GraphContext TF_GraphContext;
|
||||
TF_GraphContext* TF_NewGraphContext(TF_Graph*);
|
||||
void TF_DeleteGraphContext(TF_GraphContext*);
|
||||
|
||||
// `eager_context` must outlive `context`.
|
||||
void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context,
|
||||
TFE_Context* eager_context, TF_Status*);
|
||||
// `graph_context` must outlive `context`.
|
||||
void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context,
|
||||
TF_GraphContext* graph_context,
|
||||
TF_Status*);
|
||||
|
||||
// TODO(srbs): Add APIs for specifying attrs etc.
|
||||
// `op_type` must outlive `op`.
|
||||
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
||||
@ -74,25 +61,9 @@ void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
||||
// `op_name` must outlive `op`.
|
||||
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
|
||||
TF_Status* s);
|
||||
|
||||
// Wrapper for TF_Output but contains a pointer to TF_GraphContext as well.
|
||||
typedef struct TF_GraphTensor TF_GraphTensor;
|
||||
TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* c, TF_Output t,
|
||||
TF_Status* s);
|
||||
TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s);
|
||||
void TF_DeleteGraphTensor(TF_GraphTensor* t);
|
||||
|
||||
// `t` must outlive `at`.
|
||||
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
|
||||
TF_Status* s);
|
||||
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s);
|
||||
|
||||
// `t` must outlive `at`.
|
||||
void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t,
|
||||
TF_Status* s);
|
||||
TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s);
|
||||
// `attr_name` must outlive `op`.
|
||||
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
|
||||
TF_DataType value, TF_Status* s);
|
||||
|
||||
// TF_OutputList just lets us not specify the number of outputs of an operation
|
||||
// beforehand. This forces a memory allocation in the runtime, which is bad, but
|
||||
@ -104,6 +75,17 @@ void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*);
|
||||
int TF_OutputListNumOutputs(TF_OutputList* o);
|
||||
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
|
||||
|
||||
// Stores a function representation that can be used for execution or for
|
||||
// setting functional attributes of other composite ops e.g. control flow.
|
||||
typedef struct TF_AbstractFunction TF_AbstractFunction;
|
||||
TF_AbstractFunction* TF_ExecutionContextToFunction(
|
||||
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
||||
const TF_AbstractTensor* inputs, int num_outputs,
|
||||
const TF_AbstractTensor* outputs, TF_Status* status);
|
||||
void TF_DeleteAbstractFunction(TF_AbstractFunction*);
|
||||
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext*,
|
||||
TF_AbstractFunction*, TF_Status*);
|
||||
|
||||
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
|
||||
// capture some inputs and then add a node in the graph, and after
|
||||
// execution/node creation it'll go and record things that happened in any tape
|
||||
@ -112,6 +94,23 @@ void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_ExecutionContext* ctx, TF_Status* s);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// APIs specific to Eager and graph modes
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
TF_ExecutionContextOptions* TF_NewGraphContextOptions();
|
||||
TF_ExecutionContextOptions* TF_NewEagerContextOptions(TFE_ContextOptions*);
|
||||
|
||||
// Temporary APIs till we figure out how to create scalar valued Eager
|
||||
// tensors and how to get value out of eager abstract tensors.
|
||||
TF_AbstractTensor* TF_NewAbstractTensor();
|
||||
void TF_AbstractTensorSetEagerTensor(
|
||||
TF_AbstractTensor* at, TFE_TensorHandle* t,
|
||||
TF_Status* s); // `at` takes ownership of `t`.
|
||||
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s);
|
||||
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext*);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
||||
@ -33,26 +33,25 @@ namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(UnifedCAPI, TestBasicEager) {
|
||||
TF_ExecutionContext* ctx = TF_NewExecutionContext();
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* eager_ctx = TFE_NewContext(opts, status.get());
|
||||
TF_ExecutionContextOptions* options = TF_NewEagerContextOptions(opts);
|
||||
TF_ExecutionContext* ctx = TF_NewExecutionContext(options, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
// Enter the eager context.
|
||||
TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract input tensor.
|
||||
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx);
|
||||
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||
TF_AbstractTensor* at = TF_NewAbstractTensor();
|
||||
TF_AbstractTensorSetEagerTensor(at, t, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract operation.
|
||||
auto* op = TF_NewAbstractOp();
|
||||
auto* op = TF_NewAbstractOp(ctx);
|
||||
TF_AbstractOpSetOpType(op, "Add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
@ -69,7 +68,6 @@ TEST(UnifedCAPI, TestBasicEager) {
|
||||
// Clean up operation and inputs.
|
||||
TF_DeleteAbstractOp(op);
|
||||
TF_DeleteAbstractTensor(at);
|
||||
TFE_DeleteTensorHandle(t);
|
||||
|
||||
// Verify the results.
|
||||
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
|
||||
@ -83,100 +81,98 @@ TEST(UnifedCAPI, TestBasicEager) {
|
||||
|
||||
TF_DeleteTensor(result_tensor);
|
||||
TF_DeleteAbstractTensor(result);
|
||||
TFE_DeleteTensorHandle(result_t);
|
||||
TF_DeleteOutputList(o);
|
||||
TFE_DeleteContext(eager_ctx);
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
TF_DeleteExecutionContextOptions(options);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TestBasicGraph) {
|
||||
TF_ExecutionContext* ctx = TF_NewExecutionContext();
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
||||
// Enter a graph context.
|
||||
TF_Graph* g = TF_NewGraph();
|
||||
TF_GraphContext* graph_context = TF_NewGraphContext(g);
|
||||
TF_ExecutionContextSetGraphContext(ctx, graph_context, status.get());
|
||||
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
|
||||
TF_ExecutionContext* graph_ctx =
|
||||
TF_NewExecutionContext(options, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
auto* placeholder_op = TF_NewOperation(g, "Placeholder", "Placeholder");
|
||||
TF_SetAttrType(placeholder_op, "dtype", TF_FLOAT);
|
||||
auto* operation = TF_FinishOperation(placeholder_op, status.get());
|
||||
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_Output placeholder_t = {operation, 0};
|
||||
TF_GraphTensor* graph_t =
|
||||
TF_NewGraphTensor(graph_context, placeholder_t, status.get());
|
||||
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractTensor* t = TF_NewAbstractTensor();
|
||||
TF_AbstractTensorSetGraphTensor(t, graph_t, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract operation.
|
||||
auto* op = TF_NewAbstractOp();
|
||||
TF_AbstractOpSetOpType(op, "Add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetOpName(op, "my_add", status.get());
|
||||
TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_AbstractTensor* inputs[2] = {t, t};
|
||||
TF_OutputList* o = TF_NewOutputList();
|
||||
TF_OutputList* placeholder_outputs = TF_NewOutputList();
|
||||
|
||||
// Execute.
|
||||
TF_ExecuteOperation(op, 2, inputs, o, ctx, status.get());
|
||||
TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs,
|
||||
graph_ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs));
|
||||
TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0);
|
||||
|
||||
// Delete placeholder op.
|
||||
TF_DeleteAbstractOp(placeholder_op);
|
||||
|
||||
// Build an abstract operation.
|
||||
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(add_op, "Add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetOpName(add_op, "my_add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t};
|
||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||
|
||||
// Execute.
|
||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractTensor* output_t = TF_OutputListGet(add_outputs, 0);
|
||||
|
||||
// Clean up operation and inputs.
|
||||
TF_DeleteAbstractOp(op);
|
||||
TF_DeleteAbstractTensor(t);
|
||||
TF_DeleteGraphTensor(graph_t);
|
||||
TF_DeleteAbstractOp(add_op);
|
||||
|
||||
TF_AbstractTensor* result = TF_OutputListGet(o, 0);
|
||||
TF_GraphTensor* result_graph_tensor =
|
||||
TF_AbstractTensorGetGraphTensor(result, status.get());
|
||||
TF_DeleteAbstractTensor(result);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_Output result_output =
|
||||
TF_GraphTensorToOutput(result_graph_tensor, status.get());
|
||||
TF_DeleteGraphTensor(result_graph_tensor);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
string fn_name = "double";
|
||||
TF_Function* f = TF_GraphToFunction(
|
||||
g, fn_name.c_str(), 0, -1, nullptr, 1, &placeholder_t, 1, &result_output,
|
||||
nullptr, nullptr, fn_name.c_str(), status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
|
||||
graph_ctx, fn_name.c_str(), 1, placeholder_t, 1, output_t, status.get());
|
||||
TF_DeleteAbstractTensor(placeholder_t);
|
||||
TF_DeleteAbstractTensor(output_t);
|
||||
|
||||
// Build an eager context to run the function.
|
||||
// Build eager context.
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* eager_ctx = TFE_NewContext(opts, status.get());
|
||||
TF_ExecutionContextOptions* eager_ctx_options =
|
||||
TF_NewEagerContextOptions(opts);
|
||||
TF_ExecutionContext* eager_execution_ctx =
|
||||
TF_NewExecutionContext(eager_ctx_options, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
// Build the abstract op to run the function.
|
||||
TFE_ContextAddFunction(eager_ctx, f, status.get());
|
||||
TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOp* fn_op = TF_NewAbstractOp();
|
||||
// Build the abstract op to run the function.
|
||||
TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx);
|
||||
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract input tensor.
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||
TF_AbstractTensor* input_t = TF_NewAbstractTensor();
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(eager_execution_ctx);
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||
TF_AbstractTensorSetEagerTensor(input_t, input_eager, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Enter the eager context.
|
||||
TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get());
|
||||
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_OutputListSetNumOutputs(o, 1, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_ExecuteOperation(fn_op, 1, &input_t, o, ctx, status.get());
|
||||
TF_ExecuteOperation(fn_op, 1, &input_t, add_outputs, eager_execution_ctx,
|
||||
status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
|
||||
TF_AbstractTensor* final_result = TF_OutputListGet(o, 0);
|
||||
ASSERT_EQ(1, TF_OutputListNumOutputs(add_outputs));
|
||||
TF_AbstractTensor* final_result = TF_OutputListGet(add_outputs, 0);
|
||||
TFE_TensorHandle* final =
|
||||
TF_AbstractTensorGetEagerTensor(final_result, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
@ -185,19 +181,201 @@ TEST(UnifedCAPI, TestBasicGraph) {
|
||||
float* f_value = static_cast<float*>(TF_TensorData(f_t));
|
||||
ASSERT_EQ(*f_value, 4.0);
|
||||
|
||||
TF_DeleteOutputList(o);
|
||||
TF_DeleteOutputList(add_outputs);
|
||||
TF_DeleteOutputList(placeholder_outputs);
|
||||
TF_DeleteAbstractOp(fn_op);
|
||||
TF_DeleteAbstractTensor(input_t);
|
||||
TFE_DeleteTensorHandle(input_eager);
|
||||
TF_DeleteAbstractTensor(final_result);
|
||||
TFE_DeleteTensorHandle(final);
|
||||
TF_DeleteTensor(f_t);
|
||||
TF_DeleteFunction(f);
|
||||
TF_DeleteAbstractFunction(func);
|
||||
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||
TF_DeleteExecutionContextOptions(eager_ctx_options);
|
||||
TF_DeleteExecutionContextOptions(options);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TF_ExecutionContextOptions* options = TF_NewEagerContextOptions(opts);
|
||||
TF_ExecutionContext* ctx = TF_NewExecutionContext(options, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
|
||||
ctx, nullptr, 0, nullptr, 0, nullptr, status.get());
|
||||
ASSERT_EQ(nullptr, func);
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
|
||||
TF_DeleteGraphContext(graph_context);
|
||||
TF_DeleteGraph(g);
|
||||
TFE_DeleteContext(eager_ctx);
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
TF_DeleteExecutionContextOptions(options);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
|
||||
TF_ExecutionContext* graph_ctx =
|
||||
TF_NewExecutionContext(options, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// This should fail.
|
||||
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
|
||||
ASSERT_EQ(TF_FAILED_PRECONDITION, TF_GetCode(status.get()));
|
||||
|
||||
TF_DeleteAbstractOp(placeholder_op);
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
TF_DeleteExecutionContextOptions(options);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
|
||||
TF_ExecutionContext* graph_ctx =
|
||||
TF_NewExecutionContext(options, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// This should fail.
|
||||
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
|
||||
ASSERT_EQ(TF_FAILED_PRECONDITION, TF_GetCode(status.get()));
|
||||
|
||||
TF_DeleteAbstractOp(placeholder_op);
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
TF_DeleteExecutionContextOptions(options);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||
// Build an Eager context.
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TF_ExecutionContextOptions* options = TF_NewEagerContextOptions(opts);
|
||||
TF_ExecutionContext* ctx = TF_NewExecutionContext(options, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an Eager operation.
|
||||
auto* op = TF_NewAbstractOp(ctx);
|
||||
TF_AbstractOpSetOpType(op, "Add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract input tensor.
|
||||
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx);
|
||||
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||
TF_AbstractTensor* at = TF_NewAbstractTensor();
|
||||
TF_AbstractTensorSetEagerTensor(at, t, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_AbstractTensor* inputs[2] = {at, at};
|
||||
TF_OutputList* o = TF_NewOutputList();
|
||||
TF_OutputListSetNumOutputs(o, 1, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build a Graph context.
|
||||
TF_ExecutionContextOptions* graph_options = TF_NewGraphContextOptions();
|
||||
TF_ExecutionContext* graph_ctx =
|
||||
TF_NewExecutionContext(graph_options, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Execute eager op using graph context.
|
||||
TF_ExecuteOperation(op, 2, inputs, o, graph_ctx, status.get());
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
|
||||
// Clean up operation and inputs.
|
||||
TF_DeleteAbstractOp(op);
|
||||
TF_DeleteAbstractTensor(at);
|
||||
|
||||
TF_DeleteOutputList(o);
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
TF_DeleteExecutionContextOptions(options);
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
TF_DeleteExecutionContextOptions(graph_options);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TestExecutingGraphOpInEagerModeRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
|
||||
TF_ExecutionContext* graph_ctx =
|
||||
TF_NewExecutionContext(options, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_OutputList* placeholder_outputs = TF_NewOutputList();
|
||||
|
||||
// Execute.
|
||||
TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs,
|
||||
graph_ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs));
|
||||
TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0);
|
||||
|
||||
// Delete placeholder op.
|
||||
TF_DeleteAbstractOp(placeholder_op);
|
||||
|
||||
// Build an abstract operation.
|
||||
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(add_op, "Add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetOpName(add_op, "my_add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t};
|
||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||
|
||||
// Build eager context.
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TF_ExecutionContextOptions* eager_ctx_options =
|
||||
TF_NewEagerContextOptions(opts);
|
||||
TF_ExecutionContext* eager_execution_ctx =
|
||||
TF_NewExecutionContext(eager_ctx_options, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
// Execute.
|
||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, eager_execution_ctx,
|
||||
status.get());
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
|
||||
// Clean up operation and inputs.
|
||||
TF_DeleteAbstractTensor(placeholder_t);
|
||||
TF_DeleteAbstractOp(add_op);
|
||||
TF_DeleteOutputList(add_outputs);
|
||||
TF_DeleteOutputList(placeholder_outputs);
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||
TF_DeleteExecutionContextOptions(eager_ctx_options);
|
||||
TF_DeleteExecutionContextOptions(options);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@ -16,134 +16,16 @@ limitations under the License.
|
||||
// A simple logging device to test custom device registration.
|
||||
#include <memory>
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/custom_device_testutil.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace {
|
||||
|
||||
struct LoggingDevice {
|
||||
tensorflow::string device_name;
|
||||
tensorflow::string underlying_device;
|
||||
// Set to true whenever a TensorHandle is copied onto the device
|
||||
bool* arrived_flag;
|
||||
// Set to true whenever an operation is executed
|
||||
bool* executed_flag;
|
||||
};
|
||||
|
||||
struct LoggedTensor {
|
||||
TFE_TensorHandle* tensor;
|
||||
LoggedTensor() = delete;
|
||||
explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
|
||||
~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
|
||||
};
|
||||
|
||||
void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
|
||||
delete reinterpret_cast<LoggedTensor*>(data);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* MakeLoggedTensorHandle(
|
||||
TFE_Context* context, const tensorflow::string& logging_device_name,
|
||||
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
|
||||
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
auto dtype = TFE_TensorHandleDataType(t->tensor);
|
||||
return TFE_NewTensorHandleFromDeviceMemory(
|
||||
context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
|
||||
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status, void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
|
||||
tensor, context, dev->underlying_device.c_str(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
auto dst = std::make_unique<LoggedTensor>(t);
|
||||
*(dev->arrived_flag) = true;
|
||||
return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
|
||||
status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
const char* target_device_name,
|
||||
TF_Status* status,
|
||||
void* device_info) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Trying to copy a tensor out of a logging device.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
|
||||
TFE_TensorHandle** inputs, const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* s,
|
||||
void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_Op* op(TFE_NewOp(context, operation_name, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpAddAttrs(op, attributes);
|
||||
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
|
||||
for (int j = 0; j < num_inputs; ++j) {
|
||||
TFE_TensorHandle* input = inputs[j];
|
||||
const char* input_device = TFE_TensorHandleDeviceName(input, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
if (dev->device_name == input_device) {
|
||||
LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
|
||||
TFE_TensorHandleDevicePointer(input, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpAddInput(op, t->tensor, s);
|
||||
} else {
|
||||
TFE_OpAddInput(op, input, s);
|
||||
}
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
|
||||
TFE_Execute(op, op_outputs.data(), num_outputs, s);
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
std::vector<TFE_TensorHandle*> unwrapped_outputs;
|
||||
for (auto* handle : op_outputs) {
|
||||
unwrapped_outputs.push_back(handle);
|
||||
}
|
||||
for (int i = 0; i < *num_outputs; ++i) {
|
||||
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
|
||||
outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
|
||||
std::move(logged_tensor), s);
|
||||
}
|
||||
*(dev->executed_flag) = true;
|
||||
}
|
||||
|
||||
void DeleteLoggingDevice(void* device_info) {
|
||||
delete reinterpret_cast<LoggingDevice*>(device_info);
|
||||
}
|
||||
|
||||
void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
bool* arrived_flag, bool* executed_flag,
|
||||
TF_Status* status) {
|
||||
TFE_CustomDevice custom_device;
|
||||
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
|
||||
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
|
||||
custom_device.delete_device = &DeleteLoggingDevice;
|
||||
custom_device.execute = &LoggingDeviceExecute;
|
||||
LoggingDevice* device = new LoggingDevice;
|
||||
device->arrived_flag = arrived_flag;
|
||||
device->executed_flag = executed_flag;
|
||||
device->device_name = name;
|
||||
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -276,9 +158,7 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
|
||||
tensorflow::string(
|
||||
TFE_TensorHandleBackingDeviceName(var_value, status.get())));
|
||||
TFE_TensorHandle* var_value_unpacked =
|
||||
reinterpret_cast<LoggedTensor*>(
|
||||
TFE_TensorHandleDevicePointer(var_value, status.get()))
|
||||
->tensor;
|
||||
UnpackTensorHandle(var_value, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> resolved_value(
|
||||
TFE_TensorHandleResolve(var_value_unpacked, status.get()),
|
||||
@ -296,7 +176,7 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
|
||||
TEST(CUSTOM_DEVICE, AccessVariableOnCustomDevice) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
@ -346,16 +226,21 @@ TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
|
||||
|
||||
// Read the variable's value.
|
||||
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
executed = false;
|
||||
num_retvals = 1;
|
||||
TFE_TensorHandle* var_value = nullptr;
|
||||
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
|
||||
EXPECT_FALSE(TF_GetCode(status.get()) == TF_OK)
|
||||
<< "Execution should fail because the variable is being used on the "
|
||||
"wrong device.";
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
ASSERT_EQ(
|
||||
tensorflow::string(name),
|
||||
tensorflow::string(TFE_TensorHandleDeviceName(var_value, status.get())));
|
||||
TFE_DeleteTensorHandle(var_value);
|
||||
|
||||
// Free the backing buffer for the variable.
|
||||
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
@ -366,6 +251,79 @@ TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, InputBasedPlacement) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* custom0 = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
const char* custom1 = "/job:localhost/replica:0/task:0/device:CUSTOM:1";
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
RegisterLoggingDevice(context.get(), custom0, &arrived, &executed,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
RegisterLoggingDevice(context.get(), custom1, &arrived, &executed,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcpu(
|
||||
TestMatrixTensorHandle(context.get()), TFE_DeleteTensorHandle);
|
||||
ASSERT_FALSE(arrived);
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcustom0(
|
||||
TFE_TensorHandleCopyToDevice(hcpu.get(), context.get(), custom0,
|
||||
status.get()),
|
||||
TFE_DeleteTensorHandle);
|
||||
ASSERT_TRUE(arrived);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
arrived = false;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcustom1(
|
||||
TFE_TensorHandleCopyToDevice(hcpu.get(), context.get(), custom1,
|
||||
status.get()),
|
||||
TFE_DeleteTensorHandle);
|
||||
ASSERT_TRUE(arrived);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Base case: two CPU inputs executes fine.
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> matmul(
|
||||
MatMulOp(context.get(), hcpu.get(), hcpu.get()), TFE_DeleteOp);
|
||||
TFE_TensorHandle* retval;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_DeleteTensorHandle(retval);
|
||||
|
||||
// Custom device: inputs in same custom device works.
|
||||
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcustom0.get()));
|
||||
num_retvals = 1;
|
||||
executed = false;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
TFE_DeleteTensorHandle(retval);
|
||||
|
||||
// Custom device: inputs in different custom devices fails.
|
||||
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcustom1.get()));
|
||||
num_retvals = 1;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
ASSERT_NE(TF_OK, TF_GetCode(status.get()));
|
||||
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
|
||||
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom1));
|
||||
|
||||
// Custom device: mix of custom/physical fails.
|
||||
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
|
||||
num_retvals = 1;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
ASSERT_NE(TF_OK, TF_GetCode(status.get()));
|
||||
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
|
||||
ASSERT_TRUE(
|
||||
absl::StrContains(TF_Message(status.get()), "[]")); // kVariantDeviceNull
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -394,5 +352,3 @@ TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
|
||||
<< TF_Message(status.get());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
172
tensorflow/c/eager/custom_device_testutil.cc
Normal file
172
tensorflow/c/eager/custom_device_testutil.cc
Normal file
@ -0,0 +1,172 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// A simple logging device to test custom device registration.
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace {
|
||||
|
||||
struct LoggingDevice {
|
||||
tensorflow::string device_name;
|
||||
tensorflow::string underlying_device;
|
||||
// Set to true whenever a TensorHandle is copied onto the device
|
||||
bool* arrived_flag;
|
||||
// Set to true whenever an operation is executed
|
||||
bool* executed_flag;
|
||||
};
|
||||
|
||||
struct LoggedTensor {
|
||||
TFE_TensorHandle* tensor;
|
||||
LoggedTensor() = delete;
|
||||
explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
|
||||
~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
|
||||
};
|
||||
|
||||
void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
|
||||
delete reinterpret_cast<LoggedTensor*>(data);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* MakeLoggedTensorHandle(
|
||||
TFE_Context* context, const tensorflow::string& logging_device_name,
|
||||
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
|
||||
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
auto dtype = TFE_TensorHandleDataType(t->tensor);
|
||||
return TFE_NewTensorHandleFromDeviceMemory(
|
||||
context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
|
||||
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status, void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
|
||||
tensor, context, dev->underlying_device.c_str(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
auto dst = std::make_unique<LoggedTensor>(t);
|
||||
*(dev->arrived_flag) = true;
|
||||
return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
|
||||
status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
const char* target_device_name,
|
||||
TF_Status* status,
|
||||
void* device_info) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Trying to copy a tensor out of a logging device.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
|
||||
TFE_TensorHandle** inputs, const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* s,
|
||||
void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_Op* op(TFE_NewOp(context, operation_name, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpAddAttrs(op, attributes);
|
||||
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
|
||||
for (int j = 0; j < num_inputs; ++j) {
|
||||
TFE_TensorHandle* input = inputs[j];
|
||||
const char* input_device = TFE_TensorHandleDeviceName(input, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
if (dev->device_name == input_device) {
|
||||
LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
|
||||
TFE_TensorHandleDevicePointer(input, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpAddInput(op, t->tensor, s);
|
||||
} else {
|
||||
TFE_OpAddInput(op, input, s);
|
||||
}
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
|
||||
TFE_Execute(op, op_outputs.data(), num_outputs, s);
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
std::vector<TFE_TensorHandle*> unwrapped_outputs;
|
||||
for (auto* handle : op_outputs) {
|
||||
unwrapped_outputs.push_back(handle);
|
||||
}
|
||||
for (int i = 0; i < *num_outputs; ++i) {
|
||||
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
|
||||
outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
|
||||
std::move(logged_tensor), s);
|
||||
}
|
||||
*(dev->executed_flag) = true;
|
||||
}
|
||||
|
||||
void DeleteLoggingDevice(void* device_info) {
|
||||
delete reinterpret_cast<LoggingDevice*>(device_info);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
bool* arrived_flag, bool* executed_flag,
|
||||
TF_Status* status) {
|
||||
TFE_CustomDevice custom_device;
|
||||
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
|
||||
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
|
||||
custom_device.delete_device = &DeleteLoggingDevice;
|
||||
custom_device.execute = &LoggingDeviceExecute;
|
||||
LoggingDevice* device = new LoggingDevice;
|
||||
device->arrived_flag = arrived_flag;
|
||||
device->executed_flag = executed_flag;
|
||||
device->device_name = name;
|
||||
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle,
|
||||
TF_Status* status) {
|
||||
return reinterpret_cast<LoggedTensor*>(
|
||||
TFE_TensorHandleDevicePointer(logged_tensor_handle, status))
|
||||
->tensor;
|
||||
}
|
||||
|
||||
void AllocateLoggingDevice(const char* name, bool* arrived_flag,
|
||||
bool* executed_flag, TFE_CustomDevice** device,
|
||||
void** device_info) {
|
||||
TFE_CustomDevice* custom_device = new TFE_CustomDevice;
|
||||
custom_device->copy_tensor_to_device = &CopyToLoggingDevice;
|
||||
custom_device->copy_tensor_from_device = &CopyTensorFromLoggingDevice;
|
||||
custom_device->delete_device = &DeleteLoggingDevice;
|
||||
custom_device->execute = &LoggingDeviceExecute;
|
||||
*device = custom_device;
|
||||
LoggingDevice* logging_device = new LoggingDevice;
|
||||
logging_device->arrived_flag = arrived_flag;
|
||||
logging_device->executed_flag = executed_flag;
|
||||
logging_device->device_name = name;
|
||||
logging_device->underlying_device =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
*device_info = reinterpret_cast<void*>(logging_device);
|
||||
}
|
||||
36
tensorflow/c/eager/custom_device_testutil.h
Normal file
36
tensorflow/c/eager/custom_device_testutil.h
Normal file
@ -0,0 +1,36 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
|
||||
#define TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
|
||||
|
||||
// A simple logging device to test custom device registration.
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
bool* arrived_flag, bool* executed_flag,
|
||||
TF_Status* status);
|
||||
void AllocateLoggingDevice(const char* name, bool* arrived_flag,
|
||||
bool* executed_flag, TFE_CustomDevice** device,
|
||||
void** device_info);
|
||||
TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle,
|
||||
TF_Status* status);
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
|
||||
@ -42,7 +42,28 @@ class AbstractOperationInterface {
|
||||
virtual Status Reset(const char* op, const char* raw_device_name) = 0;
|
||||
|
||||
virtual const string& Name() const = 0;
|
||||
|
||||
// Returns the operation's device name.
|
||||
//
|
||||
// The value returned may be different from the one set by SetDeviceName, but
|
||||
// it will be compatible with it: the name will be updated by device placement
|
||||
// logic to refer to the specific device chosen.
|
||||
//
|
||||
// Example: If one calls `op->SetDeviceName("/device:GPU")`, the value
|
||||
// returned by DeviceName should be "/device:GPU:*" until a particular GPU is
|
||||
// chosen for the operation by the device placement logic in the
|
||||
// executor. After that, the value returned by DeviceName will be a full
|
||||
// device name such as "/job:localhost/replica:0/task:0/device:GPU:1".
|
||||
virtual const string& DeviceName() const = 0;
|
||||
|
||||
// Sets the operation device name.
|
||||
//
|
||||
// The given `name` must be parseable by DeviceNameUtils::ParseFullName, and
|
||||
// the result will be used as a constraint for device placement. See the
|
||||
// documentation for DeviceName for more details.
|
||||
//
|
||||
// The value will override the previous value - that is, no "merging" of
|
||||
// existing and given constraints will be performed.
|
||||
virtual Status SetDeviceName(const char* name) = 0;
|
||||
|
||||
virtual Status AddInput(AbstractTensorHandleInterface* input) = 0;
|
||||
|
||||
38
tensorflow/c/eager/parallel_device/BUILD
Normal file
38
tensorflow/c/eager/parallel_device/BUILD
Normal file
@ -0,0 +1,38 @@
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
)
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "parallel_device",
|
||||
srcs = ["parallel_device.cc"],
|
||||
hdrs = ["parallel_device.h"],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "parallel_device_test",
|
||||
srcs = ["parallel_device_test.cc"],
|
||||
deps = [
|
||||
":parallel_device",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_experimental",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
597
tensorflow/c/eager/parallel_device/parallel_device.cc
Normal file
597
tensorflow/c/eager/parallel_device/parallel_device.cc
Normal file
@ -0,0 +1,597 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace eager {
|
||||
namespace {
|
||||
|
||||
// Functor for making unique_ptrs slightly more ergonomic. Using
|
||||
// decltype(delete_fn) in the unique_ptr's second template argument requires
|
||||
// passing a function pointer to delete_fn when constructing the unique_ptr.
|
||||
class TensorHandleDeleter {
|
||||
public:
|
||||
void operator()(TFE_TensorHandle* to_delete) const {
|
||||
TFE_DeleteTensorHandle(to_delete);
|
||||
}
|
||||
};
|
||||
|
||||
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
|
||||
|
||||
class OpDeleter {
|
||||
public:
|
||||
void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
|
||||
};
|
||||
|
||||
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
|
||||
|
||||
class ExecutorDeleter {
|
||||
public:
|
||||
void operator()(TFE_Executor* to_delete) const {
|
||||
TFE_DeleteExecutor(to_delete);
|
||||
}
|
||||
};
|
||||
|
||||
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
|
||||
|
||||
class ParallelTensor;
|
||||
|
||||
using MaybeParallelTensorOwned =
|
||||
absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
|
||||
using MaybeParallelTensorUnowned =
|
||||
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
|
||||
|
||||
// Creates a vector of `count` new executors (threads).
|
||||
std::vector<ExecutorPtr> MakeExecutors(size_t count) {
|
||||
std::vector<ExecutorPtr> executors;
|
||||
executors.reserve(count);
|
||||
for (int i = 0; i < count; ++i) {
|
||||
executors.emplace_back(TFE_NewExecutor(true /* is_async */));
|
||||
}
|
||||
return executors;
|
||||
}
|
||||
|
||||
// A representation of the custom device passed in and out of the TFE custom
|
||||
// device APIs, providing context about the parallel device to
|
||||
// ParallelDeviceExecute.
|
||||
class ParallelDevice {
|
||||
public:
|
||||
ParallelDevice(const std::string& name,
|
||||
const std::vector<std::string>& devices);
|
||||
|
||||
// Helper to copy a tensor handle from another device once for each component
|
||||
// of the ParallelDevice.
|
||||
//
|
||||
// Sets a bad status and returns a nullptr if `tensor` is already on the
|
||||
// ParallelDevice, or if the individual copies fail.
|
||||
std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status) const;
|
||||
|
||||
// Takes a description of a single operation being executed on the
|
||||
// ParallelDevice, and in turn runs one operation per component device with
|
||||
// its corresponding inputs from the input ParallelTensors (or
|
||||
// implicitly-mirrored tensors on other devices). Wraps the resulting
|
||||
// per-device and per-output TFE_TensorHandles into one ParallelTensor per
|
||||
// output of the original operation.
|
||||
//
|
||||
// `inputs` are either ParallelTensors, i.e. already on the ParallelDevice, or
|
||||
// un-replicated TFE_TensorHandles on other devices. TPUReplicatedInput
|
||||
// requires non-parallel tensors, and TPUReplicatedOutput requires a parallel
|
||||
// tensor, but other operations will implicitly broadcast non-parallel input
|
||||
// tensors across the ParallelDevice's component devices.
|
||||
//
|
||||
// Two special-cased operations, TPUReplicatedInput and TPUReplicatedOutput,
|
||||
// pack and un-pack parallel tensors respectively. Only TPUReplicatedOutput
|
||||
// causes `Execute` to return non-parallel tensors.
|
||||
//
|
||||
// Attributes are forwarded to executed operations unmodified.
|
||||
//
|
||||
// The returned optional has a value if and only if `status` evaluates to
|
||||
// TF_OK.
|
||||
absl::optional<std::vector<MaybeParallelTensorOwned>> Execute(
|
||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const;
|
||||
|
||||
// Implements the parallel case for `Execute`, where all of the outputs of the
|
||||
// operation are ParallelTensors, and all inputs are either ParallelTensors or
|
||||
// should be implicitly broadcast. This means the operation is not
|
||||
// TPUReplicatedInput or TPUReplicatedOutput.
|
||||
//
|
||||
// The returned optional has a value if and only if `status` evaluates to
|
||||
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
|
||||
// if sanity checks on dtypes/metadata fail.
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
ExecuteParallelOperation(TFE_Context* context,
|
||||
std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const char* operation_name,
|
||||
const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const;
|
||||
|
||||
const std::string& device_name() const { return device_name_; }
|
||||
|
||||
private:
|
||||
// The name of the parallel device
|
||||
// (e.g. "/job:localhost/replica:0/task:0/device:CUSTOM:0")
|
||||
const std::string device_name_;
|
||||
// A sequence of device names, indicating which devices replicated operations
|
||||
// are forwarded to.
|
||||
const std::vector<std::string> underlying_devices_;
|
||||
// A sequence of TFE_Executors, one per device, for executing operations in
|
||||
// parallel.
|
||||
const std::vector<ExecutorPtr> executors_;
|
||||
};
|
||||
|
||||
// The internal representation of a TFE_TensorHandle placed on a
|
||||
// ParallelDevice. Contains a tuple of tensors, one on each of the
|
||||
// `underlying_devices_` of the ParallelDevice.
|
||||
class ParallelTensor {
|
||||
public:
|
||||
// Construct a ParallelTensor from TensorHandles placed on the component
|
||||
// devices of a ParallelDevice.
|
||||
static std::unique_ptr<ParallelTensor> FromTensorHandles(
|
||||
const ParallelDevice& parallel_device,
|
||||
std::vector<TensorHandlePtr> components, TF_Status* status);
|
||||
|
||||
// Helper to wrap a ParallelTensor into a TFE_TensorHandle which contains it.
|
||||
static TensorHandlePtr AsTensorHandle(TFE_Context* context,
|
||||
std::unique_ptr<ParallelTensor> t,
|
||||
TF_Status* status);
|
||||
|
||||
size_t num_tensors() const { return tensors_.size(); }
|
||||
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
|
||||
|
||||
private:
|
||||
ParallelTensor(const ParallelDevice& device,
|
||||
std::vector<TensorHandlePtr> tensors,
|
||||
std::vector<int64_t> shape, const TF_DataType dtype)
|
||||
: device_(device),
|
||||
tensors_(std::move(tensors)),
|
||||
shape_(std::move(shape)),
|
||||
dtype_(dtype) {}
|
||||
|
||||
const ParallelDevice& device_;
|
||||
const std::vector<TensorHandlePtr> tensors_;
|
||||
const std::vector<int64_t> shape_;
|
||||
const TF_DataType dtype_;
|
||||
};
|
||||
|
||||
ParallelDevice::ParallelDevice(const std::string& name,
|
||||
const std::vector<std::string>& devices)
|
||||
: device_name_(name),
|
||||
underlying_devices_(devices),
|
||||
executors_(MakeExecutors(underlying_devices_.size())) {}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
|
||||
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
|
||||
const char* current_device = TFE_TensorHandleDeviceName(tensor, status);
|
||||
if (device_name_ == current_device) {
|
||||
std::string message(absl::StrCat(
|
||||
"Tried to copy a TensorHandle to its existing device: ", device_name_));
|
||||
TF_SetStatus(status, TF_INTERNAL, message.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<TensorHandlePtr> components;
|
||||
components.reserve(underlying_devices_.size());
|
||||
for (const std::string& underlying_device_name : underlying_devices_) {
|
||||
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
|
||||
tensor, context, underlying_device_name.c_str(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
components.emplace_back(t);
|
||||
}
|
||||
return ParallelTensor::FromTensorHandles(*this, std::move(components),
|
||||
status);
|
||||
}
|
||||
|
||||
absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const {
|
||||
absl::optional<std::vector<MaybeParallelTensorOwned>> result;
|
||||
// TODO(allenl): We should remove "TPU" from these op names at the very least,
|
||||
// or consider other ways of packing/unpacking parallel tensors.
|
||||
if (operation_name == std::string("TPUReplicatedInput")) {
|
||||
// Special-cased operation for packing per-device tensors into one parallel
|
||||
// tensor.
|
||||
if (inputs.size() != underlying_devices_.size()) {
|
||||
std::string message(absl::StrCat(
|
||||
"The parallel device ", device_name_, " expected ",
|
||||
underlying_devices_.size(), " inputs to TPUReplicatedInput, but got ",
|
||||
inputs.size()));
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
|
||||
return result;
|
||||
}
|
||||
std::vector<TensorHandlePtr> components;
|
||||
components.reserve(inputs.size());
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
if (absl::holds_alternative<ParallelTensor*>(inputs[i])) {
|
||||
std::string message(absl::StrCat(
|
||||
"Expected all inputs to TPUReplicatedInput to be non-parallel "
|
||||
"TensorHandles. The input ",
|
||||
i,
|
||||
" was a parallel tensor (already "
|
||||
"placed on the parallel device)."));
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
|
||||
return result;
|
||||
}
|
||||
components.emplace_back(TFE_TensorHandleCopySharingTensor(
|
||||
absl::get<TFE_TensorHandle*>(inputs[i]), status));
|
||||
}
|
||||
std::vector<MaybeParallelTensorOwned> result_content;
|
||||
result_content.reserve(1);
|
||||
result_content.push_back(ParallelTensor::FromTensorHandles(
|
||||
*this, std::move(components), status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
result.emplace(std::move(result_content));
|
||||
return result;
|
||||
} else if (operation_name == std::string("TPUReplicatedOutput")) {
|
||||
// Special-cased operation for un-packing one parallel tensor into
|
||||
// per-device tensors.
|
||||
OpPtr op(TFE_NewOp(context, operation_name, status));
|
||||
TFE_OpAddAttrs(op.get(), attributes);
|
||||
int expected_outputs = TFE_OpGetOutputLength(op.get(), "outputs", status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
if (expected_outputs != underlying_devices_.size()) {
|
||||
std::string message(absl::StrCat(
|
||||
"The parallel device ", device_name_, " expected ",
|
||||
underlying_devices_.size(),
|
||||
" outputs for TPUReplicatedOutput, but got ", expected_outputs));
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
|
||||
return result;
|
||||
}
|
||||
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[0])) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"Expected the input to "
|
||||
"TPUReplicatedOutput to be a parallel tensor (placed on the "
|
||||
"parallel device).");
|
||||
return result;
|
||||
}
|
||||
ParallelTensor* t = absl::get<ParallelTensor*>(inputs[0]);
|
||||
std::vector<MaybeParallelTensorOwned> outputs;
|
||||
outputs.reserve(t->num_tensors());
|
||||
for (int i = 0; i < t->num_tensors(); ++i) {
|
||||
TensorHandlePtr this_output(
|
||||
TFE_TensorHandleCopySharingTensor(t->tensor(i), status));
|
||||
outputs.emplace_back(std::move(this_output));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
result.emplace(std::move(outputs));
|
||||
return result;
|
||||
}
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
maybe_parallel_results(
|
||||
ExecuteParallelOperation(context, std::move(inputs), operation_name,
|
||||
attributes, expected_max_outputs, status));
|
||||
if (!maybe_parallel_results.has_value()) return result;
|
||||
std::vector<std::unique_ptr<ParallelTensor>> parallel_results(
|
||||
std::move(maybe_parallel_results.value()));
|
||||
std::vector<MaybeParallelTensorOwned> result_content;
|
||||
result_content.reserve(parallel_results.size());
|
||||
for (std::unique_ptr<ParallelTensor>& parallel_result : parallel_results) {
|
||||
result_content.push_back(
|
||||
MaybeParallelTensorOwned(std::move(parallel_result)));
|
||||
}
|
||||
result.emplace(std::move(result_content));
|
||||
return result;
|
||||
}
|
||||
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
ParallelDevice::ExecuteParallelOperation(
|
||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const {
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
|
||||
// Compute per-device per-output tensors
|
||||
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
|
||||
per_device_output_tensors.reserve(underlying_devices_.size());
|
||||
// TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep
|
||||
// setting the thread-local executor like this.
|
||||
TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
|
||||
auto reset_executor = gtl::MakeCleanup([context, previous_executor]() {
|
||||
TFE_ContextSetExecutorForThread(context, previous_executor);
|
||||
TFE_DeleteExecutor(previous_executor);
|
||||
});
|
||||
int first_op_output_count;
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
TFE_Executor* executor = executors_[device_index].get();
|
||||
// Note that the `reset_executor` cleanup sets the thread's executor back to
|
||||
// the value before this function ran.
|
||||
TFE_ContextSetExecutorForThread(context, executor);
|
||||
OpPtr op(TFE_NewOp(context, operation_name, status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
|
||||
status);
|
||||
TFE_OpAddAttrs(op.get(), attributes);
|
||||
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
|
||||
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[input_index])) {
|
||||
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
|
||||
// to each parallel operation.
|
||||
//
|
||||
// TODO(allenl): There may be smarter ways to do this copy in some
|
||||
// cases, i.e. with a collective broadcast. We'll need to be careful
|
||||
// about things that are taken as inputs on the host or on their
|
||||
// existing device (for multi-device functions).
|
||||
TFE_OpAddInput(op.get(),
|
||||
absl::get<TFE_TensorHandle*>(inputs[input_index]),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
} else {
|
||||
// Parallel tensors are divided between operations by device.
|
||||
TFE_OpAddInput(op.get(),
|
||||
absl::get<ParallelTensor*>(inputs[input_index])
|
||||
->tensor(device_index),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
|
||||
int real_num_outputs = expected_max_outputs;
|
||||
// For nested devices, the inner device sees the async executor we've
|
||||
// set. Inner parallel devices will just overwrite this with their own and
|
||||
// then set it back to ours before returning. This means parallel devices
|
||||
// which consist of several aliased parallel devices would hypothetically
|
||||
// deadlock if the outer parallel device ran one collective with a group
|
||||
// size equal to the total number of aliased physical devices. Currently
|
||||
// physical devices cannot participate in a single collective reduction
|
||||
// multiple times, so this would fail earlier.
|
||||
//
|
||||
// TODO(allenl): Keep a map from outer executor to list of inner executors
|
||||
// rather than a single list of executors so aliased nested parallel devices
|
||||
// don't re-use an executor.
|
||||
TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
|
||||
if (device_index == 0) {
|
||||
first_op_output_count = real_num_outputs;
|
||||
} else {
|
||||
if (real_num_outputs != first_op_output_count) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Parallel ops produced different numbers of tensors.");
|
||||
return result;
|
||||
}
|
||||
}
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
std::vector<TensorHandlePtr> this_outputs;
|
||||
this_outputs.reserve(real_num_outputs);
|
||||
for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
|
||||
this_outputs.emplace_back(op_outputs[output_num]);
|
||||
}
|
||||
per_device_output_tensors.push_back(std::move(this_outputs));
|
||||
}
|
||||
// For each output of the original operation, pack the per-device
|
||||
// TensorHandles we've computed into a single parallel TensorHandle.
|
||||
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
|
||||
per_device_outputs.reserve(first_op_output_count);
|
||||
for (int i = 0; i < first_op_output_count; ++i) {
|
||||
std::vector<TensorHandlePtr> components;
|
||||
components.reserve(underlying_devices_.size());
|
||||
for (int j = 0; j < underlying_devices_.size(); ++j) {
|
||||
components.push_back(std::move(per_device_output_tensors[j][i]));
|
||||
}
|
||||
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
|
||||
*this, std::move(components), status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
result.emplace(std::move(per_device_outputs));
|
||||
return result;
|
||||
}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
|
||||
const ParallelDevice& parallel_device,
|
||||
std::vector<TensorHandlePtr> components, TF_Status* status) {
|
||||
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
|
||||
std::vector<int64_t> shape(
|
||||
TFE_TensorHandleNumDims(components[0].get(), status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
shape[i] = TFE_TensorHandleDim(components[0].get(), i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
|
||||
// Verify that the TensorHandle's shape and dtype match all of the component
|
||||
// shapes and dtypes.
|
||||
for (TensorHandlePtr& component : components) {
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
if (tensor_dim != shape[i]) {
|
||||
// TODO(allenl): Allow shapes to differ.
|
||||
TF_SetStatus(status, TF_UNIMPLEMENTED,
|
||||
"Components of a ParallelTensor must currently all have "
|
||||
"the same shape");
|
||||
return nullptr;
|
||||
}
|
||||
if (TFE_TensorHandleDataType(component.get()) != dtype) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Components of a ParallelTensor must all have "
|
||||
"the same dtype");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::unique_ptr<ParallelTensor>(new ParallelTensor(
|
||||
parallel_device, std::move(components), std::move(shape), dtype));
|
||||
}
|
||||
|
||||
// Used as an argument to TFE_NewTensorHandleFromDeviceMemory, indicating how
|
||||
// ParallelTensors wrapped in TFE_TensorHandles should be cleaned up once their
|
||||
// reference counts drop to zero.
|
||||
void ParallelTensorDeallocator(void* data, size_t len, void* arg) {
|
||||
delete reinterpret_cast<ParallelTensor*>(data);
|
||||
}
|
||||
|
||||
TensorHandlePtr ParallelTensor::AsTensorHandle(
|
||||
TFE_Context* context, std::unique_ptr<ParallelTensor> t,
|
||||
TF_Status* status) {
|
||||
// The resulting TensorHandle owns an opaque pointer to "device memory", which
|
||||
// for a ParallelDevice is really a ParallelTensor. When the TensorHandle is
|
||||
// deleted, it will call ParallelTensorDeallocator to free the struct.
|
||||
ParallelTensor* t_released = t.release();
|
||||
return TensorHandlePtr(TFE_NewTensorHandleFromDeviceMemory(
|
||||
context, t_released->device_.device_name().c_str(), t_released->dtype_,
|
||||
t_released->shape_.data(), t_released->shape_.size(), t_released, 1,
|
||||
&ParallelTensorDeallocator, nullptr, status));
|
||||
}
|
||||
|
||||
// For TFE_CustomDevice::copy_tensor_to_device in the parallel device
|
||||
// registration.
|
||||
//
|
||||
// Replicates a single TFE_TensorHandle, producing a TFE_TensorHandle containing
|
||||
// a ParallelTensor with one copy of `tensor` for each device in the
|
||||
// ParallelDevice.
|
||||
//
|
||||
// Since this function is used to satisfy the TFE_CustomDevice C API,
|
||||
// device_info is passed in using a C-style generic. It must always be a
|
||||
// ParallelDevice.
|
||||
TFE_TensorHandle* CopyToParallelDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status, void* device_info) {
|
||||
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info);
|
||||
std::unique_ptr<ParallelTensor> parallel_tensor(
|
||||
dev->CopyToParallelDevice(context, tensor, status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return ParallelTensor::AsTensorHandle(context, std::move(parallel_tensor),
|
||||
status)
|
||||
.release();
|
||||
}
|
||||
|
||||
// For TFE_CustomDevice::copy_tensor_from_device in the parallel device
|
||||
// registration.
|
||||
//
|
||||
// Currently this is an error, and un-packing ParallelTensors must be performed
|
||||
// explicitly by running a TPUReplicatedOutput operation on the parallel device.
|
||||
//
|
||||
// TODO(allenl): There are some use-cases that are only supported by copying to
|
||||
// host at the moment (e.g. debug print on a tensor, .numpy(), etc.). We either
|
||||
// need to return something here or address these use-cases one by one.
|
||||
TFE_TensorHandle* CopyTensorFromParallelDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
const char* target_device_name,
|
||||
TF_Status* status,
|
||||
void* device_info) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Trying to copy a tensor out of a parallel device.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// For TFE_CustomDevice::execute in the parallel device registration.
|
||||
//
|
||||
// Since this function is used to satisfy the TFE_CustomDevice C API,
|
||||
// device_info is passed in using a C-style generic. It must always be a
|
||||
// ParallelDevice.
|
||||
void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
|
||||
TFE_TensorHandle** inputs,
|
||||
const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* status,
|
||||
void* device_info) {
|
||||
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info);
|
||||
std::vector<MaybeParallelTensorUnowned> typed_inputs;
|
||||
typed_inputs.reserve(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
const char* tensor_handle_device =
|
||||
TFE_TensorHandleDeviceName(inputs[i], status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (dev->device_name() == tensor_handle_device) {
|
||||
// We assume that any tensors already placed on this device are
|
||||
// ParallelTensors.
|
||||
typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
|
||||
TFE_TensorHandleDevicePointer(inputs[i], status)));
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
} else {
|
||||
typed_inputs.emplace_back(inputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
absl::optional<std::vector<MaybeParallelTensorOwned>> maybe_typed_outputs(
|
||||
dev->Execute(context, std::move(typed_inputs), operation_name, attributes,
|
||||
*num_outputs, status));
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!maybe_typed_outputs.has_value()) {
|
||||
TF_SetStatus(status, TF_INTERNAL, "OK status but no value was returned.");
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<MaybeParallelTensorOwned> typed_outputs(
|
||||
std::move(maybe_typed_outputs.value()));
|
||||
|
||||
if (typed_outputs.size() > *num_outputs) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"The allocated output buffer was too small.");
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < typed_outputs.size(); ++i) {
|
||||
MaybeParallelTensorOwned typed_output(std::move(typed_outputs[i]));
|
||||
if (absl::holds_alternative<TensorHandlePtr>(typed_output)) {
|
||||
outputs[i] = absl::get<TensorHandlePtr>(typed_output).release();
|
||||
} else {
|
||||
outputs[i] = ParallelTensor::AsTensorHandle(
|
||||
context,
|
||||
std::move(absl::get<std::unique_ptr<ParallelTensor>>(
|
||||
typed_output)),
|
||||
status)
|
||||
.release();
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
}
|
||||
}
|
||||
*num_outputs = typed_outputs.size();
|
||||
}
|
||||
|
||||
// For TFE_CustomDevice::delete_device in the parallel device registration.
|
||||
//
|
||||
// Since this function is used to satisfy the TFE_CustomDevice C API,
|
||||
// device_info is passed in using a C-style generic. It must always be a
|
||||
// ParallelDevice.
|
||||
void DeleteParallelDevice(void* device_info) {
|
||||
delete reinterpret_cast<ParallelDevice*>(device_info);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void RegisterParallelDevice(TFE_Context* context, const char* device_name,
|
||||
const char** underlying_devices,
|
||||
int num_underlying_devices, TF_Status* status) {
|
||||
TFE_CustomDevice custom_device;
|
||||
custom_device.copy_tensor_to_device = &CopyToParallelDevice;
|
||||
custom_device.copy_tensor_from_device = &CopyTensorFromParallelDevice;
|
||||
custom_device.delete_device = &DeleteParallelDevice;
|
||||
custom_device.execute = &ParallelDeviceExecute;
|
||||
std::vector<std::string> underlying_devices_vector;
|
||||
underlying_devices_vector.reserve(num_underlying_devices);
|
||||
for (int device_index = 0; device_index < num_underlying_devices;
|
||||
++device_index) {
|
||||
underlying_devices_vector.push_back(underlying_devices[device_index]);
|
||||
}
|
||||
ParallelDevice* d =
|
||||
new ParallelDevice(device_name, underlying_devices_vector);
|
||||
TFE_RegisterCustomDevice(context, custom_device, device_name, d, status);
|
||||
}
|
||||
|
||||
} // namespace eager
|
||||
} // namespace tensorflow
|
||||
62
tensorflow/c/eager/parallel_device/parallel_device.h
Normal file
62
tensorflow/c/eager/parallel_device/parallel_device.h
Normal file
@ -0,0 +1,62 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
|
||||
#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace eager {
|
||||
|
||||
// Register a parallel device named `device_name` which forwards operations to
|
||||
// `underlying_devices`, maintaining "parallel tensors" with components placed
|
||||
// on each underlying device.
|
||||
//
|
||||
// For example if `device_name` is
|
||||
// "/job:localhost/replica:0/task:0/device:CUSTOM:0"
|
||||
// and `underlying_devices` is
|
||||
// {"/job:localhost/replica:0/task:0/device:GPU:0",
|
||||
// "/job:localhost/replica:0/task:0/device:GPU:1"}
|
||||
// Then executing an operation on CUSTOM:0 will execute it on GPU:0 and GPU:1.
|
||||
//
|
||||
// Implicit copies onto `device_name` are allowed, replicating the value once
|
||||
// per device in `underlying_devices`. Implicit copies off of the device throw
|
||||
// an error.
|
||||
//
|
||||
// All component tensors must have the same dtype. Currently they must also have
|
||||
// the same shape, although this requirement may be relaxed in the future.
|
||||
//
|
||||
// `device_name` must not name an existing physical or custom device (see
|
||||
// the documentation for TFE_RegisterCustomDevice for more information).
|
||||
//
|
||||
// Tensors may be copied on or off the device explicitly using
|
||||
// TPUReplicatedInput and TPUReplicatedOutput respectively. For example, with
|
||||
// two component devices, running `x = TPUReplicatedInput(inputs=[a, b])` on the
|
||||
// parallel device creates a parallel tensor `x` with `a` on the first of
|
||||
// `underlying_devices` and `b` on the second. Running `a_unpacked, b_unpacked =
|
||||
// TPUReplicatedOutput(input=x, num_replicas=2)` un-packs the parallel tensor
|
||||
// into its components.
|
||||
//
|
||||
// `context` owns the parallel device. `underlying_devices` must stay valid
|
||||
// while the parallel device is in use.
|
||||
void RegisterParallelDevice(TFE_Context* context, const char* device_name,
|
||||
const char** underlying_devices,
|
||||
int num_underlying_devices, TF_Status* status);
|
||||
|
||||
} // namespace eager
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
|
||||
917
tensorflow/c/eager/parallel_device/parallel_device_test.cc
Normal file
917
tensorflow/c/eager/parallel_device/parallel_device_test.cc
Normal file
@ -0,0 +1,917 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
// NOTE(allenl): These tests currently go through TFE_Execute and so are
|
||||
// integration testing rather than purely testing the parallel device. They
|
||||
// correspond fairly well to the implementation, but testing the C++ directly is
|
||||
// another option.
|
||||
|
||||
// Functor for making unique_ptr to TFE_TensorHandle slightly more
|
||||
// ergonomic. Using decltype(TFE_DeleteTensorHandle) in the unique_ptr's second
|
||||
// template argument requires passing a function pointer to
|
||||
// TFE_DeleteTensorHandle when constructing the unique_ptr.
|
||||
class TensorHandleDeleter {
|
||||
public:
|
||||
void operator()(TFE_TensorHandle* to_delete) {
|
||||
TFE_DeleteTensorHandle(to_delete);
|
||||
}
|
||||
};
|
||||
|
||||
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
|
||||
|
||||
// A helper for performing common operations on variables. A much more
|
||||
// restricted stand-in for tf.Variable in Python.
|
||||
class Variable {
|
||||
public:
|
||||
// Construct a Variable from a resource-dtype TFE_TensorHandle and an
|
||||
// indication of the dtype of the variable's value.
|
||||
//
|
||||
// Note that creating this resource-dtype handle can fail, so `Create` is a
|
||||
// separate static method which returns a status.
|
||||
Variable(TFE_TensorHandle* handle, TF_DataType type)
|
||||
: handle_(handle), type_(type) {}
|
||||
|
||||
// Helper for constructing a resource handle and wrapping it in a `Variable`
|
||||
// object.
|
||||
static Variable* Create(TFE_Context* context, TF_DataType type,
|
||||
const int64_t* dims, const int num_dims,
|
||||
const char* device, TF_Status* status);
|
||||
// Dereferences the backing buffer for the variable. Note that since this can
|
||||
// fail (it runs operations), it must be called explicitly and the resulting
|
||||
// `status` checked.
|
||||
void Destroy(TFE_Context* context, TF_Status* status);
|
||||
|
||||
// Reads from the variable.
|
||||
TensorHandlePtr Read(TFE_Context* context, TF_Status* status);
|
||||
// Assigns a new value to the variable.
|
||||
void Assign(TFE_Context* context, TFE_TensorHandle* value, TF_Status* status);
|
||||
// Adds `value` to the existing value of the variable.
|
||||
void AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
|
||||
TF_Status* status);
|
||||
|
||||
private:
|
||||
// Helper for running any single-argument assignment ops (Assign, AssignAdd,
|
||||
// AssignSub, ...).
|
||||
void GeneralAssignment(const char* op_name, TFE_Context* context,
|
||||
TFE_TensorHandle* value, TF_Status* status);
|
||||
|
||||
// The a handle for the resource-dtype tensor pointing to the variable's
|
||||
// buffer.
|
||||
TFE_TensorHandle* handle_;
|
||||
// The dtype of the variable's buffer (input dtype for assignments, output
|
||||
// dtype of read operations).
|
||||
TF_DataType type_;
|
||||
};
|
||||
|
||||
Variable* Variable::Create(TFE_Context* context, TF_DataType type,
|
||||
const int64_t* dims, const int num_dims,
|
||||
const char* device, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op.get(), "dtype", type);
|
||||
TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status);
|
||||
TFE_OpSetAttrString(op.get(), "container", "", 0);
|
||||
// Use the special GUID for no buffer sharing
|
||||
//
|
||||
// TODO(allenl): Should we provide a better API for this? AFAIK this is the
|
||||
// only reasonable way to make variables with no aliasing using the eager C
|
||||
// API.
|
||||
std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347";
|
||||
TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(),
|
||||
no_sharing.length());
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_TensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &var_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return new Variable(var_handle, type);
|
||||
}
|
||||
|
||||
void Variable::Destroy(TFE_Context* context, TF_Status* status) {
|
||||
// Free the backing buffer for the variable.
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpAddInput(op.get(), handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const char* device = TFE_TensorHandleDeviceName(handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
int num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
// Delete the variable handle itself.
|
||||
TFE_DeleteTensorHandle(handle_);
|
||||
}
|
||||
|
||||
TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpAddInput(op.get(), handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
const char* device = TFE_TensorHandleDeviceName(handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op.get(), "dtype", type_);
|
||||
int num_retvals = 1;
|
||||
TFE_TensorHandle* var_value = nullptr;
|
||||
TFE_Execute(op.get(), &var_value, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return TensorHandlePtr(var_value);
|
||||
}
|
||||
|
||||
void Variable::GeneralAssignment(const char* op_name, TFE_Context* context,
|
||||
TFE_TensorHandle* value, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, op_name, status), &TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetAttrType(op.get(), "dtype", type_);
|
||||
TFE_OpAddInput(op.get(), handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpAddInput(op.get(), value, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const char* device = TFE_TensorHandleDeviceName(handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
|
||||
int num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
}
|
||||
|
||||
void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
|
||||
TF_Status* status) {
|
||||
GeneralAssignment("AssignAddVariableOp", context, value, status);
|
||||
}
|
||||
|
||||
void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value,
|
||||
TF_Status* status) {
|
||||
GeneralAssignment("AssignVariableOp", context, value, status);
|
||||
}
|
||||
|
||||
// Passed to `TF_NewTensor` to indicate how an array of floats should be
|
||||
// deleted.
|
||||
static void FloatDeallocator(void* data, size_t, void* arg) {
|
||||
delete[] static_cast<float*>(data);
|
||||
}
|
||||
|
||||
// Creates a TFE_TensorHandle with value `v`.
|
||||
TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) {
|
||||
const int num_bytes = sizeof(float);
|
||||
float* values = new float[1];
|
||||
values[0] = v;
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||
TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator,
|
||||
nullptr),
|
||||
TF_DeleteTensor);
|
||||
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
|
||||
}
|
||||
|
||||
// Creates a rank-one TFE_TensorHandle with value `v`.
|
||||
TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
|
||||
TF_Status* status) {
|
||||
const int num_bytes = v.size() * sizeof(float);
|
||||
float* values = new float[v.size()];
|
||||
memcpy(values, v.data(), num_bytes);
|
||||
int64_t dims = v.size();
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||
TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes,
|
||||
&FloatDeallocator, nullptr),
|
||||
TF_DeleteTensor);
|
||||
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
|
||||
}
|
||||
|
||||
// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
|
||||
template <std::size_t num_replicas>
|
||||
void ExtractPerDeviceValues(
|
||||
TFE_Context* context, TFE_TensorHandle* input,
|
||||
std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas);
|
||||
TFE_OpAddInput(op.get(), input, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const char* device = TFE_TensorHandleDeviceName(input, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
TFE_TensorHandle* result_handles[num_replicas];
|
||||
int num_retvals = num_replicas;
|
||||
TFE_Execute(op.get(), result_handles, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
for (int i = 0; i < num_replicas; ++i) {
|
||||
(*components)[i].reset(result_handles[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to pack `num_replicas` TFE_TensorHandles into one parallel handle.
|
||||
template <std::size_t num_replicas>
|
||||
TensorHandlePtr CreatePerDeviceValues(
|
||||
TFE_Context* context,
|
||||
const std::array<TFE_TensorHandle*, num_replicas>& components,
|
||||
const char* device, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "TPUReplicatedInput", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrInt(op.get(), "N", num_replicas);
|
||||
for (int i = 0; i < num_replicas; ++i) {
|
||||
TFE_OpAddInput(op.get(), components[i], status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
TFE_TensorHandle* result_handle;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return TensorHandlePtr(result_handle);
|
||||
}
|
||||
|
||||
TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
|
||||
TFE_TensorHandle* second, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "Mul", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpAddInput(op.get(), first, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpAddInput(op.get(), second, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
const char* first_device = TFE_TensorHandleDeviceName(first, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetDevice(op.get(), first_device, status);
|
||||
|
||||
TFE_TensorHandle* result_handle;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return TensorHandlePtr(result_handle);
|
||||
}
|
||||
|
||||
// Assert that `handle` is equal to `expected_value`.
|
||||
void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
|
||||
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_EQ(expected_value,
|
||||
*static_cast<float*>(TF_TensorData(value_zero.get())));
|
||||
}
|
||||
|
||||
// Create and modify a variable placed on a parallel device which composes
|
||||
// `first_device` and `second_device`.
|
||||
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
const char* second_device) {
|
||||
// Register the custom device
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::array<const char*, 2> underlying_devices{first_device, second_device};
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context, device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a variable handle (uninitialized to start) placed on the parallel
|
||||
// device.
|
||||
std::function<void(Variable*)> variable_deleter = [&](Variable* to_delete) {
|
||||
to_delete->Destroy(context, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
delete to_delete;
|
||||
};
|
||||
std::unique_ptr<Variable, decltype(variable_deleter)> variable(
|
||||
Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name,
|
||||
status.get()),
|
||||
variable_deleter);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Assign an initial value to the variable, implicitly mirroring it to each
|
||||
// component device.
|
||||
{
|
||||
TensorHandlePtr initial_value = FloatTensorHandle(20., status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
variable->Assign(context, initial_value.get(), status.get());
|
||||
}
|
||||
|
||||
// Read from the variable and verify that we have a parallel tensor.
|
||||
{
|
||||
TensorHandlePtr read = variable->Read(context, status.get());
|
||||
std::array<TensorHandlePtr, 2> components;
|
||||
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
AssertScalarFloatEq(components[0].get(), 20.);
|
||||
AssertScalarFloatEq(components[1].get(), 20.);
|
||||
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
std::string second_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
|
||||
// Add a parallel tensor with different values on each device to the variable.
|
||||
{
|
||||
TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
|
||||
TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
|
||||
std::array<TFE_TensorHandle*, 2> components{value_one.get(),
|
||||
value_two.get()};
|
||||
TensorHandlePtr combined_value =
|
||||
CreatePerDeviceValues(context, components, device_name, status.get());
|
||||
variable->AssignAdd(context, combined_value.get(), status.get());
|
||||
}
|
||||
|
||||
// Read the variable and verify that each component has the right modified
|
||||
// value.
|
||||
{
|
||||
TensorHandlePtr read = variable->Read(context, status.get());
|
||||
std::array<TensorHandlePtr, 2> components;
|
||||
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
AssertScalarFloatEq(components[0].get(), 23.);
|
||||
AssertScalarFloatEq(components[1].get(), 18.);
|
||||
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
std::string second_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestBasicCPU) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*xla*/ false,
|
||||
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
||||
2),
|
||||
TF_DeleteBuffer);
|
||||
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
|
||||
status.get());
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
BasicTestsForTwoDevices(context.get(),
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1");
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestBasicCPUAliased) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
BasicTestsForTwoDevices(context.get(),
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestBasicTPUAliased) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Skip the test if no TPU is available.
|
||||
std::unique_ptr<TF_DeviceList, decltype(&TF_DeleteDeviceList)> devices(
|
||||
TFE_ContextListDevices(context.get(), status.get()), TF_DeleteDeviceList);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
bool has_tpu = false;
|
||||
for (int device_index = 0; device_index < TF_DeviceListCount(devices.get());
|
||||
++device_index) {
|
||||
std::string device_type =
|
||||
TF_DeviceListType(devices.get(), device_index, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
if (device_type == "TPU") {
|
||||
has_tpu = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (has_tpu) {
|
||||
BasicTestsForTwoDevices(context.get(),
|
||||
"/job:localhost/replica:0/task:0/device:TPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:TPU:0");
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestExplicitCopies) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*xla*/ false,
|
||||
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
||||
2),
|
||||
TF_DeleteBuffer);
|
||||
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
|
||||
status.get());
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> underlying_devices;
|
||||
const char* first_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
underlying_devices.push_back(first_device_name);
|
||||
const char* second_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1";
|
||||
underlying_devices.push_back(second_device_name);
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TensorHandlePtr cpu_value(FloatTensorHandle(3., status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Copying on to a parallel device is OK.
|
||||
TensorHandlePtr device_value(TFE_TensorHandleCopyToDevice(
|
||||
cpu_value.get(), context.get(), device_name, status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
const char* backing_device =
|
||||
TFE_TensorHandleBackingDeviceName(device_value.get(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_EQ(std::string(device_name), backing_device);
|
||||
|
||||
// Un-pack the parallel tensor to verify that the copy was successful.
|
||||
{
|
||||
std::array<TensorHandlePtr, 2> components;
|
||||
ExtractPerDeviceValues(context.get(), device_value.get(), &components,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// The value of the original tensor is replicated on each device.
|
||||
AssertScalarFloatEq(components[0].get(), 3.);
|
||||
AssertScalarFloatEq(components[1].get(), 3.);
|
||||
|
||||
// Verify that the mirrors are placed on the component devices.
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
std::string second_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
|
||||
// Copies off of parallel devices must be explicit.
|
||||
TensorHandlePtr copy_back(TFE_TensorHandleCopyToDevice(
|
||||
device_value.get(), context.get(), first_device_name, status.get()));
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_INTERNAL);
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestDifferentShapes) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*xla*/ false,
|
||||
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
||||
2),
|
||||
TF_DeleteBuffer);
|
||||
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
|
||||
status.get());
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> underlying_devices;
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create two vectors with different lengths
|
||||
std::vector<float> size_two_value{1., 2.};
|
||||
std::vector<float> size_three_value{1., 2., 3.};
|
||||
TensorHandlePtr size_two(
|
||||
VectorFloatTensorHandle(size_two_value, status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TensorHandlePtr size_three(
|
||||
VectorFloatTensorHandle(size_three_value, status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Try to combine these values into a single parallel tensor.
|
||||
std::array<TFE_TensorHandle*, 2> components{size_two.get(), size_three.get()};
|
||||
TensorHandlePtr combined_value = CreatePerDeviceValues(
|
||||
context.get(), components, device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_UNIMPLEMENTED)
|
||||
<< TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*xla*/ false,
|
||||
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
||||
3),
|
||||
TF_DeleteBuffer);
|
||||
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
|
||||
status.get());
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a parallel device with two CPUs
|
||||
const char* first_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> first_underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1"};
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), first_device_name, first_underlying_devices.data(),
|
||||
first_underlying_devices.size(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a second parallel device with the first parallel device and one
|
||||
// additional CPU.
|
||||
const char* second_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CUSTOM:1";
|
||||
std::vector<const char*> second_underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CUSTOM:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:2"};
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), second_device_name, second_underlying_devices.data(),
|
||||
second_underlying_devices.size(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a tensor on the first parallel device
|
||||
TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
|
||||
TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
|
||||
TensorHandlePtr first_combined_value = CreatePerDeviceValues(
|
||||
context.get(), components, first_device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Nest the first parallel tensor into a second
|
||||
TensorHandlePtr value_three(FloatTensorHandle(3., status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
components[0] = first_combined_value.get();
|
||||
components[1] = value_three.get();
|
||||
TensorHandlePtr second_combined_value = CreatePerDeviceValues(
|
||||
context.get(), components, second_device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TensorHandlePtr negative_one(FloatTensorHandle(3., status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TensorHandlePtr multiply_result(Multiply(context.get(),
|
||||
second_combined_value.get(),
|
||||
negative_one.get(), status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Un-pack the parallel tensor to verify that the operation was
|
||||
// successful. The resulting structure should be:
|
||||
// second_device{first_device{1. * 3., 2. * 3.}, 3. * 3.}.
|
||||
std::array<TensorHandlePtr, 2> second_components;
|
||||
ExtractPerDeviceValues(context.get(), multiply_result.get(),
|
||||
&second_components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
AssertScalarFloatEq(second_components[1].get(), 9.);
|
||||
|
||||
// Verify that the mirrors are placed on the component devices.
|
||||
std::string first_device = TFE_TensorHandleBackingDeviceName(
|
||||
second_components[0].get(), status.get());
|
||||
ASSERT_EQ(second_underlying_devices[0], first_device);
|
||||
std::string second_device = TFE_TensorHandleBackingDeviceName(
|
||||
second_components[1].get(), status.get());
|
||||
ASSERT_EQ(second_underlying_devices[1], second_device);
|
||||
|
||||
// Un-pack the first parallel device's tensor too
|
||||
std::array<TensorHandlePtr, 2> first_components;
|
||||
ExtractPerDeviceValues(context.get(), second_components[0].get(),
|
||||
&first_components, status.get());
|
||||
AssertScalarFloatEq(first_components[0].get(), 3.);
|
||||
AssertScalarFloatEq(first_components[1].get(), 6.);
|
||||
|
||||
first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(),
|
||||
status.get());
|
||||
ASSERT_EQ(first_underlying_devices[0], first_device);
|
||||
second_device = TFE_TensorHandleBackingDeviceName(first_components[1].get(),
|
||||
status.get());
|
||||
ASSERT_EQ(first_underlying_devices[1], second_device);
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestInvalidPacking) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> underlying_devices;
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
|
||||
TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
|
||||
{
|
||||
// Try to pack two TensorHandles onto a parallel device with a single
|
||||
// component.
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::array<TFE_TensorHandle*, 2> components{value_one.get(),
|
||||
value_two.get()};
|
||||
TensorHandlePtr combined_value = CreatePerDeviceValues(
|
||||
context.get(), components, device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
|
||||
<< TF_Message(status.get());
|
||||
}
|
||||
|
||||
{
|
||||
// Try to extract the wrong number of components from a parallel tensor
|
||||
std::array<TFE_TensorHandle*, 1> correct_components{value_one.get()};
|
||||
TensorHandlePtr combined_value = CreatePerDeviceValues(
|
||||
context.get(), correct_components, device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
std::array<TensorHandlePtr, 2> incorrect_components;
|
||||
ExtractPerDeviceValues(context.get(), combined_value.get(),
|
||||
&incorrect_components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
|
||||
<< TF_Message(status.get());
|
||||
}
|
||||
|
||||
{
|
||||
// Try to pass a ParallelTensor to TPUReplicatedInput
|
||||
std::array<TFE_TensorHandle*, 1> correct_components{value_one.get()};
|
||||
TensorHandlePtr combined_value = CreatePerDeviceValues(
|
||||
context.get(), correct_components, device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
std::array<TFE_TensorHandle*, 1> incorrect_components{combined_value.get()};
|
||||
TensorHandlePtr recombined_value = CreatePerDeviceValues(
|
||||
context.get(), incorrect_components, device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
|
||||
<< TF_Message(status.get());
|
||||
}
|
||||
|
||||
{
|
||||
// Try to pass a non-parallel tensor to TPUReplicatedOutput
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context.get(), "TPUReplicatedOutput", status.get()),
|
||||
TFE_DeleteOp);
|
||||
if (TF_GetCode(status.get()) != TF_OK) return;
|
||||
TFE_OpSetAttrInt(op.get(), "num_replicas", 1);
|
||||
TFE_OpAddInput(op.get(), value_one.get(), status.get());
|
||||
if (TF_GetCode(status.get()) != TF_OK) return;
|
||||
TFE_OpSetDevice(op.get(), device_name, status.get());
|
||||
if (TF_GetCode(status.get()) != TF_OK) return;
|
||||
|
||||
TFE_TensorHandle* result_handles;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &result_handles, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
|
||||
<< TF_Message(status.get());
|
||||
}
|
||||
}
|
||||
|
||||
TensorHandlePtr CollectiveSum(TFE_Context* context, TFE_TensorHandle* input,
|
||||
int group_size, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "CollectiveReduce", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
const char* device = TFE_TensorHandleDeviceName(input, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op.get(), "T", TFE_TensorHandleDataType(input));
|
||||
TFE_OpSetAttrInt(op.get(), "group_size", group_size);
|
||||
TFE_OpSetAttrInt(op.get(), "group_key", 0);
|
||||
TFE_OpSetAttrInt(op.get(), "instance_key", 0);
|
||||
const std::string merge_op("Add");
|
||||
TFE_OpSetAttrString(op.get(), "merge_op", merge_op.c_str(),
|
||||
merge_op.length());
|
||||
const std::string final_op("Id");
|
||||
TFE_OpSetAttrString(op.get(), "final_op", final_op.c_str(),
|
||||
final_op.length());
|
||||
TFE_OpSetAttrIntList(op.get(), "subdiv_offsets", nullptr, 0);
|
||||
|
||||
TFE_OpAddInput(op.get(), input, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
TFE_TensorHandle* result_handle;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return TensorHandlePtr(result_handle);
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestCollective) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*xla*/ false,
|
||||
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
||||
2),
|
||||
TF_DeleteBuffer);
|
||||
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
|
||||
status.get());
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> underlying_devices;
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a tensor on the parallel device
|
||||
TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
|
||||
TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
|
||||
TensorHandlePtr parallel_value = CreatePerDeviceValues(
|
||||
context.get(), components, device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Run a collective sum, so each component should now be the same.
|
||||
TensorHandlePtr reduced(
|
||||
CollectiveSum(context.get(), parallel_value.get(), 2, status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
std::array<TensorHandlePtr, 2> result_components;
|
||||
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
AssertScalarFloatEq(result_components[0].get(), 3.);
|
||||
AssertScalarFloatEq(result_components[1].get(), 3.);
|
||||
}
|
||||
|
||||
void RegisterCollectiveMulFunction(TFE_Context* context,
|
||||
const char* function_name, int group_size,
|
||||
TF_Status* status) {
|
||||
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> body(TF_NewGraph(),
|
||||
TF_DeleteGraph);
|
||||
TF_OperationDescription* placeholder_desc =
|
||||
TF_NewOperation(body.get(), "Placeholder", "Placeholder");
|
||||
TF_SetAttrType(placeholder_desc, "dtype", TF_FLOAT);
|
||||
TF_Operation* placeholder_op = TF_FinishOperation(placeholder_desc, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TF_Output x{placeholder_op, 0};
|
||||
|
||||
TF_OperationDescription* reduce_desc =
|
||||
TF_NewOperation(body.get(), "CollectiveReduce", "CollectiveReduce");
|
||||
TF_SetAttrType(reduce_desc, "T", TF_FLOAT);
|
||||
TF_SetAttrInt(reduce_desc, "group_size", group_size);
|
||||
TF_SetAttrInt(reduce_desc, "group_key", 0);
|
||||
TF_SetAttrInt(reduce_desc, "instance_key", 0);
|
||||
|
||||
const std::string merge_op("Mul");
|
||||
TF_SetAttrString(reduce_desc, "merge_op", merge_op.c_str(),
|
||||
merge_op.length());
|
||||
const std::string final_op("Id");
|
||||
TF_SetAttrString(reduce_desc, "final_op", final_op.c_str(),
|
||||
final_op.length());
|
||||
TF_SetAttrIntList(reduce_desc, "subdiv_offsets", nullptr, 0);
|
||||
TF_AddInput(reduce_desc, x);
|
||||
TF_Operation* reduce_op = TF_FinishOperation(reduce_desc, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TF_Operation* operations[]{placeholder_op, reduce_op};
|
||||
TF_Output y{reduce_op, 0};
|
||||
const char* output_name = "y";
|
||||
std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)> function(
|
||||
TF_GraphToFunction(
|
||||
/* fn_body */ body.get(), /* fn_name */ function_name,
|
||||
/* append_hash_to_fn_name */ 0, /* num_opers */ 2,
|
||||
/* opers */ operations, /* ninputs */ 1, /* inputs */ &x,
|
||||
/* noutputs */ 1, /* outputs */ &y, /* output_names */ &output_name,
|
||||
/* opts */ nullptr, /* description */ "", /* status */ status),
|
||||
TF_DeleteFunction);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_ContextAddFunction(context, function.get(), status);
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestFunction) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*xla*/ false,
|
||||
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
||||
2),
|
||||
TF_DeleteBuffer);
|
||||
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
|
||||
status.get());
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> underlying_devices;
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* function_name = "test_reduce_mul";
|
||||
RegisterCollectiveMulFunction(context.get(), function_name, 2, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TensorHandlePtr value_one(FloatTensorHandle(7., status.get()));
|
||||
TensorHandlePtr value_two(FloatTensorHandle(9., status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
|
||||
TensorHandlePtr parallel_value = CreatePerDeviceValues(
|
||||
context.get(), components, device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context.get(), function_name, status.get()), TFE_DeleteOp);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetDevice(op.get(), device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpAddInput(op.get(), parallel_value.get(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* raw_result_handle;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &raw_result_handle, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TensorHandlePtr reduced(raw_result_handle);
|
||||
|
||||
std::array<TensorHandlePtr, 2> result_components;
|
||||
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
AssertScalarFloatEq(result_components[0].get(), 7. * 9.);
|
||||
AssertScalarFloatEq(result_components[1].get(), 7. * 9.);
|
||||
|
||||
std::string first_device = TFE_TensorHandleBackingDeviceName(
|
||||
result_components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
std::string second_device = TFE_TensorHandleBackingDeviceName(
|
||||
result_components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
@ -119,6 +119,9 @@ inline Tensor& TensorFromInterface(AbstractTensorInterface* tensor) {
|
||||
return down_cast<TensorInterface*>(tensor)->Tensor();
|
||||
}
|
||||
|
||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
||||
|
||||
TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
|
||||
|
||||
@ -156,6 +156,7 @@ cc_library(
|
||||
":array_grad",
|
||||
":data_flow_grad",
|
||||
":image_grad",
|
||||
":manip_grad",
|
||||
":math_grad",
|
||||
":nn_grad",
|
||||
],
|
||||
@ -494,6 +495,32 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "manip_grad",
|
||||
srcs = ["gradients/manip_grad.cc"],
|
||||
deps = [
|
||||
":cc_ops",
|
||||
":grad_op_registry",
|
||||
":gradients",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "gradients_manip_grad_test",
|
||||
srcs = ["gradients/manip_grad_test.cc"],
|
||||
deps = [
|
||||
":array_ops",
|
||||
":cc_ops",
|
||||
":gradient_checker",
|
||||
":manip_grad",
|
||||
":testutil",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
# Generates separate libraries for array_ops and math_ops to reduce the dependency count of targets that depend on only these
|
||||
tf_gen_op_wrappers_cc(
|
||||
name = "math_ops",
|
||||
|
||||
40
tensorflow/cc/gradients/manip_grad.cc
Normal file
40
tensorflow/cc/gradients/manip_grad.cc
Normal file
@ -0,0 +1,40 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||
#include "tensorflow/cc/framework/gradients.h"
|
||||
#include "tensorflow/cc/ops/manip_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
namespace {
|
||||
|
||||
Status RollGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
auto shift = op.input(1);
|
||||
auto axis = op.input(2);
|
||||
auto grad_op = Roll(scope, grad_inputs[0], Neg(scope, shift), axis);
|
||||
grad_outputs->push_back(grad_op);
|
||||
grad_outputs->push_back(NoGradient());
|
||||
grad_outputs->push_back(NoGradient());
|
||||
return scope.status();
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Roll", RollGrad);
|
||||
|
||||
} // namespace
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
51
tensorflow/cc/gradients/manip_grad_test.cc
Normal file
51
tensorflow/cc/gradients/manip_grad_test.cc
Normal file
@ -0,0 +1,51 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/framework/gradient_checker.h"
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/manip_ops.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
using ops::Placeholder;
|
||||
using ops::Roll;
|
||||
|
||||
class ManipGradTest : public ::testing::Test {
|
||||
protected:
|
||||
ManipGradTest() : scope_(Scope::NewRootScope()) {}
|
||||
|
||||
void RunTest(const Output& x, const TensorShape& x_shape, const Output& y,
|
||||
const TensorShape& y_shape) {
|
||||
TF_ASSERT_OK(scope_.status());
|
||||
float max_error;
|
||||
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
|
||||
scope_, {x}, {x_shape}, {y}, {y_shape}, &max_error)));
|
||||
EXPECT_LT(max_error, 1e-4);
|
||||
}
|
||||
|
||||
Scope scope_;
|
||||
};
|
||||
|
||||
TEST_F(ManipGradTest, RollGrad) {
|
||||
TensorShape shape({5, 4, 3});
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
||||
auto y = Roll(scope_, x, {2, 1}, {0, 1});
|
||||
RunTest(x, shape, y, shape);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
@ -16,6 +16,12 @@ cc_library(
|
||||
deps = ["//tensorflow/core:test_main"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "quantize_header",
|
||||
srcs = ["quantize.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfcompile_lib",
|
||||
srcs = [
|
||||
@ -27,6 +33,7 @@ cc_library(
|
||||
"codegen.h",
|
||||
"compile.h",
|
||||
"flags.h",
|
||||
"quantize.h",
|
||||
],
|
||||
defines = if_llvm_aarch64_available(["TF_LLVM_AARCH64_AVAILABLE=1"]),
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
@ -37,7 +44,6 @@ cc_library(
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"//tensorflow/compiler/mlir/lite/quantization/xla:quantize",
|
||||
"//tensorflow/compiler/tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
|
||||
|
||||
@ -24,7 +24,7 @@ limitations under the License.
|
||||
#include "llvm-c/Target.h"
|
||||
#include "tensorflow/compiler/aot/codegen.h"
|
||||
#include "tensorflow/compiler/aot/flags.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
|
||||
#include "tensorflow/compiler/aot/quantize.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
@ -46,6 +46,14 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
|
||||
static llvm::ManagedStatic<QuantizeXlaFn> quantize_xla;
|
||||
|
||||
bool RegisterQuantizeFn(const QuantizeXlaFn& fn) {
|
||||
if (*quantize_xla) return false;
|
||||
*quantize_xla = fn;
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Compiles the XLA computation into executable code.
|
||||
@ -116,9 +124,11 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
|
||||
} else {
|
||||
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
|
||||
}
|
||||
if (flags.experimental_quantize) {
|
||||
TF_RETURN_IF_ERROR(mlir::xla_hlo::XlaQuantize(config, &computation));
|
||||
|
||||
if (flags.experimental_quantize && *quantize_xla) {
|
||||
TF_RETURN_IF_ERROR((*quantize_xla)(config, &computation));
|
||||
}
|
||||
|
||||
if (!flags.out_session_module.empty()) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
|
||||
computation.Snapshot());
|
||||
|
||||
@ -13,21 +13,29 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
|
||||
#ifndef TENSORFLOW_COMPILER_AOT_QUANTIZE_H_
|
||||
#define TENSORFLOW_COMPILER_AOT_QUANTIZE_H_
|
||||
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
|
||||
// Quantizes the model in the computation.
|
||||
tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config,
|
||||
xla::XlaComputation* computation);
|
||||
using QuantizeXlaFn = std::function<Status(const tf2xla::Config& config,
|
||||
xla::XlaComputation* computation)>;
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
||||
// Set the static quantization function to the `fn` if it hasn't been set.
|
||||
// Return false if the static function has been set.
|
||||
bool RegisterQuantizeFn(const QuantizeXlaFn& fn);
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_AOT_QUANTIZE_H_
|
||||
@ -321,7 +321,7 @@ xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
|
||||
host_compute_builder.node_name());
|
||||
|
||||
// Copy all attributes.
|
||||
for (auto attr : call_node->attrs()) {
|
||||
for (const auto& attr : call_node->attrs()) {
|
||||
host_compute_builder.Attr(attr.first, attr.second);
|
||||
}
|
||||
|
||||
@ -346,7 +346,7 @@ xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
|
||||
xla_token_input_nodes.emplace_back(kXlaTokenArgNodeName);
|
||||
auto cluster_deps_it = cluster_deps.find(original_oc_name);
|
||||
if (cluster_deps_it != cluster_deps.end()) {
|
||||
for (auto dep : cluster_deps_it->second) {
|
||||
for (const auto& dep : cluster_deps_it->second) {
|
||||
xla_token_input_nodes.emplace_back(host_compute_node_name(dep));
|
||||
}
|
||||
}
|
||||
@ -2359,7 +2359,7 @@ Status ExtractOutsideCompilationForFunction(
|
||||
}
|
||||
// For XlaHostCompute nodes with dependencies, add control edges between
|
||||
// them so XlaCompiler can handle them in correct order.
|
||||
for (auto iter : host_compute_nodes) {
|
||||
for (const auto& iter : host_compute_nodes) {
|
||||
Node* host_compute_node = iter.second;
|
||||
std::vector<string> token_input_node_names;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(host_compute_node->def(),
|
||||
@ -2479,7 +2479,7 @@ Status ExtractOutsideCompilation(
|
||||
|
||||
TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name));
|
||||
|
||||
for (auto shape_inference_graph_name : shape_inference_graphs) {
|
||||
for (const auto& shape_inference_graph_name : shape_inference_graphs) {
|
||||
TF_RETURN_IF_ERROR(RewriteShapeInferenceGraph(
|
||||
shape_inference_graph_name, g, pivot_node, fld));
|
||||
}
|
||||
|
||||
@ -1161,7 +1161,7 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
||||
std::vector<string> vall_ops = XlaOpRegistry::GetAllRegisteredOps();
|
||||
absl::flat_hash_set<string> all_ops(vall_ops.begin(), vall_ops.end());
|
||||
// Check that user's provided TF operation really exists.
|
||||
for (auto s : whitelist) {
|
||||
for (const auto& s : whitelist) {
|
||||
if (!all_ops.contains(string(s))) {
|
||||
return errors::InvalidArgument(
|
||||
"The operation '", s,
|
||||
@ -1475,7 +1475,7 @@ void MarkForCompilationPassImpl::VLogClusteringSummary() {
|
||||
<< RatioToString(auto_clustering_info.clustered_node_count(),
|
||||
graph_->num_nodes());
|
||||
|
||||
for (XlaAutoClusteringSummary::Cluster cluster :
|
||||
for (const XlaAutoClusteringSummary::Cluster& cluster :
|
||||
auto_clustering_info.clusters()) {
|
||||
absl::string_view cluster_name = cluster.name();
|
||||
int size = cluster.size();
|
||||
@ -1891,6 +1891,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"DynamicStitch",
|
||||
"Einsum",
|
||||
"EmptyTensorList",
|
||||
"EnsureShape",
|
||||
"ExtractImagePatches",
|
||||
"Igamma",
|
||||
"IgammaGradA",
|
||||
|
||||
@ -493,7 +493,7 @@ std::pair<string, AttrValue> impl::AttrLiteralHelper(
|
||||
const std::pair<string, absl::Span<const string>>& string_list_attr) {
|
||||
AttrValue attr_value;
|
||||
AttrValue::ListValue* list = attr_value.mutable_list();
|
||||
for (string s : string_list_attr.second) {
|
||||
for (const string& s : string_list_attr.second) {
|
||||
list->add_s(s);
|
||||
}
|
||||
return {string_list_attr.first, attr_value};
|
||||
|
||||
@ -296,10 +296,10 @@ Status XlaCompilationCache::CompileSingleOp(
|
||||
arg_shapes.push_back(absl::get<TensorShape>(arg.shape));
|
||||
}
|
||||
GraphDebugInfo debug_info;
|
||||
return CompileGraphToXlaHlo(*graph, {arg_shapes.data(), arg_shapes.size()},
|
||||
compile_options.use_tuple_arg,
|
||||
*options.flib_def, debug_info,
|
||||
options.shape_representation_fn, result);
|
||||
return CompileGraphToXlaHlo(
|
||||
*graph, {arg_shapes.data(), arg_shapes.size()},
|
||||
options.device_type.type_string(), compile_options.use_tuple_arg,
|
||||
*options.flib_def, debug_info, options.shape_representation_fn, result);
|
||||
};
|
||||
return CompileImpl(options, name, args, compile_op,
|
||||
/*compile_threshold=*/absl::nullopt,
|
||||
|
||||
@ -69,6 +69,7 @@ absl::optional<AllocatorStats> XlaDeviceAllocator::GetStats() {
|
||||
tf_stats.bytes_reserved = se_stats->bytes_reserved;
|
||||
tf_stats.peak_bytes_reserved = se_stats->peak_bytes_reserved;
|
||||
tf_stats.bytes_reservable_limit = se_stats->bytes_reservable_limit;
|
||||
tf_stats.largest_free_block_bytes = se_stats->largest_free_block_bytes;
|
||||
return tf_stats;
|
||||
}
|
||||
|
||||
|
||||
@ -479,6 +479,15 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
input_output_alias, output_num, ctx, i, shape, &output,
|
||||
definition_event, stream, use_multiple_streams_));
|
||||
} else {
|
||||
auto program_shape =
|
||||
kernel->computation->GetProgramShape().ValueOrDie();
|
||||
if (program_shape.result().IsTuple() &&
|
||||
program_shape.result().tuple_shapes(output_num).IsTuple()) {
|
||||
return errors::Unimplemented(
|
||||
"Support for TensorList or Stack crossing the XLA/TF boundary "
|
||||
"is not implemented");
|
||||
}
|
||||
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
Tensor output_tensor = GetOrCreateTensorForOutput(
|
||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||
|
||||
@ -40,11 +40,11 @@ cc_library(
|
||||
srcs = ["tf_mlir_opt_main.cc"],
|
||||
deps = [
|
||||
":init_mlir",
|
||||
":passes",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:MlirOptLib",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
@ -55,9 +55,15 @@ cc_library(
|
||||
cc_library(
|
||||
name = "passes",
|
||||
visibility = [
|
||||
":__subpackages__",
|
||||
"//tensorflow/python:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
"@llvm-project//mlir:Affine",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
# Link jit lib to link JIT devices required to run
|
||||
# xla-legalize-tf-with-tf2xla pass.
|
||||
"//tensorflow/compiler/jit",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_legalize_tf",
|
||||
@ -65,33 +71,12 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_passes",
|
||||
"//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_to_quant",
|
||||
"//tensorflow/compiler/mlir/lite/quantization/xla:hlo_xla_quantization_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
|
||||
"//tensorflow/compiler/mlir/xla:buffer_assignment",
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/compiler/mlir/xla:hlo_legalize_to_lhlo",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_copy_removal",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_parallel_loops",
|
||||
"//tensorflow/compiler/mlir/xla:xla_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_to_linalg",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
|
||||
"//tensorflow/compiler/mlir/xla:xla_lower",
|
||||
"//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts",
|
||||
"//tensorflow/compiler/mlir/xla:xla_test_passes",
|
||||
"@llvm-project//mlir:Affine",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
],
|
||||
)
|
||||
|
||||
@ -139,11 +124,14 @@ cc_library(
|
||||
tf_cc_binary(
|
||||
name = "tf-opt",
|
||||
deps = [
|
||||
":passes",
|
||||
":tf_mlir_opt_main",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
|
||||
"//tensorflow/compiler/mlir/tfjs:tensorflow_js_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/xla:all_xla_passes_for_testing",
|
||||
],
|
||||
)
|
||||
|
||||
@ -170,7 +158,6 @@ tf_cc_binary(
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TranslateClParser",
|
||||
"@llvm-project//mlir:Translation",
|
||||
],
|
||||
)
|
||||
|
||||
@ -55,7 +55,7 @@ gentbl(
|
||||
"ir/tfl_structs.cc.inc",
|
||||
),
|
||||
(
|
||||
"-gen-op-doc",
|
||||
"-gen-dialect-doc",
|
||||
"g3doc/tfl_ops.md",
|
||||
),
|
||||
],
|
||||
@ -307,7 +307,7 @@ cc_library(
|
||||
"transforms/optimize_functional_ops.cc",
|
||||
"transforms/prepare_composite_functions_tf.cc",
|
||||
"transforms/prepare_tf.cc",
|
||||
"transforms/runtime_type_verify.cc",
|
||||
"transforms/runtime_verify.cc",
|
||||
"transforms/split_merged_operands.cc",
|
||||
"transforms/trim_functions_tf.cc",
|
||||
"transforms/while_loop_outline.cc",
|
||||
@ -512,7 +512,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:status",
|
||||
@ -521,6 +521,7 @@ cc_library(
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@flatbuffers",
|
||||
"@llvm-project//llvm:analysis",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
@ -537,6 +538,15 @@ tf_native_cc_binary(
|
||||
],
|
||||
)
|
||||
|
||||
tf_native_cc_binary(
|
||||
name = "json_to_flatbuffer",
|
||||
srcs = ["json_to_flatbuffer.cc"],
|
||||
deps = [
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "emit_error_reporter",
|
||||
srcs = [
|
||||
@ -552,19 +562,16 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "flatbuffer_translate_lib",
|
||||
name = "flatbuffer_export",
|
||||
srcs = [
|
||||
"flatbuffer_export.cc",
|
||||
"flatbuffer_import.cc",
|
||||
"utils/convert_type.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"flatbuffer_export.h",
|
||||
"flatbuffer_export_flags.h",
|
||||
"flatbuffer_import.h",
|
||||
"utils/convert_type.h",
|
||||
],
|
||||
deps = [
|
||||
":convert_type",
|
||||
":flatbuffer_tflite_operator_lib",
|
||||
":stateful_ops_utils",
|
||||
":tensorflow_lite",
|
||||
@ -582,14 +589,12 @@ cc_library(
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:schema_fbs_version",
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite/delegates/flex:whitelisted_flex_ops_lib",
|
||||
"//tensorflow/lite/kernels/internal:kernel_utils",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/lite/tools/versioning",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
@ -604,6 +609,78 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "flatbuffer_import",
|
||||
srcs = [
|
||||
"flatbuffer_import.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"flatbuffer_import.h",
|
||||
],
|
||||
deps = [
|
||||
":convert_type",
|
||||
":flatbuffer_tflite_operator_lib",
|
||||
":tensorflow_lite",
|
||||
":tensorflow_lite_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Translation",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "convert_type",
|
||||
srcs = [
|
||||
"utils/convert_type.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"utils/convert_type.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "flatbuffer_translate_lib",
|
||||
hdrs = [
|
||||
"flatbuffer_export.h",
|
||||
"flatbuffer_export_flags.h",
|
||||
"flatbuffer_import.h",
|
||||
"utils/convert_type.h",
|
||||
],
|
||||
deps = [
|
||||
":flatbuffer_export",
|
||||
":flatbuffer_import",
|
||||
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "flatbuffer_translate_registeration",
|
||||
srcs = [
|
||||
|
||||
@ -36,7 +36,8 @@ struct PassConfig {
|
||||
form_clusters(false),
|
||||
unfold_batch_matmul(true),
|
||||
legalize_tf_while(true),
|
||||
shape_inference(true) {}
|
||||
shape_inference(true),
|
||||
runtime_verification(true) {}
|
||||
|
||||
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
|
||||
// added, which produces TF Lite ops.
|
||||
@ -65,6 +66,8 @@ struct PassConfig {
|
||||
bool legalize_tf_while;
|
||||
// Whether to do shape inference.
|
||||
bool shape_inference;
|
||||
// Whether to do TFLite runtime verification.
|
||||
bool runtime_verification;
|
||||
};
|
||||
|
||||
} // namespace TFL
|
||||
|
||||
@ -246,6 +246,50 @@ static void EmitGetBuiltinOpCode(const std::vector<Record *> &defs,
|
||||
"}\n";
|
||||
}
|
||||
|
||||
// Emits functions that return the min/max operand numbers for a given tflite op
|
||||
// name.
|
||||
//
|
||||
// Signature:
|
||||
// llvm::MinMax mlir::OperandNumbersMinMax(llvm::StringRef op_name) {
|
||||
// if(const auto *op = op_union.AsOptions()) {
|
||||
// return {min, max};
|
||||
// }
|
||||
// ...
|
||||
// return {0, 0};
|
||||
// }
|
||||
static void EmitOperandNumbers(const RecordKeeper &record_keeper,
|
||||
const std::vector<Record *> &defs,
|
||||
raw_ostream *ostream) {
|
||||
raw_ostream &os = *ostream;
|
||||
const auto attr_type = record_keeper.getClass("Attr");
|
||||
const auto optional_tensor = record_keeper.getClass("TFL_TensorOfOrNone");
|
||||
os << "llvm::MinMax mlir::OperandNumbersMinMax(llvm::StringRef op_name) {\n";
|
||||
for (const auto *def : defs) {
|
||||
auto op_name = def->getValueAsString("opName");
|
||||
int tail_optional_tensor = 0, tensor_number_max = 0;
|
||||
auto *arg_values = def->getValueAsDag("arguments");
|
||||
for (int i = 0, e = arg_values->getNumArgs(); i < e; ++i) {
|
||||
auto arg = arg_values->getArg(i);
|
||||
auto *arg_def = dyn_cast<DefInit>(arg);
|
||||
if (!arg_def) continue;
|
||||
if (!arg_def->getDef()->isSubClassOf(attr_type)) {
|
||||
tensor_number_max++;
|
||||
if (arg_def->getDef()->isSubClassOf(optional_tensor)) {
|
||||
tail_optional_tensor++;
|
||||
} else {
|
||||
tail_optional_tensor = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
const int tensor_number_min = tensor_number_max - tail_optional_tensor;
|
||||
|
||||
os << formatv(" if (op_name == \"tfl.{0}\") {{\n", op_name)
|
||||
<< " return {" << tensor_number_min << ", " << tensor_number_max
|
||||
<< "};\n }\n";
|
||||
}
|
||||
os << " return {0, 0};\n}\n";
|
||||
}
|
||||
|
||||
// Emits a builder function that returns the packed FlatBuffer object given
|
||||
// a general mlir::Operation.
|
||||
//
|
||||
@ -374,6 +418,8 @@ static bool OperatorWritersMain(raw_ostream &os, RecordKeeper &records) {
|
||||
EmitBuildOperator(defs, &os);
|
||||
os << "\n\n";
|
||||
EmitBuiltinOptionsToAttributes(records, defs, &os);
|
||||
os << "\n\n";
|
||||
EmitOperandNumbers(records, defs, &os);
|
||||
|
||||
return false;
|
||||
}
|
||||
@ -441,7 +487,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
||||
|
||||
mlir::tblgen::FmtContext verify_ctx;
|
||||
os << "::mlir::LogicalResult " << op.getCppClassName()
|
||||
<< "::VerifyTflRuntimeTypes(::mlir::Operation *op, bool "
|
||||
<< "::VerifyTflRuntimeConstraints(::mlir::Operation *op, bool "
|
||||
"failure_on_operand_type_mismatch) {\n";
|
||||
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
|
||||
verify_ctx.withOp("top");
|
||||
@ -450,7 +496,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
||||
auto &value = op.getOperand(i);
|
||||
// Skip from from first variadic operands for now. Else getOperand index
|
||||
// used below doesn't match.
|
||||
if (value.isVariadic()) break;
|
||||
if (value.isVariableLength()) break;
|
||||
if (!value.name.empty())
|
||||
verify_ctx.addSubst(value.name, formatv("op->getOperand({0})", i));
|
||||
}
|
||||
@ -458,7 +504,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
||||
auto &value = op.getResult(i);
|
||||
// Skip from from first variadic results for now. Else getResult index
|
||||
// used below doesn't match.
|
||||
if (value.isVariadic()) break;
|
||||
if (value.isVariableLength()) break;
|
||||
if (!value.name.empty())
|
||||
verify_ctx.addSubst(value.name, formatv("op->getResult({0})", i));
|
||||
}
|
||||
@ -466,6 +512,25 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
||||
"operand");
|
||||
GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(),
|
||||
"result");
|
||||
|
||||
for (auto &trait : op.getTraits()) {
|
||||
if (!trait.getDef().isSubClassOf("GenInternalOpTrait")) {
|
||||
continue;
|
||||
}
|
||||
if (trait.getDef().getValueAsString("trait") !=
|
||||
"OpTrait::TFLRuntimeOpTrait") {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto *val = trait.getDef().getValue("tflRuntimePredicate");
|
||||
if (!val) continue;
|
||||
|
||||
mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
|
||||
os << tgfmt(
|
||||
" if (!($0)) {\n "
|
||||
" return ::mlir::LogicalResult::Failure;\n }\n",
|
||||
&verify_ctx, tgfmt(pred.getCondition(), &verify_ctx));
|
||||
}
|
||||
os << " return top.verify();\n}\n";
|
||||
}
|
||||
|
||||
|
||||
@ -9,7 +9,9 @@ cc_library(
|
||||
name = "cost_estimators",
|
||||
textual_hdrs = [
|
||||
"estimator.h",
|
||||
"cpu_estimators.h",
|
||||
"gpu_estimators.h",
|
||||
"hardware.h",
|
||||
"arithmetic_count_util.h",
|
||||
],
|
||||
)
|
||||
|
||||
@ -0,0 +1,45 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ARITHMETIC_COUNT_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ARITHMETIC_COUNT_UTIL_H_
|
||||
|
||||
// For add/mul/div/sub and other broadcastable ops.
|
||||
class ArithmeticCountUtilHelper {
|
||||
public:
|
||||
static bool GetArithmeticCountForBroadcastableOp(mlir::Operation* op,
|
||||
int64_t* count) {
|
||||
auto output = op->getResult(0);
|
||||
auto output_type = output.getType().dyn_cast_or_null<RankedTensorType>();
|
||||
if (!output_type || !output_type.hasStaticShape()) return false;
|
||||
|
||||
*count = output_type.getNumElements();
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool GetInputTensorTotalSize(mlir::Operation* op, int64_t* count) {
|
||||
int64_t total_count = 0;
|
||||
for (auto input : op->getOperands()) {
|
||||
auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
|
||||
if (!input_type || !input_type.hasStaticShape()) {
|
||||
return false;
|
||||
}
|
||||
total_count += input_type.getNumElements();
|
||||
}
|
||||
*count = total_count;
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ARITHMETIC_COUNT_UTIL_H_
|
||||
@ -0,0 +1,103 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_CPU_ESTIMATORS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_CPU_ESTIMATORS_H_
|
||||
|
||||
// CPU
|
||||
constexpr float kCPUArithmeticUnitCost = 1.0;
|
||||
|
||||
// This basically assumes pure load/store. This is just fake data.
|
||||
constexpr float kCPUCopyUnitCost = 0.5;
|
||||
constexpr float kCPUDefaultCost = 3.0f;
|
||||
|
||||
// Default values.
|
||||
constexpr float kCPUDefaultFixedValuedCost = 10000.0;
|
||||
|
||||
// tfl.add
|
||||
template <>
|
||||
class TFLiteCostEstimator<AddOp, hardware::CPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
int64_t count;
|
||||
if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op,
|
||||
&count))
|
||||
return kCPUArithmeticUnitCost * count;
|
||||
return kCPUDefaultFixedValuedCost;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.mul
|
||||
template <>
|
||||
class TFLiteCostEstimator<MulOp, hardware::CPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
int64_t count;
|
||||
if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op,
|
||||
&count))
|
||||
return kCPUArithmeticUnitCost * count;
|
||||
return kCPUDefaultFixedValuedCost;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.concatenation
|
||||
template <>
|
||||
class TFLiteCostEstimator<ConcatenationOp, hardware::CPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
int64_t count;
|
||||
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
|
||||
return kCPUCopyUnitCost * count;
|
||||
return kCPUDefaultFixedValuedCost;
|
||||
}
|
||||
|
||||
// TODO(renjieliu): We probably need to check for dynamic weights.
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.pack
|
||||
template <>
|
||||
class TFLiteCostEstimator<PackOp, hardware::CPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
int64_t count;
|
||||
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
|
||||
return kCPUCopyUnitCost * count;
|
||||
return kCPUDefaultFixedValuedCost;
|
||||
}
|
||||
|
||||
// TODO(renjieliu): We probably need to check for dynamic weights.
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.reshape
|
||||
template <>
|
||||
class TFLiteCostEstimator<ReshapeOp, hardware::CPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
int64_t count;
|
||||
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
|
||||
return kCPUCopyUnitCost * count;
|
||||
return kCPUDefaultFixedValuedCost;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_CPU_ESTIMATORS_H_
|
||||
@ -16,9 +16,19 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
|
||||
|
||||
// tfl.add
|
||||
// GPU
|
||||
constexpr float kGPUArithmeticUnitCost = 0.2;
|
||||
|
||||
// The copy can be non-consectutive copy. This is just fake data.
|
||||
constexpr float kGPUCopyUnitCost = 0.2;
|
||||
constexpr float kGPUDefaultCost = 1.0f;
|
||||
|
||||
// Default values.
|
||||
constexpr float kGPUDefaultFixedValuedCost = 10000.0;
|
||||
|
||||
// tfl.abs
|
||||
template <>
|
||||
class TFLiteCostEstimator<AddOp, hardware::GPU> {
|
||||
class TFLiteCostEstimator<AbsOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
@ -29,6 +39,21 @@ class TFLiteCostEstimator<AddOp, hardware::GPU> {
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.add
|
||||
template <>
|
||||
class TFLiteCostEstimator<AddOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
int64_t count;
|
||||
if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op,
|
||||
&count))
|
||||
return kGPUArithmeticUnitCost * count;
|
||||
return kGPUDefaultFixedValuedCost;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.average_pool_2d
|
||||
template <>
|
||||
class TFLiteCostEstimator<AveragePool2DOp, hardware::GPU> {
|
||||
@ -47,9 +72,10 @@ template <>
|
||||
class TFLiteCostEstimator<ConcatenationOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
int64_t count;
|
||||
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
|
||||
return kGPUCopyUnitCost * count;
|
||||
return kGPUDefaultFixedValuedCost;
|
||||
}
|
||||
|
||||
// TODO(renjieliu): We probably need to check for dynamic weights.
|
||||
@ -149,6 +175,19 @@ class TFLiteCostEstimator<HardSwishOp, hardware::GPU> {
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.log
|
||||
template <>
|
||||
class TFLiteCostEstimator<LogOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.logistic
|
||||
template <>
|
||||
class TFLiteCostEstimator<LogisticOp, hardware::GPU> {
|
||||
@ -201,6 +240,33 @@ class TFLiteCostEstimator<MaximumOp, hardware::GPU> {
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.max_unpooling_2d
|
||||
template <>
|
||||
class TFLiteCostEstimator<MaxUnpooling2DOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.mean
|
||||
template <>
|
||||
class TFLiteCostEstimator<MeanOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// TODO(renjieiu): check for constraints.
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.minimum
|
||||
template <>
|
||||
class TFLiteCostEstimator<MinimumOp, hardware::GPU> {
|
||||
@ -219,9 +285,11 @@ template <>
|
||||
class TFLiteCostEstimator<MulOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
int64_t count;
|
||||
if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op,
|
||||
&count))
|
||||
return kGPUArithmeticUnitCost * count;
|
||||
return kGPUDefaultFixedValuedCost;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
@ -240,6 +308,32 @@ class TFLiteCostEstimator<PadOp, hardware::GPU> {
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.pow
|
||||
template <>
|
||||
class TFLiteCostEstimator<PowOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.prelu
|
||||
template <>
|
||||
class TFLiteCostEstimator<PReluOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.relu
|
||||
template <>
|
||||
class TFLiteCostEstimator<ReluOp, hardware::GPU> {
|
||||
@ -269,6 +363,33 @@ class TFLiteCostEstimator<Relu6Op, hardware::GPU> {
|
||||
// tfl.reshape
|
||||
template <>
|
||||
class TFLiteCostEstimator<ReshapeOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
int64_t count;
|
||||
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
|
||||
return kGPUCopyUnitCost * count;
|
||||
return kGPUDefaultFixedValuedCost;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.rsqrt
|
||||
template <>
|
||||
class TFLiteCostEstimator<RsqrtOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.sin
|
||||
template <>
|
||||
class TFLiteCostEstimator<SinOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
@ -305,6 +426,58 @@ class TFLiteCostEstimator<SoftmaxOp, hardware::GPU> {
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.space_to_depth
|
||||
template <>
|
||||
class TFLiteCostEstimator<SpaceToDepthOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.sqrt
|
||||
template <>
|
||||
class TFLiteCostEstimator<SqrtOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.square
|
||||
template <>
|
||||
class TFLiteCostEstimator<SquareOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.squared_difference
|
||||
template <>
|
||||
class TFLiteCostEstimator<SquaredDifferenceOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.strided_slice
|
||||
template <>
|
||||
class TFLiteCostEstimator<StridedSliceOp, hardware::GPU> {
|
||||
@ -318,6 +491,19 @@ class TFLiteCostEstimator<StridedSliceOp, hardware::GPU> {
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.tanh
|
||||
template <>
|
||||
class TFLiteCostEstimator<TanhOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.transpose
|
||||
template <>
|
||||
class TFLiteCostEstimator<TransposeOp, hardware::GPU> {
|
||||
@ -331,5 +517,18 @@ class TFLiteCostEstimator<TransposeOp, hardware::GPU> {
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.transpose_conv
|
||||
template <>
|
||||
class TFLiteCostEstimator<TransposeConvOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
|
||||
|
||||
|
||||
@ -59,13 +59,11 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Support/Functional.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Translation.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
@ -579,6 +577,24 @@ StatusOr<Operation*> ConvertOp(
|
||||
op_state.addTypes({type});
|
||||
}
|
||||
|
||||
// While the last several tensors could be optional tensors for an tfl op, the
|
||||
// number of input operands could vary. Gets the min/max number of
|
||||
// operands from tflite op name.
|
||||
// Also, since the above code special-handles the `tfl.reshape` op and add an
|
||||
// additional input, we put these function block here.
|
||||
llvm::MinMax input_min_max = mlir::OperandNumbersMinMax(op_name);
|
||||
int input_max_num = input_min_max.Max;
|
||||
int op_input_num = op_state.operands.size();
|
||||
if (input_max_num != 0 && input_max_num > op_input_num) {
|
||||
// If the number of current inputs is less than the op definition, fill in
|
||||
// with `none` value,
|
||||
llvm::SmallVector<Value, 4> none_operands(
|
||||
input_max_num - op_input_num,
|
||||
builder.create<mlir::ConstantOp>(loc, builder.getNoneType(),
|
||||
builder.getUnitAttr()));
|
||||
op_state.addOperands(ArrayRef<Value>(none_operands));
|
||||
}
|
||||
|
||||
if (op_name == "tfl.lstm") {
|
||||
// TODO(b/147587779): add the right region if region is empty.
|
||||
op_state.addRegion();
|
||||
@ -658,8 +674,8 @@ template <typename ContainerType>
|
||||
mlir::NamedAttribute BuildTFEntryFunctionAttribute(
|
||||
const tflite::SubGraphT& subgraph, Builder* builder, const std::string name,
|
||||
const ContainerType indices) {
|
||||
llvm::SmallVector<std::string, 8> tensor_names = mlir::functional::map(
|
||||
[&](int i) { return subgraph.tensors.at(i)->name; }, indices);
|
||||
auto tensor_names = llvm::map_range(
|
||||
indices, [&](int i) { return subgraph.tensors.at(i)->name; });
|
||||
return builder->getNamedAttr(
|
||||
name, builder->getStringAttr(llvm::join(tensor_names, ",")));
|
||||
}
|
||||
|
||||
@ -25,7 +25,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
||||
|
||||
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Analysis/AssumeBundleQueries.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
@ -55,6 +56,11 @@ void BuiltinOptionsToAttributes(
|
||||
// NOLINTNEXTLINE
|
||||
llvm::SmallVectorImpl<mlir::NamedAttribute> &attributes);
|
||||
|
||||
// While the last several tensors could be optional tensors for an tfl op, the
|
||||
// number of input operands could vary. This function gets the min/max number of
|
||||
// operands from tflite op name.
|
||||
llvm::MinMax OperandNumbersMinMax(llvm::StringRef op_name);
|
||||
|
||||
// Populates the array of mlir::NamedAttributes corresponding to the given
|
||||
// custom_options.
|
||||
// We use an out parameter per LLVM convention
|
||||
|
||||
@ -86,7 +86,7 @@ def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> {
|
||||
let methods = [
|
||||
StaticInterfaceMethod<
|
||||
[{Returns whether the op's operands/results are supported by runtime.}],
|
||||
"LogicalResult", "VerifyTflRuntimeTypes",
|
||||
"LogicalResult", "VerifyTflRuntimeConstraints",
|
||||
(ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch)
|
||||
>,
|
||||
];
|
||||
|
||||
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
@ -45,6 +46,30 @@ namespace mlir {
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc"
|
||||
namespace TFL {
|
||||
|
||||
// Returns true when the given two types have the same shape or broadcastable
|
||||
// shape within the given rank. If any given shapes are non-static, this method
|
||||
// returns true.
|
||||
bool IsBinaryOperandsHaveSameShapesOrBroadcastableShape(Type lhs, Type rhs,
|
||||
int max_bcast_rank) {
|
||||
// Ignore shape checking on the non-static shapes for model compatibility.
|
||||
auto lhs_shaped_type = lhs.dyn_cast<ShapedType>();
|
||||
if (!lhs_shaped_type || !lhs_shaped_type.hasStaticShape()) return true;
|
||||
auto rhs_shaped_type = rhs.dyn_cast<ShapedType>();
|
||||
if (!rhs_shaped_type || !rhs_shaped_type.hasStaticShape()) return true;
|
||||
|
||||
if (lhs_shaped_type.getShape().equals(rhs_shaped_type.getShape()))
|
||||
return true;
|
||||
|
||||
SmallVector<int64_t, 4> result_shape;
|
||||
if (!OpTrait::util::getBroadcastedShape(lhs_shaped_type.getShape(),
|
||||
rhs_shaped_type.getShape(),
|
||||
result_shape)) {
|
||||
return false;
|
||||
}
|
||||
return lhs_shaped_type.getRank() <= max_bcast_rank &&
|
||||
rhs_shaped_type.getRank() <= max_bcast_rank;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlowLiteDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -315,7 +340,7 @@ Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
|
||||
const int num_elements = result_shape_type.getNumElements();
|
||||
new_values.reserve(num_elements);
|
||||
|
||||
for (APFloat old_value : dense_elements.getValues<APFloat>()) {
|
||||
for (const APFloat &old_value : dense_elements.getValues<APFloat>()) {
|
||||
new_values.push_back(calculate(old_value));
|
||||
}
|
||||
|
||||
@ -843,7 +868,7 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (!shape_elements) return nullptr;
|
||||
|
||||
SmallVector<int64_t, 4> shape_data;
|
||||
for (auto it : shape_elements.getValues<APInt>()) {
|
||||
for (const auto &it : shape_elements.getValues<APInt>()) {
|
||||
shape_data.push_back(it.getSExtValue());
|
||||
}
|
||||
result_type =
|
||||
@ -1266,10 +1291,65 @@ static LogicalResult Verify(SplitVOp op) {
|
||||
|
||||
static LogicalResult Verify(LSTMOp op) {
|
||||
auto operands = op.GetStatefulOperands();
|
||||
if (operands.size() == 2 && operands[0] == 18 && operands[1] == 19) {
|
||||
return success();
|
||||
if (operands.size() != 2 || operands[0] != 18 || operands[1] != 19) {
|
||||
return op.emitOpError("LSTMOp expected to have two stateful operands");
|
||||
}
|
||||
return op.emitError("LSTMOp expected to have two stateful operands");
|
||||
|
||||
const auto input_type = op.input().getType().cast<ShapedType>();
|
||||
// Since TFLite runtime generally supports dynamic shape/rank, if `input_type`
|
||||
// doesn't have static shape, we skip the shape check below.
|
||||
if (!input_type.hasStaticShape()) return success();
|
||||
// The input should be at least 2D tensor since it will go through fully
|
||||
// connected layer.
|
||||
if (!input_type.hasRank() || input_type.getRank() < 2)
|
||||
return op.emitOpError(
|
||||
"the first input operand should have more than 2 dimensions.");
|
||||
|
||||
const auto activation_state =
|
||||
op.input_activation_state().getType().cast<ShapedType>();
|
||||
const auto cell_state = op.input_cell_state().getType().cast<ShapedType>();
|
||||
const auto input_to_output_weights =
|
||||
op.input_to_output_weights().getType().cast<ShapedType>();
|
||||
const auto recurrent_to_output_weights =
|
||||
op.recurrent_to_output_weights().getType().cast<ShapedType>();
|
||||
if (activation_state.hasStaticShape() && cell_state.hasStaticShape() &&
|
||||
input_to_output_weights.hasStaticShape() &&
|
||||
recurrent_to_output_weights.hasStaticShape()) {
|
||||
const int n_input = input_type.getDimSize(input_type.getRank() - 1);
|
||||
const int n_cell = input_to_output_weights.getDimSize(0);
|
||||
const int n_output = recurrent_to_output_weights.getDimSize(1);
|
||||
const int output_state_size = activation_state.getNumElements();
|
||||
const int n_batch = input_type.getRank() == 2 ? input_type.getDimSize(0)
|
||||
: input_type.getDimSize(1);
|
||||
const int state_size = cell_state.getNumElements();
|
||||
|
||||
// Check if the dimension of the inputs matches.
|
||||
if ((output_state_size != n_batch * n_output) ||
|
||||
(state_size != n_batch * n_cell) ||
|
||||
(input_to_output_weights.getDimSize(1) != n_input) ||
|
||||
(recurrent_to_output_weights.getRank() != 2) ||
|
||||
(recurrent_to_output_weights.getDimSize(0) != n_cell) ||
|
||||
(input_to_output_weights.getRank() != 2)) {
|
||||
return op.emitOpError("inputs don't match with the dimensions.");
|
||||
}
|
||||
|
||||
const bool is_layer_norm_lstm =
|
||||
!op.forget_layer_norm_coefficients().getType().isa<NoneType>();
|
||||
if (is_layer_norm_lstm) {
|
||||
const auto forget_layer_norm_coefficients =
|
||||
op.forget_layer_norm_coefficients().getType().cast<ShapedType>();
|
||||
// If this lstm has layer normalization, this input value,
|
||||
// "forget_layer_norm_coefficients" should be a 1D tensor.
|
||||
if (forget_layer_norm_coefficients.getRank() != 1 ||
|
||||
forget_layer_norm_coefficients.getDimSize(0) != n_cell)
|
||||
return op.emitOpError(
|
||||
"coefficient inputs have more than 2 dimensions or "
|
||||
"don't match the dimension with input operand "
|
||||
"`input_to_output_weights`.");
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1742,7 +1822,7 @@ static LogicalResult Verify(TransposeOp op) {
|
||||
|
||||
int index = 0;
|
||||
llvm::SmallVector<int64_t, 4> axes;
|
||||
for (auto axis_int : perm.getValues<APInt>()) {
|
||||
for (const auto &axis_int : perm.getValues<APInt>()) {
|
||||
const int64_t axis = axis_int.getSExtValue();
|
||||
if (axis < 0 || (input_type.hasRank() && axis >= input_type.getRank())) {
|
||||
return op.emitOpError(
|
||||
|
||||
@ -28,7 +28,6 @@ limitations under the License.
|
||||
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
|
||||
#include "mlir/Support/Functional.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
@ -54,6 +53,8 @@ class TensorFlowLiteDialect : public Dialect {
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc"
|
||||
// Include all specializes estimators below this line
|
||||
#include "tensorflow/compiler/mlir/lite/experimental/estimators/arithmetic_count_util.h"
|
||||
#include "tensorflow/compiler/mlir/lite/experimental/estimators/cpu_estimators.h"
|
||||
#include "tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h"
|
||||
|
||||
} // end namespace TFL
|
||||
|
||||
@ -106,6 +106,22 @@ class DerivedShapeAttr<code body> : DerivedAttr<"ArrayRef<int64_t>", body>;
|
||||
class DerivedTFLiteTypeAttr<code body> :
|
||||
DerivedAttr<"tflite::TensorType", body>;
|
||||
|
||||
// TFL Runtime op trait predicate.
|
||||
class TFL_RuntimePredOpTrait<string desc, Pred pred> :
|
||||
GenInternalOpTrait<"TFLRuntimeOpTrait"> {
|
||||
Pred tflRuntimePredicate = pred;
|
||||
string tflRuntimeDescription = desc;
|
||||
}
|
||||
|
||||
class TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<
|
||||
int i, int j, int max_bcast_rank> :
|
||||
TFL_RuntimePredOpTrait<"operand #" # i # " and operand #" # j #
|
||||
" have the same shape or broadcastable shapes within the rank " #
|
||||
max_bcast_rank,
|
||||
CPred<"TFL::IsBinaryOperandsHaveSameShapesOrBroadcastableShape("
|
||||
"$_op.getOperand(" # i # ").getType(), $_op.getOperand(" # j #
|
||||
").getType(), " # max_bcast_rank # ")">>;
|
||||
|
||||
// These additional types/type constraints here are used to decouple the ops
|
||||
// from runtime support for the ops. Prefer to use these types when defining
|
||||
// new TF_Ops for uniformity.
|
||||
@ -344,7 +360,10 @@ class TFL_ConvOp<string mnemonic, string opSummary, int index> :
|
||||
// TFL op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def TFL_AbsOp : TFL_Op<"abs", [
|
||||
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Absolute value operator";
|
||||
|
||||
let description = [{
|
||||
@ -360,10 +379,9 @@ an output element, this operation computes \\(y = |x|\\).
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
Commutative,
|
||||
TFL_GpuTargetOp]> {
|
||||
def TFL_AddOp : TFL_Op<"add", [
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>,
|
||||
ResultsBroadcastableShape, NoSideEffect, Commutative, TFL_GpuTargetOp]> {
|
||||
let summary = "Addition operator";
|
||||
|
||||
let description = [{
|
||||
@ -371,11 +389,11 @@ def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape,
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins AnyTensor:$lhs,
|
||||
AnyTensor:$rhs,
|
||||
ins TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$lhs,
|
||||
TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$rhs,
|
||||
TFL_AFAttr:$fused_activation_function);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$output);
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
@ -432,7 +450,7 @@ retained with length 1.
|
||||
}
|
||||
|
||||
def TFL_TransposeConvOp:
|
||||
TFL_Op<"transpose_conv", [NoSideEffect]> {
|
||||
TFL_Op<"transpose_conv", [NoSideEffect, TFL_GpuTargetOp]> {
|
||||
let summary = "Transpose convolution operator";
|
||||
|
||||
let description = [{
|
||||
@ -1527,7 +1545,10 @@ def TFL_LogisticOp: TFL_Op<"logistic", [
|
||||
}
|
||||
|
||||
def TFL_LogOp: TFL_Op<"log", [
|
||||
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Natural logarithm operator";
|
||||
|
||||
let description = [{
|
||||
@ -1637,7 +1658,7 @@ def TFL_MaxPoolingWithArgMax2DOp :
|
||||
}
|
||||
|
||||
def TFL_MaxUnpooling2DOp :
|
||||
Op<TFL_Dialect, "max_unpooling_2d", [NoSideEffect]> {
|
||||
Op<TFL_Dialect, "max_unpooling_2d", [NoSideEffect, TFL_GpuTargetOp]> {
|
||||
let summary = "Max Unpool 2D";
|
||||
|
||||
let description = [{
|
||||
@ -1690,7 +1711,7 @@ def TFL_MaximumOp : TFL_Op<"maximum", [
|
||||
let hasOptions = 0;
|
||||
}
|
||||
|
||||
def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect]> {
|
||||
def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect, TFL_GpuTargetOp]> {
|
||||
let summary = "Mean operator";
|
||||
|
||||
let description = [{
|
||||
@ -2072,7 +2093,10 @@ def TFL_PadV2Op : TFL_Op<"padv2", [
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
|
||||
def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Power operator";
|
||||
|
||||
let description = [{
|
||||
@ -2092,7 +2116,7 @@ def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape, NoSideEffect, NoQuanti
|
||||
let builders = [TFL_BroadcastableBinaryBuilder];
|
||||
}
|
||||
|
||||
def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect]> {
|
||||
def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect, TFL_GpuTargetOp]> {
|
||||
let summary = "Parameterized Relu operator";
|
||||
|
||||
let description = [{
|
||||
@ -2141,6 +2165,17 @@ def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect,
|
||||
let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y);
|
||||
|
||||
// This builder doesn't work with quantized type, so it can only be used by
|
||||
// non-quantization tablegen patterns. Currently, it is used by the
|
||||
// elementwise-move reordering pattern in the optimize_patterns.td
|
||||
let builders = [OpBuilder<
|
||||
"Builder *, OperationState &state, Value input",
|
||||
[{
|
||||
state.addOperands({input});
|
||||
state.addTypes(input.getType());
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
|
||||
@ -2157,6 +2192,17 @@ def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
|
||||
let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y);
|
||||
|
||||
// This builder doesn't work with quantized type, so it can only be used by
|
||||
// non-quantization tablegen patterns. Currently, it is used by the
|
||||
// elementwise-move reordering pattern in the optimize_patterns.td
|
||||
let builders = [OpBuilder<
|
||||
"Builder *, OperationState &state, Value input",
|
||||
[{
|
||||
state.addOperands({input});
|
||||
state.addTypes(input.getType());
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect,
|
||||
@ -2172,6 +2218,17 @@ def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect,
|
||||
let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y);
|
||||
|
||||
// This builder doesn't work with quantized type, so it can only be used by
|
||||
// non-quantization tablegen patterns. Currently, it is used by the
|
||||
// elementwise-move reordering pattern in the optimize_patterns.td
|
||||
let builders = [OpBuilder<
|
||||
"Builder *, OperationState &state, Value input",
|
||||
[{
|
||||
state.addOperands({input});
|
||||
state.addTypes(input.getType());
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
def TFL_ReshapeOp: TFL_Op<"reshape", [
|
||||
@ -2223,7 +2280,10 @@ slice `i`, with the first `seq_lengths[i]` slices along dimension
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
|
||||
def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Reciprocal of square root operator";
|
||||
|
||||
let description = [{
|
||||
@ -2371,7 +2431,10 @@ def TFL_SelectV2Op : TFL_Op<"select_v2", [NoSideEffect]> {
|
||||
}
|
||||
|
||||
def TFL_SinOp: TFL_Op<"sin", [
|
||||
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Sine operator";
|
||||
|
||||
let description = [{
|
||||
@ -2413,7 +2476,10 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [
|
||||
}
|
||||
|
||||
def TFL_SqrtOp: TFL_Op<"sqrt", [
|
||||
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Square root operator";
|
||||
|
||||
let description = [{
|
||||
@ -2428,7 +2494,10 @@ def TFL_SqrtOp: TFL_Op<"sqrt", [
|
||||
}
|
||||
|
||||
def TFL_SquareOp: TFL_Op<"square", [
|
||||
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Square operator";
|
||||
|
||||
let description = [{
|
||||
@ -2472,7 +2541,10 @@ def TFL_SubOp : TFL_Op<"sub", [ResultsBroadcastableShape, NoSideEffect]> {
|
||||
// TODO(jpienaar): Expand the kernel implementation to support all types besides
|
||||
// I32 and F32.
|
||||
def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [
|
||||
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Squared difference operator";
|
||||
|
||||
let description = [{
|
||||
@ -2499,7 +2571,8 @@ def TFL_TanhOp: TFL_Op<"tanh", [
|
||||
// zero_point = central_value
|
||||
// scale = 1. / (central_value - min_value)
|
||||
FixedResultScale<Int8UniformQuantizedType<0, 78125, -7>>,
|
||||
FixedResultScale<UInt8UniformQuantizedType<128, 78125, -7>>]> {
|
||||
FixedResultScale<UInt8UniformQuantizedType<128, 78125, -7>>,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Hyperbolic tangent operator";
|
||||
|
||||
let description = [{
|
||||
@ -2509,6 +2582,17 @@ def TFL_TanhOp: TFL_Op<"tanh", [
|
||||
let arguments = (ins TFL_TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$x);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$y);
|
||||
|
||||
// This builder doesn't work with quantized type, so it can only be used by
|
||||
// non-quantization tablegen patterns. Currently, it is used by the
|
||||
// elementwise-move reordering pattern in the optimize_patterns.td
|
||||
let builders = [OpBuilder<
|
||||
"Builder *, OperationState &state, Value input",
|
||||
[{
|
||||
state.addOperands({input});
|
||||
state.addTypes(input.getType());
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, SameOperandsAndResultsScale,
|
||||
@ -2694,7 +2778,8 @@ def TFL_SpaceToDepthOp: TFL_Op<"space_to_depth", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultsScale,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TCresVTEtIsSameAsOp<0, 0>>
|
||||
TCresVTEtIsSameAsOp<0, 0>>,
|
||||
TFL_GpuTargetOp
|
||||
]> {
|
||||
let summary = "SpaceToDepth operator";
|
||||
|
||||
@ -2957,14 +3042,13 @@ def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
// TODO: add uint8 support when ready.
|
||||
TFL_TensorOf<[F32, I32, I64]>:$input,
|
||||
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8]>:$input,
|
||||
TFL_TensorOf<[I32, I64]>:$pad,
|
||||
TFL_MirrorPaddingAttr:$mode
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I32, I64]>:$output
|
||||
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
|
||||
63
tensorflow/compiler/mlir/lite/json_to_flatbuffer.cc
Normal file
63
tensorflow/compiler/mlir/lite/json_to_flatbuffer.cc
Normal file
@ -0,0 +1,63 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <stdint.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdio>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "flatbuffers/idl.h" // from @flatbuffers
|
||||
#include "flatbuffers/util.h" // from @flatbuffers
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
// load FlatBuffer schema (.fbs) and JSON from disk
|
||||
if (argc < 2) {
|
||||
std::cerr << "Missing input argument. Usage:\n"
|
||||
<< argv[0] << " <filename or - for stdin>\n\n";
|
||||
return 1;
|
||||
}
|
||||
const char* schema_path = argv[1];
|
||||
const char* json_path = argv[2];
|
||||
std::string schema;
|
||||
std::string json;
|
||||
|
||||
const bool status =
|
||||
flatbuffers::LoadFile(schema_path, /*binary=*/false, &schema) &&
|
||||
flatbuffers::LoadFile(json_path, /*binary=*/false, &json);
|
||||
if (!status) {
|
||||
std::cerr << "couldn't load files!\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
// parse schema first, so we can use it to parse the data after
|
||||
flatbuffers::Parser parser;
|
||||
const bool schema_parse_result =
|
||||
parser.Parse(schema.c_str()) && parser.Parse(json.c_str());
|
||||
if (!schema_parse_result) {
|
||||
std::cerr << "Parse error.\n";
|
||||
return 1;
|
||||
}
|
||||
const size_t length = parser.builder_.GetSize();
|
||||
const size_t n =
|
||||
std::fwrite(parser.builder_.GetBufferPointer(), 1, length, stdout);
|
||||
if (n != length) {
|
||||
std::cerr << "print to stdout filed.\n";
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
@ -88,7 +88,6 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||
pass_config.lower_tensor_list_ops = true;
|
||||
pass_config.shape_inference = false;
|
||||
|
||||
return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module),
|
||||
pass_config, result);
|
||||
|
||||
@ -16,9 +16,12 @@ limitations under the License.
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
||||
#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project
|
||||
@ -41,6 +44,77 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status HandleInputOutputArraysWithModule(const toco::ModelFlags& model_flags,
|
||||
mlir::OwningModuleRef* module) {
|
||||
mlir::FuncOp entry_function = nullptr;
|
||||
for (auto func : module->get().getOps<mlir::FuncOp>()) {
|
||||
if (auto tf_attrs =
|
||||
func.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function")) {
|
||||
// TODO(jaesung): There could be multiple entry functions. Let's handle
|
||||
// such cases if there are any needs for that.
|
||||
if (entry_function != nullptr) {
|
||||
return errors::InvalidArgument(
|
||||
"There should be only one tf.entry_function");
|
||||
}
|
||||
entry_function = func;
|
||||
}
|
||||
}
|
||||
if (entry_function == nullptr) {
|
||||
return errors::InvalidArgument("no tf.entry_function found");
|
||||
}
|
||||
|
||||
// Get the list of input Op names from the function attribute.
|
||||
mlir::DictionaryAttr tf_attrs =
|
||||
entry_function.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
|
||||
llvm::SmallVector<llvm::StringRef, 4> function_input_names;
|
||||
function_input_names.reserve(model_flags.input_arrays().size());
|
||||
auto input_attr = tf_attrs.get("inputs");
|
||||
if (!input_attr) {
|
||||
return errors::InvalidArgument("no inputs attribute found");
|
||||
}
|
||||
auto input_names = input_attr.cast<mlir::StringAttr>().getValue();
|
||||
input_names.split(function_input_names, ",");
|
||||
if (function_input_names.size() != model_flags.input_arrays().size()) {
|
||||
return errors::InvalidArgument(
|
||||
"input array size mismatch: got ", function_input_names.size(),
|
||||
", expected: ", model_flags.input_arrays().size());
|
||||
}
|
||||
llvm::StringSet<> function_input_names_set;
|
||||
function_input_names_set.insert(function_input_names.begin(),
|
||||
function_input_names.end());
|
||||
for (const auto& input_array : model_flags.input_arrays()) {
|
||||
if (function_input_names_set.count(input_array.name()) == 0) {
|
||||
return errors::InvalidArgument("input array name (", input_array.name(),
|
||||
") does not exist in the given graph");
|
||||
}
|
||||
}
|
||||
|
||||
// Get the list of output Op names from the function attribute.
|
||||
llvm::SmallVector<llvm::StringRef, 4> function_output_names;
|
||||
function_output_names.reserve(model_flags.output_arrays().size());
|
||||
auto output_attr = tf_attrs.get("outputs");
|
||||
if (!output_attr) {
|
||||
return errors::InvalidArgument("no outputs attribute found");
|
||||
}
|
||||
auto output_names = output_attr.cast<mlir::StringAttr>().getValue();
|
||||
output_names.split(function_output_names, ",");
|
||||
if (function_output_names.size() != model_flags.output_arrays().size()) {
|
||||
return errors::InvalidArgument(
|
||||
"output array size mismatch: got ", function_output_names.size(),
|
||||
", expected: ", model_flags.output_arrays().size());
|
||||
}
|
||||
llvm::StringSet<> function_output_names_set;
|
||||
function_output_names_set.insert(function_output_names.begin(),
|
||||
function_output_names.end());
|
||||
for (const auto& output_array : model_flags.output_arrays()) {
|
||||
if (function_output_names_set.count(output_array) == 0) {
|
||||
return errors::InvalidArgument("output array name (", output_array,
|
||||
") does not exist in the given graph");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConvertSavedModelToTFLiteFlatBuffer(
|
||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||
string* result) {
|
||||
@ -77,11 +151,15 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
|
||||
model_flags.saved_model_version(), tags,
|
||||
exported_names, &context));
|
||||
|
||||
if (!model_flags.input_arrays().empty() ||
|
||||
!model_flags.output_arrays().empty()) {
|
||||
TF_RETURN_IF_ERROR(HandleInputOutputArraysWithModule(model_flags, &module));
|
||||
}
|
||||
|
||||
mlir::TFL::PassConfig pass_config(quant_specs);
|
||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||
pass_config.lower_tensor_list_ops = true;
|
||||
pass_config.shape_inference = true;
|
||||
|
||||
auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
|
||||
toco_flags, std::move(module), pass_config, result);
|
||||
|
||||
@ -285,7 +285,7 @@ Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
|
||||
if (pass_config.legalize_tf_while) {
|
||||
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
|
||||
}
|
||||
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
|
||||
pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
|
||||
|
||||
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_native_cc_binary")
|
||||
load(
|
||||
"//tensorflow/core/platform:build_config.bzl",
|
||||
"tf_proto_library",
|
||||
@ -14,7 +14,10 @@ package(
|
||||
package_group(
|
||||
name = "friends",
|
||||
includes = ["//third_party/mlir:subpackages"],
|
||||
packages = ["//tensorflow/compiler/mlir/..."],
|
||||
packages = [
|
||||
"//learning/brain/experimental/mlir/quantization/...",
|
||||
"//tensorflow/compiler/mlir/...",
|
||||
],
|
||||
)
|
||||
|
||||
exports_files([
|
||||
@ -112,11 +115,22 @@ tf_native_cc_binary(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "numerical_utils",
|
||||
srcs = ["numerical_utils.cc"],
|
||||
hdrs = ["numerical_utils.h"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "device_target",
|
||||
srcs = ["device_target.cc"],
|
||||
hdrs = ["device_target.h"],
|
||||
deps = [
|
||||
":numerical_utils",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
@ -139,3 +153,13 @@ cc_library(
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "numerical_utils_test",
|
||||
srcs = ["numerical_utils_test.cc"],
|
||||
deps = [
|
||||
":numerical_utils",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
@ -15,12 +15,18 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace quant {
|
||||
@ -39,7 +45,7 @@ DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) {
|
||||
assert(qi8n_ == qi8n_);
|
||||
}
|
||||
|
||||
Optional<KernelSpec> DeviceTarget::Get(QuantizeRegionOp op) const {
|
||||
Optional<KernelSpec> DeviceTarget::GetKernelSpec(QuantizeRegionOp op) const {
|
||||
auto kernel_specs_it = specs_.find(op.logical_kernel());
|
||||
if (kernel_specs_it == specs_.end()) return llvm::None;
|
||||
|
||||
@ -50,9 +56,15 @@ Optional<KernelSpec> DeviceTarget::Get(QuantizeRegionOp op) const {
|
||||
return kernel_specs_it->getValue().Find(signature);
|
||||
}
|
||||
|
||||
ScaleDecomposeFn DeviceTarget::GetDecomposeFn(QuantizeRegionOp op) const {
|
||||
auto kernel_specs_it = specs_.find(op.logical_kernel());
|
||||
if (kernel_specs_it == specs_.end()) return ScaleDecomposeFn(nullptr);
|
||||
return kernel_specs_it->second.GetDecomposeFn();
|
||||
}
|
||||
|
||||
LogicalResult DeviceTarget::RegisterKernel(
|
||||
llvm::StringRef kernel, const KernelSpecs::Signature& signature,
|
||||
const ScaleFn& fn) {
|
||||
const ScaleFn& fn, const ScaleDecomposeFn& dfn) {
|
||||
return specs_[kernel].Add(signature, {ScaleConstraintType::CustomScale, fn});
|
||||
}
|
||||
|
||||
@ -78,5 +90,49 @@ void DeviceTarget::AppendToSignature(ArrayAttr specs_attr,
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult DeviceTarget::DecomposeMultiplyAccumulateScale(
|
||||
Operation* op, quant::QuantizedMultipliers* input_multipliers,
|
||||
quant::QuantizedMultipliers* output_multipliers,
|
||||
quant::QuantizedRanges* output_ranges) {
|
||||
auto rop = llvm::dyn_cast<quant::QuantizeRegionOp>(op);
|
||||
if (!rop) return failure();
|
||||
|
||||
llvm::SmallVector<Type, 4> input_specs, out_specs;
|
||||
for (auto spec : rop.input_specs()) {
|
||||
input_specs.push_back(spec.cast<TypeAttr>().getValue());
|
||||
}
|
||||
for (auto spec : rop.output_specs()) {
|
||||
out_specs.push_back(spec.cast<TypeAttr>().getValue());
|
||||
}
|
||||
|
||||
auto in_spec = input_specs[0].dyn_cast<quant::UniformQuantizedType>();
|
||||
// TODO(fengliuai): handles the PerAxis QuantizedType.
|
||||
auto w_spec = input_specs[1].dyn_cast<quant::UniformQuantizedType>();
|
||||
auto b_spec = input_specs[2].dyn_cast<quant::UniformQuantizedType>();
|
||||
auto o_spec = out_specs[0].dyn_cast<quant::UniformQuantizedType>();
|
||||
if (!in_spec || !w_spec || !b_spec || !o_spec) return failure();
|
||||
|
||||
double scale_product = in_spec.getScale() * w_spec.getScale();
|
||||
if (fabs(scale_product - b_spec.getScale()) < 1e-6) return failure();
|
||||
|
||||
// input multipliers
|
||||
input_multipliers->append(3, kUnitQuantizedMultiplier);
|
||||
|
||||
// output multipliers
|
||||
double real_multiplier = o_spec.getScale() / scale_product;
|
||||
output_multipliers->push_back(quant::QuantizeMultiplier(real_multiplier));
|
||||
|
||||
// output ranges
|
||||
auto min = rop.getAttrOfType<FloatAttr>("min");
|
||||
auto max = rop.getAttrOfType<FloatAttr>("max");
|
||||
output_ranges->push_back(quant::CalculateQuantizedRange(
|
||||
o_spec.getScale(), o_spec.getZeroPoint(),
|
||||
(min ? absl::optional<double>(min.getValueAsDouble()) : absl::nullopt),
|
||||
(max ? absl::optional<double>(max.getValueAsDouble()) : absl::nullopt),
|
||||
o_spec.getStorageTypeMin(), o_spec.getStorageTypeMax()));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace quant
|
||||
} // namespace mlir
|
||||
|
||||
@ -17,13 +17,13 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_DEVICE_TARGET_H_
|
||||
|
||||
#include <functional>
|
||||
#include <ostream>
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/Hashing.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
@ -33,6 +33,7 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace quant {
|
||||
@ -40,9 +41,17 @@ namespace quant {
|
||||
class QuantizeContext;
|
||||
|
||||
using AdjacentOperations = llvm::SmallVectorImpl<Operation*>;
|
||||
using QuantizedMultipliers = llvm::SmallVector<QuantizedMultiplier, 4>;
|
||||
using QuantizedRanges = llvm::SmallVector<QuantizedRange, 4>;
|
||||
using ScaleFn = std::function<LogicalResult(QuantizeContext*, Operation*,
|
||||
AdjacentOperations*, bool*)>;
|
||||
|
||||
using ScaleDecomposeFn =
|
||||
std::function<LogicalResult(Operation*, QuantizedMultipliers*,
|
||||
QuantizedMultipliers*, QuantizedRanges*)>;
|
||||
|
||||
static const QuantizedMultiplier kUnitQuantizedMultiplier{1, 0};
|
||||
|
||||
enum class ScaleConstraintType {
|
||||
OutputInputSameScale,
|
||||
OutputInputFreeScale,
|
||||
@ -73,12 +82,25 @@ class KernelSpecs {
|
||||
}
|
||||
}
|
||||
|
||||
ScaleDecomposeFn GetDecomposeFn() const { return decompose_fn_; }
|
||||
|
||||
// Adds the kernel signature with the kernel specification.
|
||||
LogicalResult Add(const Signature& signature, const KernelSpec& spec) {
|
||||
if (all_signatures_.insert({signature, spec}).second) return success();
|
||||
return failure();
|
||||
}
|
||||
|
||||
KernelSpecs& WithSignature(const KernelSpecs::Signature& signature,
|
||||
const ScaleFn& fn) {
|
||||
Add(signature, {ScaleConstraintType::CustomScale, fn});
|
||||
return *this;
|
||||
}
|
||||
|
||||
KernelSpecs& WithImpl(const ScaleDecomposeFn& dfn) {
|
||||
decompose_fn_ = dfn;
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
// The signature is pattern match based.
|
||||
struct SignatureInfo : public llvm::DenseMapInfo<Signature> {
|
||||
@ -101,6 +123,10 @@ class KernelSpecs {
|
||||
// Maps the signature to the kernel spec. Note that the matching is
|
||||
// pattern match based.
|
||||
llvm::DenseMap<Signature, KernelSpec, SignatureInfo> all_signatures_;
|
||||
|
||||
// A method to compute the effective multipliers. This is independent on the
|
||||
// bits of the ports, thus all the signature shares the same here.
|
||||
ScaleDecomposeFn decompose_fn_;
|
||||
};
|
||||
|
||||
class DeviceTarget {
|
||||
@ -108,19 +134,26 @@ class DeviceTarget {
|
||||
explicit DeviceTarget(MLIRContext* ctx);
|
||||
|
||||
// Retrieves the kernel spec for the quant region op.
|
||||
Optional<KernelSpec> Get(quant::QuantizeRegionOp op) const;
|
||||
Optional<KernelSpec> GetKernelSpec(quant::QuantizeRegionOp op) const;
|
||||
|
||||
// Retrieves the scale decomposition function for the quant region op.
|
||||
ScaleDecomposeFn GetDecomposeFn(quant::QuantizeRegionOp op) const;
|
||||
|
||||
protected:
|
||||
// Adds the kernel spec with the custom scale function for the kernel.
|
||||
LogicalResult RegisterKernel(llvm::StringRef kernel,
|
||||
const KernelSpecs::Signature& signature,
|
||||
const ScaleFn& fn);
|
||||
const ScaleFn& fn, const ScaleDecomposeFn& dfn);
|
||||
|
||||
// Adds the kernel spec with the scale constraint type for the kernel.
|
||||
LogicalResult RegisterKernel(llvm::StringRef kernel,
|
||||
const KernelSpecs::Signature& signature,
|
||||
const ScaleConstraintType constraint);
|
||||
|
||||
// Adds the kernel with the name. Retrun an existing one if it has been
|
||||
// added before.
|
||||
KernelSpecs& RegisterKernel(llvm::StringRef kernel) { return specs_[kernel]; }
|
||||
|
||||
// converts specification to signature:
|
||||
// - UniformedQuantizedType -> AnyQuantizedType
|
||||
// - AnyQuantizedType (int) -> AnyQuantizedType
|
||||
@ -128,6 +161,13 @@ class DeviceTarget {
|
||||
void AppendToSignature(ArrayAttr specs_attr,
|
||||
KernelSpecs::Signature* signature) const;
|
||||
|
||||
// For "mulmat->add" type of kernels, convert the scales of all the ports to
|
||||
// multipliers.
|
||||
static LogicalResult DecomposeMultiplyAccumulateScale(
|
||||
Operation* op, quant::QuantizedMultipliers* input_multipliers,
|
||||
quant::QuantizedMultipliers* output_multipliers,
|
||||
quant::QuantizedRanges* output_ranges);
|
||||
|
||||
// A set of parameters are required to build the signatures.
|
||||
FloatType f32_;
|
||||
IntegerType i8_;
|
||||
|
||||
@ -33,7 +33,6 @@ limitations under the License.
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/Functional.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_info.pb.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
|
||||
@ -55,7 +54,8 @@ namespace quant {
|
||||
using QuantParamsEntry = QuantizationInfo::QuantParams;
|
||||
|
||||
namespace {
|
||||
class ImportQuantStatsPass : public FunctionPass<ImportQuantStatsPass> {
|
||||
class ImportQuantStatsPass
|
||||
: public PassWrapper<ImportQuantStatsPass, FunctionPass> {
|
||||
public:
|
||||
explicit ImportQuantStatsPass(OperationToName op_to_name)
|
||||
: op_to_name_(op_to_name) {}
|
||||
@ -193,7 +193,7 @@ void ImportQuantStatsPass::runOnFunction() {
|
||||
}
|
||||
|
||||
// Creates an instance of the default quant parameters pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateImportQuantStatsPass(
|
||||
OperationToName op_to_name, const std::string &stats_str) {
|
||||
auto pass = absl::make_unique<ImportQuantStatsPass>(op_to_name);
|
||||
if (pass->ParseQuantStats(stats_str)) return nullptr;
|
||||
@ -203,7 +203,7 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
|
||||
// Creates an instance pass to import quantization stats to the operations in
|
||||
// the function. A custom method to get the name from the op is used because
|
||||
// different dialect ops might have different ways to assign the name.
|
||||
std::unique_ptr<OpPassBase<FuncOp>>
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) {
|
||||
auto get_name_func = [](Operation *op) {
|
||||
Location loc = op->getLoc();
|
||||
|
||||
@ -0,0 +1,82 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace quant {
|
||||
|
||||
// This method is adopted from TFLite:
|
||||
// ["tensorflow/lite/kernels/internal/quantization_util.cc"]
|
||||
QuantizedMultiplier QuantizeMultiplier(double double_multiplier) {
|
||||
if (double_multiplier < 1e-6) {
|
||||
return {0, 0};
|
||||
}
|
||||
|
||||
int32_t shift;
|
||||
const double q = frexp(double_multiplier, &shift);
|
||||
auto q_fixed = static_cast<int64_t>(round(q * (1ll << 31)));
|
||||
assert(q_fixed <= (1ll << 31));
|
||||
if (q_fixed == (1ll << 31)) {
|
||||
q_fixed /= 2;
|
||||
++shift;
|
||||
}
|
||||
assert(q_fixed <= std::numeric_limits<int32_t>::max());
|
||||
// A shift amount smaller than -31 would cause all bits to be shifted out
|
||||
// and thus all results would be zero. We implement that instead with
|
||||
// q_fixed==0, so as to avoid hitting issues with right-shift
|
||||
// operations with shift amounts greater than 31. Note that this happens
|
||||
// roughly when abs(double_multiplier) < 2^-31 and the present handling means
|
||||
// that we're effectively flushing tiny double_multiplier's to zero.
|
||||
// We could conceivably handle values in the range (roughly) [32, 63]
|
||||
// as 'denormals' i.e. (shift==0, q_fixed < 2^30). In that point of view
|
||||
// the present handling is just doing 'flush denormals to zero'. We could
|
||||
// reconsider and actually generate nonzero denormals if a need arises.
|
||||
if (shift < -31) {
|
||||
shift = 0;
|
||||
q_fixed = 0;
|
||||
}
|
||||
return {static_cast<int32_t>(q_fixed), shift};
|
||||
}
|
||||
|
||||
QuantizedRange CalculateQuantizedRange(double scale, int32_t zero_point,
|
||||
absl::optional<double> rmin,
|
||||
absl::optional<double> rmax,
|
||||
int32_t qmin, int32_t qmax) {
|
||||
auto quantize = [scale, zero_point](float f) {
|
||||
return zero_point + static_cast<int32_t>(std::round(f / scale));
|
||||
};
|
||||
|
||||
if (rmin.has_value() && rmax.has_value()) {
|
||||
return {std::max(qmin, quantize(rmin.value())),
|
||||
std::min(qmax, quantize(rmax.value()))};
|
||||
} else if (rmin.has_value()) {
|
||||
return {std::max(qmin, quantize(rmin.value())), qmax};
|
||||
} else if (rmax.has_value()) {
|
||||
return {qmin, std::min(qmax, quantize(rmax.value()))};
|
||||
} else {
|
||||
return {qmin, qmax};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace quant
|
||||
} // namespace mlir
|
||||
45
tensorflow/compiler/mlir/lite/quantization/numerical_utils.h
Normal file
45
tensorflow/compiler/mlir/lite/quantization/numerical_utils.h
Normal file
@ -0,0 +1,45 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_NUMERICAL_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_NUMERICAL_UTILS_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace quant {
|
||||
|
||||
using QuantizedMultiplier = std::pair<int32_t, int32_t>;
|
||||
using QuantizedRange = std::pair<int32_t, int32_t>;
|
||||
|
||||
// Decompose double precision multiplier to integer multiplier and exponent.
|
||||
// double_multiplier = int_multiplier * 2 ^ (-31 + exponent)
|
||||
// int_multiplier will be range of (2^31, 2^30].
|
||||
QuantizedMultiplier QuantizeMultiplier(double double_multiplier);
|
||||
|
||||
// Calculate the effective quantized value range for the scale, zero point. The
|
||||
// range is the minimum range defined by [rmin, rmax] and [qmin, qmax].
|
||||
QuantizedRange CalculateQuantizedRange(double scale, int32_t zero_point,
|
||||
absl::optional<double> rmin,
|
||||
absl::optional<double> rmax,
|
||||
int32_t qmin, int32_t qmax);
|
||||
|
||||
} // namespace quant
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_NUMERICAL_UTILS_H_
|
||||
@ -0,0 +1,114 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "absl/types/optional.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace quant {
|
||||
|
||||
namespace {
|
||||
|
||||
double ComposeScale(const QuantizedMultiplier& input) {
|
||||
return input.first * exp2(-31 + input.second);
|
||||
}
|
||||
|
||||
TEST(NumericalUtils, QuantizeMultiplier) {
|
||||
// Decompose multiplier larger than 1.
|
||||
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e6)), 1.0e6);
|
||||
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e3)), 1.0e3);
|
||||
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(10.)), 10.);
|
||||
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(5.)), 5.);
|
||||
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(2.)), 2.);
|
||||
|
||||
// Decompose multiplier between 1.0 and 1e-6.
|
||||
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(0.0)), 0.0);
|
||||
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0)), 1.0);
|
||||
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-1)), 1.0e-1);
|
||||
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-2)), 1.0e-2);
|
||||
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-3)), 1.0e-3);
|
||||
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-4)), 1.0e-4);
|
||||
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-5)), 1.0e-5);
|
||||
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-6)), 1.0e-6);
|
||||
|
||||
// When scale is smaller than 1.0e-6, it is decomposed to {0, 0}.
|
||||
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-7)), 0.0);
|
||||
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-8)), 0.0);
|
||||
}
|
||||
|
||||
TEST(NumericalUtils, ActivationRange) {
|
||||
// zero point = 0
|
||||
auto a =
|
||||
CalculateQuantizedRange(1e-6, 0, absl::nullopt, absl::nullopt, -128, 127);
|
||||
ASSERT_EQ(a.first, -128);
|
||||
ASSERT_EQ(a.second, 127);
|
||||
|
||||
auto b = CalculateQuantizedRange(1e-6, 0, 0.0, absl::nullopt, -128, 127);
|
||||
ASSERT_EQ(b.first, 0);
|
||||
ASSERT_EQ(b.second, 127);
|
||||
|
||||
auto c = CalculateQuantizedRange(1e-6, 0, -1.0, 1.0, -128, 127);
|
||||
ASSERT_EQ(c.first, -128);
|
||||
ASSERT_EQ(c.second, 127);
|
||||
|
||||
auto d = CalculateQuantizedRange(1e-6, 0, 0.0, 6.0, -128, 127);
|
||||
ASSERT_EQ(d.first, 0);
|
||||
ASSERT_EQ(d.second, 127);
|
||||
|
||||
// zero point = 100
|
||||
auto e = CalculateQuantizedRange(1e-6, 100, absl::nullopt, absl::nullopt,
|
||||
-128, 127);
|
||||
ASSERT_EQ(e.first, -128);
|
||||
ASSERT_EQ(e.second, 127);
|
||||
|
||||
auto f = CalculateQuantizedRange(1e-6, 100, 0.0, absl::nullopt, -128, 127);
|
||||
ASSERT_EQ(f.first, 100);
|
||||
ASSERT_EQ(f.second, 127);
|
||||
|
||||
auto g = CalculateQuantizedRange(1e-6, 100, -1.0, 1.0, -128, 127);
|
||||
ASSERT_EQ(g.first, -128);
|
||||
ASSERT_EQ(g.second, 127);
|
||||
|
||||
auto h = CalculateQuantizedRange(1e-6, 100, 0.0, 6.0, -128, 127);
|
||||
ASSERT_EQ(h.first, 100);
|
||||
ASSERT_EQ(h.second, 127);
|
||||
|
||||
// zero point = -100
|
||||
auto i = CalculateQuantizedRange(1e-6, -100, absl::nullopt, absl::nullopt,
|
||||
-128, 127);
|
||||
ASSERT_EQ(i.first, -128);
|
||||
ASSERT_EQ(i.second, 127);
|
||||
|
||||
auto j = CalculateQuantizedRange(1e-6, -100, 0.0, absl::nullopt, -128, 127);
|
||||
ASSERT_EQ(j.first, -100);
|
||||
ASSERT_EQ(j.second, 127);
|
||||
|
||||
auto k = CalculateQuantizedRange(1e-6, -100, -1.0, 1.0, -128, 127);
|
||||
ASSERT_EQ(k.first, -128);
|
||||
ASSERT_EQ(k.second, 127);
|
||||
|
||||
auto l = CalculateQuantizedRange(1e-6, -100, 0.0, 6.0, -128, 127);
|
||||
ASSERT_EQ(l.first, -100);
|
||||
ASSERT_EQ(l.second, 127);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace quant
|
||||
} // namespace mlir
|
||||
@ -67,7 +67,7 @@ std::vector<quant::QuantizeRegionOp> QuantizeContext::GetAllOps() {
|
||||
LogicalResult QuantizeContext::Handle(
|
||||
quant::QuantizeRegionOp op, llvm::SmallVectorImpl<Operation *> *new_items,
|
||||
bool *changed) {
|
||||
auto spec = target_spec_.Get(op);
|
||||
auto spec = target_spec_.GetKernelSpec(op);
|
||||
if (!spec.hasValue()) {
|
||||
op.emitWarning(
|
||||
"Couldn't find kernel from the registeration for quantization.");
|
||||
|
||||
@ -27,13 +27,13 @@ using OperationToName = std::function<llvm::StringRef(Operation* op)>;
|
||||
// Creates an instance pass to import quantization stats to the operations in
|
||||
// the function. A custom method to get the name from the op is used because
|
||||
// different dialect ops might have different ways to assign the name.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateImportQuantStatsPass(
|
||||
OperationToName op_to_name, const std::string& stats_str);
|
||||
|
||||
// Creates an instance pass to import quantization stats to the operations in
|
||||
// the function. A custom method to get the name from the op is used because
|
||||
// different dialect ops might have different ways to assign the name.
|
||||
std::unique_ptr<OpPassBase<FuncOp>>
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreateImportQuantStatsPassForTFControlDialect(const std::string& stats_str);
|
||||
|
||||
} // namespace quant
|
||||
|
||||
@ -79,7 +79,7 @@ TypeAttr RescaleQuantizedType(Type input, Attribute factor) {
|
||||
SmallVector<double, 4> new_scales;
|
||||
new_scales.reserve(scales.size());
|
||||
auto scales_iter = scales.begin();
|
||||
for (auto f : factor_values) {
|
||||
for (const auto& f : factor_values) {
|
||||
new_scales.push_back(*(scales_iter++) *
|
||||
std::fabs(FloatAttr::getValueAsDouble(f)));
|
||||
}
|
||||
|
||||
@ -25,7 +25,7 @@ namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
// Legalize the tf ops to the quant ops, so the quantization passes can work.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFToQuantPass();
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFToQuantPass();
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
||||
@ -27,7 +27,7 @@ namespace TF {
|
||||
namespace {
|
||||
|
||||
// Legalize TF quantization emulation ops to that in Quant ops dialect.
|
||||
struct LegalizeTFToQuant : public FunctionPass<LegalizeTFToQuant> {
|
||||
struct LegalizeTFToQuant : public PassWrapper<LegalizeTFToQuant, FunctionPass> {
|
||||
explicit LegalizeTFToQuant() = default;
|
||||
LegalizeTFToQuant(const LegalizeTFToQuant &) {}
|
||||
|
||||
@ -146,12 +146,12 @@ void LegalizeTFToQuant::runOnFunction() {
|
||||
auto func = getFunction();
|
||||
auto *ctx = func.getContext();
|
||||
patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant>(ctx);
|
||||
applyPatternsGreedily(func, patterns);
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the TensorFlow dialect to QuantOps dialect pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFToQuantPass() {
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFToQuantPass() {
|
||||
return std::make_unique<LegalizeTFToQuant>();
|
||||
}
|
||||
|
||||
|
||||
@ -1,112 +0,0 @@
|
||||
load(
|
||||
"//third_party/mlir:tblgen.bzl",
|
||||
"gentbl",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
":friends",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "friends",
|
||||
includes = ["//third_party/mlir:subpackages"],
|
||||
packages = [
|
||||
"//tensorflow/compiler/aot/...",
|
||||
"//tensorflow/compiler/mlir/...",
|
||||
"//tensorflow/compiler/mlir/lite/...",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_xla_quantization_passes",
|
||||
srcs = [
|
||||
"cpu_kernel_fusion.cc",
|
||||
"generated_cpu_kernel_fusion.inc",
|
||||
"materialize.cc",
|
||||
"op_quant_spec.inc",
|
||||
"propagate.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"passes.h",
|
||||
],
|
||||
deps = [
|
||||
":cpu_device_target",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_context",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/compiler/xla/client/lib:quantize",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpu_device_target",
|
||||
srcs = [
|
||||
"cpu_device_target.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"cpu_device_target.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite/quantization:device_target",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_context",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "quantize",
|
||||
srcs = [
|
||||
"quantize.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"quantize.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/compiler/mlir/xla:hlo_to_mlir_hlo",
|
||||
"//tensorflow/compiler/tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/core/platform:status",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "cpu_kernel_fusion_inc_gen",
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-rewriters",
|
||||
"generated_cpu_kernel_fusion.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "cpu_kernel_fusion.td",
|
||||
td_srcs = [
|
||||
"@llvm-project//mlir:StdOpsTdFiles",
|
||||
"//tensorflow/compiler/mlir/xla:hlo_ops_td_files",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
],
|
||||
)
|
||||
@ -1,67 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.h"
|
||||
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
namespace ph = std::placeholders;
|
||||
|
||||
CpuDeviceTarget::CpuDeviceTarget(MLIRContext* ctx) : DeviceTarget(ctx) {
|
||||
RegisterKernel("generic.concat", {qi8_, qi8_, qi8_},
|
||||
quant::ScaleConstraintType::OutputInputSameScale);
|
||||
|
||||
// TODO(fengliuai): All the combinations are required to list. We need to
|
||||
// improve this.
|
||||
RegisterKernel("generic.reshape", {qi8_, any_},
|
||||
quant::ScaleConstraintType::OutputInputSameScale);
|
||||
RegisterKernel("generic.reshape", {any_, qi8_},
|
||||
quant::ScaleConstraintType::OutputInputSameScale);
|
||||
|
||||
RegisterKernel("generic.mul", {qi8_, qi8_, qi8_},
|
||||
quant::ScaleConstraintType::OutputInputFreeScale);
|
||||
RegisterKernel("generic.mul_add", {qi8_, qi8n_, any_, qi8_},
|
||||
std::bind(&CpuDeviceTarget::HandleMultiplyAccumulateScale,
|
||||
this, ph::_1, ph::_2, ph::_3, ph::_4));
|
||||
RegisterKernel("generic.matmul_add", {qi8_, qi8n_, any_, qi8_},
|
||||
std::bind(&CpuDeviceTarget::HandleMultiplyAccumulateScale,
|
||||
this, ph::_1, ph::_2, ph::_3, ph::_4));
|
||||
}
|
||||
|
||||
LogicalResult CpuDeviceTarget::HandleMultiplyAccumulateScale(
|
||||
quant::QuantizeContext* ctx, Operation* op,
|
||||
quant::AdjacentOperations* new_items, bool* changed) {
|
||||
auto bias_params = ctx->GetOperandParams(op, 2);
|
||||
if (!EmptyParams(bias_params)) {
|
||||
return success();
|
||||
}
|
||||
std::vector<quant::QuantParams> op_types{ctx->GetOperandParams(op, 0),
|
||||
ctx->GetOperandParams(op, 1)};
|
||||
auto bias_scale = GetUniformQuantizedTypeForBias(op_types);
|
||||
if (bias_scale && ctx->SetOperandParams(op, 2, bias_scale)) {
|
||||
*changed = true;
|
||||
new_items->push_back(op->getOperand(2).getDefiningOp());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
||||
@ -1,40 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_CPU_DEVICE_TARGET_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_CPU_DEVICE_TARGET_H_
|
||||
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
// Target specs for cpu kernels
|
||||
class CpuDeviceTarget : public quant::DeviceTarget {
|
||||
public:
|
||||
explicit CpuDeviceTarget(MLIRContext* ctx);
|
||||
|
||||
private:
|
||||
LogicalResult HandleMultiplyAccumulateScale(
|
||||
quant::QuantizeContext* ctx, Operation* op,
|
||||
quant::AdjacentOperations* new_items, bool* changed);
|
||||
};
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_CPU_DEVICE_TARGET_H_
|
||||
@ -1,252 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <initializer_list>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/quantize.h"
|
||||
|
||||
#define DEBUG_TYPE "quant-kernel-fusion"
|
||||
|
||||
constexpr int kFakeQuantOperandsNum = 5;
|
||||
constexpr int kFakeQuantPerChannelOperandsNum = 6;
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
namespace {
|
||||
|
||||
TypeAttr GetQuantSpec(Operation* op) {
|
||||
auto fake_quant = llvm::dyn_cast_or_null<CustomCallOp>(op);
|
||||
if (!fake_quant || fake_quant.getNumOperands() < kFakeQuantOperandsNum ||
|
||||
fake_quant.getNumOperands() > kFakeQuantPerChannelOperandsNum ||
|
||||
fake_quant.call_target_name() != "fake_quant_with_min_max_vars")
|
||||
return {};
|
||||
|
||||
DenseFPElementsAttr min, max;
|
||||
DenseIntElementsAttr bit_width, narrow_range, quant_dim;
|
||||
if (!matchPattern(fake_quant.getOperand(1), m_Constant(&min)) ||
|
||||
!matchPattern(fake_quant.getOperand(2), m_Constant(&max)) ||
|
||||
!matchPattern(fake_quant.getOperand(3), m_Constant(&bit_width)) ||
|
||||
!matchPattern(fake_quant.getOperand(4), m_Constant(&narrow_range)))
|
||||
return {};
|
||||
|
||||
auto bit_width_val = (*bit_width.attr_value_begin()).cast<IntegerAttr>();
|
||||
auto narrow_range_val = (*narrow_range.int_value_begin()).getSExtValue();
|
||||
int quant_dim_val = -1;
|
||||
if (fake_quant.getNumOperands() == kFakeQuantPerChannelOperandsNum &&
|
||||
matchPattern(fake_quant.getOperand(kFakeQuantPerChannelOperandsNum - 1),
|
||||
m_Constant(&quant_dim))) {
|
||||
quant_dim_val = (*quant_dim.int_value_begin()).getSExtValue();
|
||||
}
|
||||
|
||||
OpBuilder builder(op);
|
||||
Type input_type =
|
||||
fake_quant.getOperand(0).getType().cast<ShapedType>().getElementType();
|
||||
return quant::GetQuantizedTypeAttr(
|
||||
builder, input_type, min, max, quant_dim_val, bit_width_val,
|
||||
builder.getBoolAttr(narrow_range_val), /*is_signed=*/true);
|
||||
}
|
||||
|
||||
// Collects input values from outside for 'ops'.
|
||||
void CollectInputs(llvm::ArrayRef<Operation*> ops,
|
||||
llvm::SmallVectorImpl<Value>* inputs,
|
||||
llvm::SmallVectorImpl<Attribute>* input_specs) {
|
||||
for (Operation* op : ops) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
if (std::find(inputs->begin(), inputs->end(), operand) != inputs->end()) {
|
||||
continue;
|
||||
}
|
||||
if (Operation* def_op = operand.getDefiningOp()) {
|
||||
if (std::find(ops.begin(), ops.end(), def_op) == ops.end()) {
|
||||
inputs->push_back(operand);
|
||||
}
|
||||
} else { // argument value
|
||||
inputs->push_back(operand);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (Value input : *inputs) {
|
||||
ShapedType input_type = input.getType().cast<ShapedType>();
|
||||
if (TypeAttr spec = GetQuantSpec(input.getDefiningOp())) {
|
||||
input_specs->push_back(spec);
|
||||
} else {
|
||||
input_specs->push_back(TypeAttr::get(input_type.getElementType()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Collects values that are produced by 'ops' and have use outside of 'ops'.
|
||||
// TODO(fengliuai): if it is a single user and QDQ, write that to the specs.
|
||||
void CollectRets(llvm::ArrayRef<Operation*> ops,
|
||||
llvm::SmallVectorImpl<Value>* rets,
|
||||
llvm::SmallVectorImpl<Type>* ret_types,
|
||||
llvm::SmallVectorImpl<Attribute>* ret_specs) {
|
||||
for (Operation* op : ops) {
|
||||
for (Value result : op->getResults()) {
|
||||
for (Operation* user : result.getUsers()) {
|
||||
// If there are any user outside of 'ops'
|
||||
if (std::find(ops.begin(), ops.end(), user) == ops.end()) {
|
||||
ShapedType ret_type = result.getType().cast<ShapedType>();
|
||||
rets->push_back(result);
|
||||
ret_types->push_back(ret_type);
|
||||
if (TypeAttr spec = GetQuantSpec(user)) {
|
||||
ret_specs->push_back(spec);
|
||||
} else {
|
||||
ret_specs->push_back(TypeAttr::get(ret_type.getElementType()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value, 0> fuseOps(PatternRewriter* rewriter,
|
||||
const std::initializer_list<Value>& results,
|
||||
StringRef kernel) {
|
||||
// Collect all the operations to be fused.
|
||||
llvm::SmallVector<Operation*, 4> fused;
|
||||
llvm::SmallVector<Location, 4> locs;
|
||||
fused.reserve(results.size());
|
||||
locs.reserve(results.size());
|
||||
for (auto value : results) {
|
||||
Operation* op = value.getDefiningOp();
|
||||
fused.push_back(op);
|
||||
locs.push_back(op->getLoc());
|
||||
}
|
||||
|
||||
// Collect inputs from outside to 'ops'.
|
||||
llvm::SmallVector<Value, 4> inputs;
|
||||
llvm::SmallVector<Attribute, 4> input_specs;
|
||||
CollectInputs(fused, &inputs, &input_specs);
|
||||
|
||||
// Collect outputs from 'ops' to outside.
|
||||
llvm::SmallVector<Value, 4> rets;
|
||||
llvm::SmallVector<Type, 4> ret_types;
|
||||
llvm::SmallVector<Attribute, 4> ret_specs;
|
||||
CollectRets(fused, &rets, &ret_types, &ret_specs);
|
||||
|
||||
// Create the region op with the return.
|
||||
auto region = rewriter->create<quant::QuantizeRegionOp>(
|
||||
rewriter->getFusedLoc(locs), ret_types, inputs,
|
||||
rewriter->getArrayAttr(input_specs), rewriter->getArrayAttr(ret_specs),
|
||||
kernel);
|
||||
auto* body = new Block();
|
||||
region.body().push_back(body);
|
||||
|
||||
OpBuilder builder = OpBuilder::atBlockEnd(body);
|
||||
BlockAndValueMapping mapping;
|
||||
|
||||
// Make block arguments and add it to the block value mapping.
|
||||
for (Value input : inputs) {
|
||||
mapping.map(input, body->addArgument(input.getType()));
|
||||
}
|
||||
|
||||
// Clone the operations 'ops' to the region.
|
||||
for (Operation* op : fused) {
|
||||
builder.clone(*op, mapping);
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value, 4> new_rets;
|
||||
new_rets.reserve(rets.size());
|
||||
for (auto ret : llvm::enumerate(rets)) {
|
||||
Value new_ret = mapping.lookupOrNull(ret.value());
|
||||
assert(new_ret && "couldn't find return value.");
|
||||
new_rets.push_back(new_ret);
|
||||
ret.value().replaceAllUsesWith(region.getResult(ret.index()));
|
||||
}
|
||||
builder.create<quant::ReturnOp>(builder.getUnknownLoc(), new_rets);
|
||||
|
||||
LLVM_DEBUG({
|
||||
assert(region.verify().Success && "failed to create quant region.");
|
||||
llvm::dbgs() << "\ncreated region: ";
|
||||
region.print(llvm::dbgs());
|
||||
llvm::dbgs() << "\n\n\n";
|
||||
});
|
||||
|
||||
SmallVector<Value, 0> new_values(fused.back()->getNumResults());
|
||||
return new_values;
|
||||
}
|
||||
|
||||
struct CpuKernelFusionPass : public FunctionPass<CpuKernelFusionPass> {
|
||||
explicit CpuKernelFusionPass() = default;
|
||||
CpuKernelFusionPass(const CpuKernelFusionPass&) {}
|
||||
|
||||
void runOnFunction() override;
|
||||
|
||||
private:
|
||||
LogicalResult fuseCpuKernels(Operation* op);
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/generated_cpu_kernel_fusion.inc"
|
||||
|
||||
LogicalResult CpuKernelFusionPass::fuseCpuKernels(Operation* op) {
|
||||
MLIRContext* ctx = op->getContext();
|
||||
OwningRewritePatternList patterns;
|
||||
populateWithGenerated(ctx, &patterns);
|
||||
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<quant::QuantizationDialect>();
|
||||
target.addLegalOp<CallOp, ModuleOp, FuncOp, ModuleTerminatorOp,
|
||||
::mlir::ReturnOp>();
|
||||
return applyPartialConversion(op, target, patterns);
|
||||
}
|
||||
|
||||
void CpuKernelFusionPass::runOnFunction() {
|
||||
if (failed(fuseCpuKernels(getFunction()))) signalPassFailure();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the xla_hlo cpu kernel fusion pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateCpuKernelFusionPass() {
|
||||
return std::make_unique<CpuKernelFusionPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<CpuKernelFusionPass> pass(
|
||||
"xla-hlo-cpu-fusion", "Fuse xla hlo ops into cpu kernels");
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
||||
@ -1,174 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass quantize the constant and rewrite the quantization
|
||||
// ops by xla_hlo primitive ops.
|
||||
#include <cstdint>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/quantize.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The pass to materialize the quantization results by xla primitive ops.
|
||||
//
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
namespace {
|
||||
|
||||
// This pattern matches the "constant->qcast->dcast" pattern and replaces it by
|
||||
// "quantized constant->xla_hlo.dequantize". If it only matches the
|
||||
// "non-constant->qcast->dcast" pattern, it will remove both the "qcast->dcast".
|
||||
// We chain the pattern as a whole to bypass the type checks of the normal
|
||||
// xla_hlo ops.
|
||||
// TODO(fengliuai): make this pass work for bf16 input.
|
||||
class RewriteDequantize : public OpRewritePattern<quant::DequantizeCastOp> {
|
||||
public:
|
||||
explicit RewriteDequantize(int64_t size, MLIRContext *context)
|
||||
: OpRewritePattern<quant::DequantizeCastOp>(context), size_(size) {}
|
||||
|
||||
LogicalResult matchAndRewrite(quant::DequantizeCastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// quant.dcast
|
||||
// xla_hlo dequantize only takes min/max, so let's recover them from
|
||||
// the quantization parameters.
|
||||
Value dcast = op.arg();
|
||||
auto type = quant::QuantizedType::getQuantizedElementType(dcast.getType());
|
||||
if (!type || !type.isa<quant::UniformQuantizedType>()) {
|
||||
return failure();
|
||||
}
|
||||
auto qtype = type.cast<quant::UniformQuantizedType>();
|
||||
double scale = qtype.getScale();
|
||||
int64_t zero_point = qtype.getZeroPoint();
|
||||
float min = scale * (qtype.getStorageTypeMin() - zero_point);
|
||||
float max = scale * (qtype.getStorageTypeMax() - zero_point);
|
||||
|
||||
// quant.qcast
|
||||
auto qcast =
|
||||
llvm::dyn_cast_or_null<quant::QuantizeCastOp>(dcast.getDefiningOp());
|
||||
if (!qcast) return failure();
|
||||
|
||||
// constant
|
||||
DenseFPElementsAttr attr;
|
||||
// If it isn't a floating-point constant or the size is too small, let's
|
||||
// remove the quantization. Also the last dimension size should be a
|
||||
// multiplier of 4, so the shape isn't broken during packing and unpacking.
|
||||
if (!matchPattern(qcast.arg(), m_Constant(&attr)) ||
|
||||
attr.getNumElements() <= size_ ||
|
||||
attr.getType().getDimSize(attr.getType().getRank() - 1) % 4 != 0) {
|
||||
op.getResult().replaceAllUsesWith(qcast.arg());
|
||||
return success();
|
||||
}
|
||||
// TODO(fengliuai): implement transpose if it has high dimension.
|
||||
|
||||
// Create the quantized result
|
||||
auto quantized_result =
|
||||
quant::Quantize(attr, qtype).dyn_cast_or_null<DenseIntElementsAttr>();
|
||||
if (!quantized_result) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Pack the uint8 bits to uint32. The shape is changed from from
|
||||
// [n0, n1, ..., nk] to [n0, n1, ..., nk / 4].
|
||||
std::vector<uint8_t> raw_data;
|
||||
for (auto d : quantized_result.getValues<uint8_t>()) {
|
||||
raw_data.push_back(d);
|
||||
}
|
||||
// The packing might increase the data size by paddings.
|
||||
auto packed_data = xla::PackToUint32<uint8_t>(raw_data);
|
||||
auto packed_shape = attr.getType().getShape().vec();
|
||||
int lower_dims = std::accumulate(
|
||||
packed_shape.begin(),
|
||||
std::next(packed_shape.begin(), packed_shape.size() - 1), 1,
|
||||
std::multiplies<int>());
|
||||
packed_shape[packed_shape.size() - 1] = packed_data.size() / lower_dims;
|
||||
auto packed_type =
|
||||
RankedTensorType::get(packed_shape, rewriter.getIntegerType(32));
|
||||
|
||||
auto packed_quantized_result =
|
||||
DenseElementsAttr::get<uint32_t>(packed_type, packed_data);
|
||||
auto quantized_constant =
|
||||
rewriter.create<ConstantOp>(qcast.getLoc(), packed_quantized_result);
|
||||
|
||||
// Create the xla dequantize op with bf16 output
|
||||
auto dequantized_type = RankedTensorType::get(attr.getType().getShape(),
|
||||
rewriter.getBF16Type());
|
||||
auto dequantize = rewriter.create<DequantizeOp>(
|
||||
qcast.getLoc(), dequantized_type, quantized_constant,
|
||||
rewriter.getF32FloatAttr(min), rewriter.getF32FloatAttr(max),
|
||||
rewriter.getStringAttr("MIN_COMBINED"), rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false));
|
||||
|
||||
// Convert bf16 output back to f32
|
||||
rewriter.replaceOpWithNewOp<ConvertOp>(op, op.getResult().getType(),
|
||||
dequantize);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t size_;
|
||||
};
|
||||
|
||||
// Materialize the quantization results by hlo primitive ops.
|
||||
struct MaterializeToXlaPass : public FunctionPass<MaterializeToXlaPass> {
|
||||
explicit MaterializeToXlaPass() = default;
|
||||
MaterializeToXlaPass(const MaterializeToXlaPass &) {}
|
||||
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
void MaterializeToXlaPass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
MLIRContext *ctx = &getContext();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
// TODO(fengliuai): make the size 6 configurable.
|
||||
patterns.insert<RewriteDequantize>(6, ctx);
|
||||
|
||||
applyPatternsGreedily(func, patterns);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the xla_hlo dialect quantization propagation pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateMaterializeToXlaPass() {
|
||||
return std::make_unique<MaterializeToXlaPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<MaterializeToXlaPass> pass(
|
||||
"xla-hlo-materialize-quant",
|
||||
"Materialize the quantization results by xla primitve ops");
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
||||
@ -1,7 +0,0 @@
|
||||
// TODO(fengliuai): automatically generate this file
|
||||
// TODO(fengliuai): add all the xla_hlo ops
|
||||
|
||||
static std::unique_ptr<quant::OpQuantSpec> GetOpQuantSpec(mlir::Operation *op) {
|
||||
auto spec = absl::make_unique<quant::OpQuantSpec>();
|
||||
return spec;
|
||||
}
|
||||
@ -1,107 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass applies quantization propagation on xla_hlo dialect.
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.h"
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> disable_per_channel(
|
||||
"xla-disable-per-channel", llvm::cl::value_desc("bool"),
|
||||
llvm::cl::desc("Whether disable per-channel quantized weights."),
|
||||
llvm::cl::init(false));
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The quantization propagation Pass.
|
||||
//
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
namespace {
|
||||
|
||||
// Applies the quantization propagation on the input function. During the
|
||||
// propagation, two facts are respected:
|
||||
// - The quantization type (params) of the ops in the function
|
||||
// - The quantization spec for the ops
|
||||
// The propagation results should assign quantization types to all the tensors
|
||||
// and the two restrictions are respected.
|
||||
struct PropagateQuantPass : public FunctionPass<PropagateQuantPass> {
|
||||
explicit PropagateQuantPass() = default;
|
||||
PropagateQuantPass(const PropagateQuantPass &) {}
|
||||
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/op_quant_spec.inc"
|
||||
|
||||
void PropagateQuantPass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
// TODO(fengliuai): deprecate this old code generation path.
|
||||
// XLA only support uint8/uint16 quantization for now.
|
||||
ApplyQuantizationParamsPropagation(func, /*is_signed*/ false,
|
||||
disable_per_channel, GetOpQuantSpec);
|
||||
|
||||
CpuDeviceTarget spec(&getContext());
|
||||
quant::QuantizeContext ctx(func, spec);
|
||||
|
||||
std::vector<quant::QuantizeRegionOp> work_list = ctx.GetAllOps();
|
||||
bool changed = false;
|
||||
while (!work_list.empty()) {
|
||||
quant::QuantizeRegionOp op = work_list.back();
|
||||
work_list.pop_back();
|
||||
|
||||
llvm::SmallVector<Operation *, 4> new_items;
|
||||
if (failed(ctx.Handle(op, &new_items, &changed))) {
|
||||
// The IR is still valid, thus we shouldn't fail.
|
||||
signalPassFailure();
|
||||
}
|
||||
for (auto item : new_items) {
|
||||
if (auto reg = llvm::dyn_cast_or_null<quant::QuantizeRegionOp>(item))
|
||||
work_list.push_back(reg);
|
||||
}
|
||||
}
|
||||
|
||||
if (!changed) return;
|
||||
|
||||
if (failed(ctx.Finalize())) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the xla_hlo dialect quantization propagation pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreatePropagateQuantPass() {
|
||||
return std::make_unique<PropagateQuantPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<PropagateQuantPass> pass(
|
||||
"xla-hlo-propagate-quant", "Propagate quantization information");
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
||||
@ -1,74 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
static void RegisterDialects() {
|
||||
static bool init_once = []() {
|
||||
mlir::registerDialect<mlir::xla_hlo::XlaHloDialect>();
|
||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
return true;
|
||||
}();
|
||||
(void)init_once;
|
||||
}
|
||||
|
||||
// Quantizes the model in the computation.
|
||||
tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config,
|
||||
xla::XlaComputation* computation) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> snapshot,
|
||||
computation->Snapshot());
|
||||
|
||||
RegisterDialects();
|
||||
MLIRContext context;
|
||||
OwningModuleRef module = ModuleOp::create(UnknownLoc::get(&context));
|
||||
auto status = xla::ConvertHloToMlirHlo(
|
||||
module.get(), snapshot->mutable_hlo()->mutable_hlo_module());
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Hlo module import failed: " << status;
|
||||
return status;
|
||||
}
|
||||
|
||||
PassManager pm(&context);
|
||||
pm.addPass(createCanonicalizerPass());
|
||||
pm.addPass(createInlinerPass());
|
||||
pm.addPass(createSymbolDCEPass());
|
||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||
|
||||
mlir::StatusScopedDiagnosticHandler diag_handler(&context);
|
||||
LogicalResult result = pm.run(module.get());
|
||||
(void)result;
|
||||
|
||||
module->dump();
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
||||
@ -1,46 +0,0 @@
|
||||
// RUN: tf-opt -xla-hlo-cpu-fusion %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @mul_add_source
|
||||
func @mul_add_source(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
|
||||
%0 = "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%1 = "xla_hlo.add"(%0, %arg2) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %1 : tensor<4xf32>
|
||||
|
||||
// CHECK: %[[region:.*]] = "quant.region"(%arg0, %arg1, %arg2) ( {
|
||||
// CHECK: ^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
|
||||
// CHECK: %[[mul:.*]] = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
// CHECK: %[[add:.*]] = xla_hlo.add %[[mul]], %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
// CHECK: "quant.return"(%[[add]]) : (tensor<4xf32>) -> ()
|
||||
// CHECK: }) {input_specs = [f32, f32, f32], logical_kernel = "generic.mul_add", output_specs = [f32]} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK: return %[[region]] : tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @mul_add_annotated
|
||||
func @mul_add_annotated(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>, %arg2: tensor<2x4xf32>) -> (tensor<2x4xf32>) {
|
||||
%cst = constant dense<0.0> : tensor<f32>
|
||||
%cst_0 = constant dense<255.0> : tensor<f32>
|
||||
%cst_1 = constant dense<8> : tensor<i32>
|
||||
%cst_2 = constant dense<false> : tensor<i1>
|
||||
%qin = "xla_hlo.custom_call"(%arg0, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars",
|
||||
has_side_effect = false, name = "custom-call.1"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
|
||||
%qw = "xla_hlo.custom_call"(%arg1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars",
|
||||
has_side_effect = false, name = "custom-call.2"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
|
||||
%0 = "xla_hlo.multiply"(%qin, %qw) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
%1 = "xla_hlo.add"(%0, %arg2) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
%r = "xla_hlo.custom_call"(%1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars",
|
||||
has_side_effect = false, name = "custom-call.3"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
|
||||
return %r : tensor<2x4xf32>
|
||||
|
||||
// CHECK: %[[region:.*]] = "quant.region"
|
||||
// CHECK: ^bb0(%arg3: tensor<2x4xf32>, %arg4: tensor<2x4xf32>, %arg5: tensor<2x4xf32>): // no predecessors
|
||||
// CHECK: %[[mul:.*]] = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<2x4xf32>
|
||||
// CHECK: %[[add:.*]] = xla_hlo.add %[[mul]], %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<2x4xf32>
|
||||
// CHECK: "quant.return"(%[[add]]) : (tensor<2x4xf32>) -> ()
|
||||
// CHECK: }) {input_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>, !quant.uniform<i8:f32, 1.000000e+00:-128>, f32],
|
||||
// CHECK-SAME: logical_kernel = "generic.mul_add", output_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>]} :
|
||||
// CHECK-SAME: (tensor<2x4xf32>, tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
// CHECK: %[[r:.*]] = "xla_hlo.custom_call"(%[[region]]
|
||||
// CHECK: return %[[r]] : tensor<2x4xf32>
|
||||
}
|
||||
|
||||
|
||||
@ -1,15 +0,0 @@
|
||||
# RUN: not tfcompile --graph=%s.pbtxt --config=%s.config.pbtxt --experimental_quantize --cpp_class="::test::fadd_quant" 2>&1 | FileCheck %s -dump-input-on-failure
|
||||
|
||||
# TODO(fengliuai): update this file with the progress of the implementation
|
||||
// CHECK: func @main
|
||||
// CHECK: %cst = constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %cst_0 = constant dense<1.270000e+02> : tensor<f32>
|
||||
// CHECK: %cst_1 = constant dense<8> : tensor<i32>
|
||||
// CHECK: %cst_2 = constant dense<false> : tensor<i1>
|
||||
// CHECK: %0 = "xla_hlo.custom_call"(%arg0, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.9"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
|
||||
// CHECK: %1 = "xla_hlo.custom_call"(%arg1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.14"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
|
||||
// CHECK: %2 = xla_hlo.add %0, %1 {name = "add.15"} : tensor<2x4xf32>
|
||||
// CHECK: %3 = "xla_hlo.custom_call"(%2, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.20"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
|
||||
// CHECK: %4 = "xla_hlo.tuple"(%3) {name = "tuple.22"} : (tensor<2x4xf32>) -> tuple<tensor<2x4xf32>>
|
||||
// CHECK: return %4 : tuple<tensor<2x4xf32>>
|
||||
// CHECK: }
|
||||
@ -1,26 +0,0 @@
|
||||
feed {
|
||||
id { node_name: "input0" }
|
||||
shape {
|
||||
dim { size: 2 }
|
||||
dim { size: 4 }
|
||||
}
|
||||
}
|
||||
feed {
|
||||
id { node_name: "input1" }
|
||||
shape {
|
||||
dim { size: 2 }
|
||||
dim { size: 4 }
|
||||
}
|
||||
}
|
||||
|
||||
fetch {
|
||||
id { node_name: "Add/FakeQuantWithMinMaxVars" }
|
||||
shape {
|
||||
dim { size: 2 }
|
||||
dim { size: 4 }
|
||||
}
|
||||
}
|
||||
|
||||
conversion_options {
|
||||
custom_fake_quant_op_calls: true
|
||||
}
|
||||
@ -1,218 +0,0 @@
|
||||
node: {
|
||||
name: "Add/FakeQuantWithMinMaxVars"
|
||||
op: "FakeQuantWithMinMaxVars"
|
||||
input: "Add"
|
||||
input: "Add/FakeQuantWithMinMaxVars/min"
|
||||
input: "Add/FakeQuantWithMinMaxVars/max"
|
||||
attr: {
|
||||
key: "num_bits"
|
||||
value: {
|
||||
i: 8
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "narrow_range"
|
||||
value: {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "Add/FakeQuantWithMinMaxVars/min"
|
||||
op: "Const"
|
||||
attr: {
|
||||
key: "value"
|
||||
value: {
|
||||
tensor: {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape: {
|
||||
}
|
||||
float_val: 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "Add/FakeQuantWithMinMaxVars/max"
|
||||
op: "Const"
|
||||
attr: {
|
||||
key: "value"
|
||||
value: {
|
||||
tensor: {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape: {
|
||||
}
|
||||
float_val: 127.0
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Add"
|
||||
op: "Add"
|
||||
input: "input0/FakeQuantWithMinMaxVars"
|
||||
input: "input1/FakeQuantWithMinMaxVars"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "input0/FakeQuantWithMinMaxVars"
|
||||
op: "FakeQuantWithMinMaxVars"
|
||||
input: "input0"
|
||||
input: "input0/FakeQuantWithMinMaxVars/min"
|
||||
input: "input0/FakeQuantWithMinMaxVars/max"
|
||||
attr: {
|
||||
key: "num_bits"
|
||||
value: {
|
||||
i: 8
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "narrow_range"
|
||||
value: {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "input0/FakeQuantWithMinMaxVars/min"
|
||||
op: "Const"
|
||||
attr: {
|
||||
key: "value"
|
||||
value: {
|
||||
tensor: {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape: {
|
||||
}
|
||||
float_val: 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "input0/FakeQuantWithMinMaxVars/max"
|
||||
op: "Const"
|
||||
attr: {
|
||||
key: "value"
|
||||
value: {
|
||||
tensor: {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape: {
|
||||
}
|
||||
float_val: 127.0
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "input0"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "input1/FakeQuantWithMinMaxVars"
|
||||
op: "FakeQuantWithMinMaxVars"
|
||||
input: "input1"
|
||||
input: "input1/FakeQuantWithMinMaxVars/min"
|
||||
input: "input1/FakeQuantWithMinMaxVars/max"
|
||||
attr: {
|
||||
key: "num_bits"
|
||||
value: {
|
||||
i: 8
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "narrow_range"
|
||||
value: {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "input1/FakeQuantWithMinMaxVars/min"
|
||||
op: "Const"
|
||||
attr: {
|
||||
key: "value"
|
||||
value: {
|
||||
tensor: {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape: {
|
||||
}
|
||||
float_val: 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "input1/FakeQuantWithMinMaxVars/max"
|
||||
op: "Const"
|
||||
attr: {
|
||||
key: "value"
|
||||
value: {
|
||||
tensor: {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape: {
|
||||
}
|
||||
float_val: 127.0
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "input1"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 27
|
||||
}
|
||||
@ -1,54 +0,0 @@
|
||||
// RUN: tf-opt -xla-hlo-materialize-quant %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @quantize_rewrite
|
||||
func @quantize_rewrite(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> {
|
||||
// CHECK: %[[qcst:.*]] = constant dense<{{\[\[}}21004416], [-1056997248]]> : tensor<2x1xi32>
|
||||
// CHECK-NEXT: %[[dq:.*]] = "xla_hlo.dequantize"(%[[qcst]]) {is_16bits = false, max_range = 0.996078431 : f32, min_range = -1.00392163 : f32,
|
||||
// CHECK-SAME: mode = "MIN_COMBINED", transpose_output = false} : (tensor<2x1xi32>) -> tensor<2x4xbf16>
|
||||
// CHECK-NEXT: %[[cast:.*]] = "xla_hlo.convert"(%[[dq]]) : (tensor<2x4xbf16>) -> tensor<2x4xf32>
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[cast]] : tensor<2x4xf32>
|
||||
// CHECK-NEXT: return %[[mul]] : tensor<2x4xf32>
|
||||
|
||||
%w = constant dense<[[-1.0, -0.5, 0.0, 0.0], [0.5, 1.0, 0.0, 0.0]]> : tensor<2x4xf32>
|
||||
%q = "quant.qcast"(%w) : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
%dq = "quant.dcast"(%q) : (tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x4xf32>
|
||||
%mul = xla_hlo.multiply %arg0, %dq : tensor<2x4xf32>
|
||||
return %mul: tensor<2x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @quantize_small
|
||||
func @quantize_small(%arg0: tensor<1x4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK: %[[w:.*]] = constant dense<1.000000e+00> : tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[w]] : tensor<1x4xf32>
|
||||
// CHECK-NEXT: return %[[mul]] : tensor<1x4xf32>
|
||||
|
||||
%w = constant dense<1.0> : tensor<1x4xf32>
|
||||
%q = "quant.qcast"(%w) : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
%dq = "quant.dcast"(%q) : (tensor<1x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<1x4xf32>
|
||||
%mul = xla_hlo.multiply %arg0, %dq : tensor<1x4xf32>
|
||||
return %mul: tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @quantize_non_cst
|
||||
func @quantize_non_cst(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> {
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %arg0 : tensor<2x4xf32>
|
||||
// CHECK-NEXT: return %[[mul]] : tensor<2x4xf32>
|
||||
|
||||
%q = "quant.qcast"(%arg0) : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
%dq = "quant.dcast"(%q) : (tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x4xf32>
|
||||
%mul = xla_hlo.multiply %arg0, %dq : tensor<2x4xf32>
|
||||
return %mul: tensor<2x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @quantize_non_4x
|
||||
func @quantize_non_4x(%arg0: tensor<2x5xf32>) -> tensor<2x5xf32> {
|
||||
// CHECK: %[[w:.*]] = constant dense<1.000000e+00> : tensor<2x5xf32>
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[w]] : tensor<2x5xf32>
|
||||
// CHECK-NEXT: return %[[mul]] : tensor<2x5xf32>
|
||||
|
||||
%w = constant dense<1.0> : tensor<2x5xf32>
|
||||
%q = "quant.qcast"(%w) : (tensor<2x5xf32>) -> tensor<2x5x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
%dq = "quant.dcast"(%q) : (tensor<2x5x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x5xf32>
|
||||
%mul = xla_hlo.multiply %arg0, %dq : tensor<2x5xf32>
|
||||
return %mul: tensor<2x5xf32>
|
||||
}
|
||||
@ -1,69 +0,0 @@
|
||||
// RUN: tf-opt -xla-hlo-propagate-quant %s | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @mul_add_source_no_params
|
||||
func @mul_add_source_no_params(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
|
||||
%region = "quant.region"(%arg0, %arg1, %arg2) ( {
|
||||
^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
|
||||
%mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
%add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
"quant.return"(%add) : (tensor<4xf32>) -> ()
|
||||
}) {input_specs = [f32, f32, f32], logical_kernel = "generic.mul_add", output_specs = [f32]} :
|
||||
(tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %region : tensor<4xf32>
|
||||
|
||||
// CHECK: input_specs = [f32, f32, f32]
|
||||
// CHECK-SAME: output_specs = [f32]
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @mul_add_annotated_no_narrow_range
|
||||
func @mul_add_annotated_no_narrow_range(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
|
||||
%region = "quant.region"(%arg0, %arg1, %arg2) ( {
|
||||
^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
|
||||
%mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
%add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
"quant.return"(%add) : (tensor<4xf32>) -> ()
|
||||
}) {input_specs = [!quant.uniform<i8:f32, 1.0:-128>, !quant.uniform<i8:f32, 1.0:-128>, f32],
|
||||
logical_kernel = "generic.mul_add", output_specs = [!quant.uniform<i8:f32, 1.0:-128>]} :
|
||||
(tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %region : tensor<4xf32>
|
||||
|
||||
// CHECK: input_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>, !quant.uniform<i8:f32, 1.000000e+00:-128>, f32]
|
||||
// CHECK-SAME: output_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>]
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @mul_add_annotated
|
||||
func @mul_add_annotated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
|
||||
%region = "quant.region"(%arg0, %arg1, %arg2) ( {
|
||||
^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
|
||||
%mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
%add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
"quant.return"(%add) : (tensor<4xf32>) -> ()
|
||||
}) {input_specs = [!quant.uniform<i8:f32, 1.0:-128>, !quant.uniform<i8<-127:127>:f32, 1.0:-128>, f32],
|
||||
logical_kernel = "generic.mul_add", output_specs = [!quant.uniform<i8:f32, 1.0:-128>]} :
|
||||
(tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %region : tensor<4xf32>
|
||||
|
||||
// CHECK: input_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>, !quant.uniform<i8<-127:127>:f32, 1.000000e+00:-128>, !quant.uniform<i32:f32, 1.000000e+00>]
|
||||
// CHECK-SAME: output_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>]
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @same_scale_1_1
|
||||
func @same_scale_1_1(%arg0: tensor<1x7x7x64xf32>) -> (tensor<1x3136xf32>) {
|
||||
%region = "quant.region"(%arg0) ( {
|
||||
^bb0(%arg1: tensor<1x7x7x64xf32>): // no predecessors
|
||||
%r = "xla_hlo.reshape"(%arg1) : (tensor<1x7x7x64xf32>) -> (tensor<1x3136xf32>)
|
||||
"quant.return"(%r) : (tensor<1x3136xf32>) -> ()
|
||||
}) {input_specs = [!quant.uniform<i8:f32, 1.0>], logical_kernel = "generic.reshape", output_specs = [f32]} : (tensor<1x7x7x64xf32>) -> tensor<1x3136xf32>
|
||||
return %region : tensor<1x3136xf32>
|
||||
|
||||
// CHECK: input_specs = [!quant.uniform<i8:f32, 1.000000e+00>]
|
||||
// CHECK-SAME: output_specs = [!quant.uniform<i8:f32, 1.000000e+00>]
|
||||
}
|
||||
@ -1,25 +0,0 @@
|
||||
// RUN: tf-opt -xla-hlo-propagate-quant %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @mul
|
||||
func @mul(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[w:.*]] = constant dense<{{\[\[}}-1.000000e+00, -5.000000e-01], [5.000000e-01, 1.000000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK-NEXT: %[[q:.*]] = "quant.qcast"(%[[w]]) : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
// CHECK-NEXT: %[[dq:.*]] = "quant.dcast"(%[[q]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[dq]] : tensor<2x2xf32>
|
||||
// CHECK-NEXT: return %[[mul]] : tensor<2x2xf32>
|
||||
%w = constant dense<[[-1.0, -0.5], [0.5, 1.0]]> : tensor<2x2xf32>
|
||||
%mul = xla_hlo.multiply %arg0, %w : tensor<2x2xf32>
|
||||
return %mul: tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @add
|
||||
func @add(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[b:.*]] = constant dense<1.000000e+00> : tensor<2xf32>
|
||||
// CHECK-NEXT: %[[q:.*]] = "quant.qcast"(%[[b]]) : (tensor<2xf32>) -> tensor<2x!quant.uniform<u8:f32, 0.0039215686274509803>>
|
||||
// CHECK-NEXT: %[[dq:.*]] = "quant.dcast"(%[[q]]) : (tensor<2x!quant.uniform<u8:f32, 0.0039215686274509803>>) -> tensor<2xf32>
|
||||
// CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg0, %[[dq]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x2xf32>, tensor<2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: return %[[add]] : tensor<2x2xf32>
|
||||
%b = constant dense<1.0> : tensor<2xf32>
|
||||
%add = "xla_hlo.add"(%arg0, %b) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x2xf32>, tensor<2xf32>) -> tensor<2x2xf32>
|
||||
return %add: tensor<2x2xf32>
|
||||
}
|
||||
@ -39,7 +39,7 @@ versions {
|
||||
# CHECK-LABEL: func @main
|
||||
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<4xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<4xi32>) -> tensor<*xi32>
|
||||
# CHECK-SAME: control_outputs = ""
|
||||
# CHECK-SAME inputs = "input0,input1"
|
||||
# CHECK-SAME: inputs = "input0,input1"
|
||||
# CHECK-SAME: outputs = "output"
|
||||
# CHECK-NEXT: %[[OP:[a-z0-9]+]] = "tf.BannaPotatoSaladWithColeslaw"(%[[ARG_0]], %[[ARG_1]]) {T = i32, device = ""} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
|
||||
# CHECK-NEXT: return %[[OP]] : tensor<*xi32>
|
||||
|
||||
@ -12,6 +12,7 @@ glob_lit_tests(
|
||||
test_file_exts = [
|
||||
"mlir",
|
||||
"cc",
|
||||
"json",
|
||||
],
|
||||
)
|
||||
|
||||
@ -22,8 +23,10 @@ filegroup(
|
||||
data = [
|
||||
":importer_test_legacy_reshape",
|
||||
":importer_test_min_max",
|
||||
":test_schema.fbs",
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_translate",
|
||||
"//tensorflow/compiler/mlir/lite:json_to_flatbuffer",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
],
|
||||
)
|
||||
|
||||
@ -0,0 +1,83 @@
|
||||
// RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --dump-input-on-failure %s
|
||||
|
||||
// CHECK: %cst = constant unit
|
||||
// CHECK: %[[RES0:.*]] = "tfl.conv_2d"(%arg0, %arg1, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 0 : i32, stride_w = 0 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, none) -> tensor<256x32x32x16xf32>
|
||||
// CHECK: return %[[RES0]] : tensor<256x32x32x16xf32>
|
||||
|
||||
{
|
||||
version: 3,
|
||||
operator_codes: [
|
||||
{
|
||||
builtin_code: "CONV_2D",
|
||||
}
|
||||
],
|
||||
subgraphs: [
|
||||
{
|
||||
tensors: [
|
||||
{
|
||||
shape: [
|
||||
256,
|
||||
32,
|
||||
32,
|
||||
3
|
||||
],
|
||||
name: "arg0",
|
||||
quantization: {
|
||||
}
|
||||
},
|
||||
{
|
||||
shape: [
|
||||
16,
|
||||
3,
|
||||
3,
|
||||
3
|
||||
],
|
||||
name: "arg1",
|
||||
quantization: {
|
||||
}
|
||||
},
|
||||
{
|
||||
shape: [
|
||||
0
|
||||
],
|
||||
name: "cst"
|
||||
},
|
||||
{
|
||||
shape: [
|
||||
256,
|
||||
32,
|
||||
32,
|
||||
16
|
||||
],
|
||||
name: "output",
|
||||
quantization: {
|
||||
}
|
||||
},
|
||||
],
|
||||
inputs: [
|
||||
0,
|
||||
1
|
||||
],
|
||||
outputs: [
|
||||
3
|
||||
],
|
||||
operators: [
|
||||
{
|
||||
inputs: [
|
||||
0,
|
||||
1,
|
||||
-1
|
||||
],
|
||||
outputs: [
|
||||
3
|
||||
],
|
||||
builtin_options_type: "Conv2DOptions",
|
||||
builtin_options: {
|
||||
}
|
||||
}
|
||||
],
|
||||
name: "main"
|
||||
}
|
||||
],
|
||||
description: "MLIR Converted."
|
||||
}
|
||||
@ -1,15 +1,15 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||
// Ensure lstm roundtrip exactly
|
||||
|
||||
func @main(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>, %arg4: tensor<4 x f32>, %arg5: tensor<4 x f32>, %arg6: tensor<4 x f32>, %arg7: tensor<4 x f32>, %arg8: tensor<4 x f32>, %arg9: tensor<4 x f32>, %arg10: tensor<4 x f32>, %arg11: tensor<4 x f32>, %arg12: tensor<4 x f32>, %arg13: tensor<4 x f32>, %arg14: tensor<4 x f32>, %arg15: tensor<4 x f32>, %arg16: tensor<4 x f32>, %arg17: tensor<4 x f32>, %arg18: tensor<4 x f32>, %arg19: tensor<4 x f32>, %arg20: tensor<4 x f32>, %arg21: tensor<4 x f32>) -> tensor<4 x f32> {
|
||||
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %24 : tensor<4xf32>
|
||||
func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
|
||||
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
|
||||
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %24 : tensor<1x4xf32>
|
||||
// CHECK-LABEL: main
|
||||
// seperate lines since there is no region for this op. third_party/tensorflow/compiler/mlir/lite/ir/tfl_ops.td: 3252
|
||||
// CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg22, %arg23, %arg18, %arg19, %arg20, %arg21) ( {
|
||||
// CHECK: }) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK: }) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK: return %[[RES0]]
|
||||
|
||||
}
|
||||
|
||||
@ -0,0 +1,78 @@
|
||||
// RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --dump-input-on-failure %s
|
||||
|
||||
// This test is to test that if the flatbuffer omits the last optional input `bias` of tfl.conv_2d op, the flatbuffer_importer will automatically adds `none` value to tfl.conv_2d.
|
||||
|
||||
// CHECK: %cst = constant unit
|
||||
// CHECK: %[[RES0:.*]] = "tfl.conv_2d"(%arg0, %arg1, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 0 : i32, stride_w = 0 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, none) -> tensor<256x32x32x16xf32>
|
||||
// CHECK: return %[[RES0]] : tensor<256x32x32x16xf32>
|
||||
|
||||
{
|
||||
version: 3,
|
||||
operator_codes: [
|
||||
{
|
||||
builtin_code: "CONV_2D",
|
||||
}
|
||||
],
|
||||
subgraphs: [
|
||||
{
|
||||
tensors: [
|
||||
{
|
||||
shape: [
|
||||
256,
|
||||
32,
|
||||
32,
|
||||
3
|
||||
],
|
||||
name: "arg0",
|
||||
quantization: {
|
||||
}
|
||||
},
|
||||
{
|
||||
shape: [
|
||||
16,
|
||||
3,
|
||||
3,
|
||||
3
|
||||
],
|
||||
name: "arg1",
|
||||
quantization: {
|
||||
}
|
||||
},
|
||||
{
|
||||
shape: [
|
||||
256,
|
||||
32,
|
||||
32,
|
||||
16
|
||||
],
|
||||
name: "output",
|
||||
quantization: {
|
||||
}
|
||||
},
|
||||
],
|
||||
inputs: [
|
||||
0,
|
||||
1
|
||||
],
|
||||
outputs: [
|
||||
2
|
||||
],
|
||||
operators: [
|
||||
{
|
||||
inputs: [
|
||||
0,
|
||||
1
|
||||
],
|
||||
outputs: [
|
||||
2
|
||||
],
|
||||
builtin_options_type: "Conv2DOptions",
|
||||
builtin_options: {
|
||||
}
|
||||
}
|
||||
],
|
||||
name: "main"
|
||||
}
|
||||
],
|
||||
description: "MLIR Converted."
|
||||
}
|
||||
1092
tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/test_schema.fbs
Normal file
1092
tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/test_schema.fbs
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt %s -inline -mlir-disable-inline-simplify | FileCheck %s --dump-input=fail
|
||||
// RUN: tf-opt %s -inline="disable-simplify" | FileCheck %s --dump-input=fail
|
||||
|
||||
// Inline a function that contains only tfl ops.
|
||||
func @func_with_tfl_ops(%arg0 : tensor<2xi32>) -> tensor<2xi32> {
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
// RUN: tf-opt --tfl-legalize-tf-while %s -o - | FileCheck %s --dump-input-on-failure
|
||||
// RUN: tf-opt --tfl-legalize-tf-while %s -o - --tfl-legalize-tf-while --inline --mlir-disable-inline-simplify | FileCheck %s --dump-input-on-failure --check-prefix=INLINE
|
||||
// RUN: tf-opt --tfl-legalize-tf-while %s -o - --tfl-legalize-tf-while --inline="disable-simplify" | FileCheck %s --dump-input-on-failure --check-prefix=INLINE
|
||||
// RUN: tf-opt --tfl-legalize-tf-while %s -o - --tfl-legalize-tf-while --inline | FileCheck %s --dump-input-on-failure --check-prefix=CANON
|
||||
|
||||
func @while_main(%arg0: tensor<?x256x256xf32>) -> (tensor<i32>, tensor<256x256xf32>, tensor<?x256x256xf32>) attributes {tf.entry_function = {inputs = "input", outputs = "Identity,Identity_1,Identity_2"}} {
|
||||
|
||||
@ -9,6 +9,20 @@ func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testAddHighDimsHaveSameShape
|
||||
func @testAddHighDimsHaveSameShape(%arg0: tensor<1x2x3x4x5x6x7x8xi32>, %arg1: tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32> {
|
||||
// CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"}
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6x7x8xi32>, tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6x7x8xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testAddTooHighBroadcastableDims
|
||||
func @testAddTooHighBroadcastableDims(%arg0: tensor<1x2x3x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
||||
// expected-error @+1 {{'tfl.add' op failed to verify that operand #0 and operand #1 have the same shape or broadcastable shapes within the rank 4}}
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
||||
}
|
||||
|
||||
func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
|
||||
%2 = "tf.LeakyRelu"(%arg0) {alpha = 0.1 : f32} : (tensor<1xf32>) -> tensor<1xf32>
|
||||
return %2: tensor<1xf32>
|
||||
@ -38,14 +52,14 @@ func @squeezeAndReshape(%arg0: tensor<1x1x10xf32>, %arg1: tensor<?x10xf32>) -> i
|
||||
%1 = "tf.Squeeze"(%arg1) : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||
%2 = "tf.Const"() { value = dense<[2, 5]> : tensor<2xi32> } : () -> tensor<2xi32>
|
||||
%3 = "tf.Reshape" (%0, %2) : (tensor<1x10xf32>, tensor<2xi32>) -> tensor<2x5xf32>
|
||||
%4 = "some_op"(%1, %3) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
|
||||
%4 = "tf.some_op"(%1, %3) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
|
||||
return %4 : i32
|
||||
// CHECK-LABEL: squeezeAndReshape
|
||||
// CHECK: "tfl.squeeze"(%arg0) {squeeze_dims = [0]} : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
|
||||
// CHECK: %1 = "tfl.squeeze"(%arg1) {squeeze_dims = []} : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||
// CHECK: %cst = constant dense<[2, 5]> : tensor<2xi32>
|
||||
// CHECK: %2 = "tfl.reshape"(%0, %cst) : (tensor<1x10xf32>, tensor<2xi32>) -> tensor<2x5xf32>
|
||||
// CHECK: %3 = "some_op"(%1, %2) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
|
||||
// CHECK: %3 = "tf.some_op"(%1, %2) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
@ -1448,7 +1462,7 @@ func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3
|
||||
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<f32>) -> tensor<3x3xf32>
|
||||
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
|
||||
// CHECK return [[MUL]] : tensor<3x3xf32>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xf32>
|
||||
}
|
||||
|
||||
func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
|
||||
@ -1459,5 +1473,5 @@ func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3
|
||||
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<i32>
|
||||
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<i32>) -> tensor<3x3xi32>
|
||||
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32>
|
||||
// CHECK return [[MUL]] : tensor<3x3xi32>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xi32>
|
||||
}
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt -tfl-load-recipe %s | FileCheck %s --dump-input-on-failure
|
||||
// RUN: tf-opt -allow-unregistered-dialect -tfl-load-recipe %s | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// CHECK-LABEL: testLstm
|
||||
func @testLstm(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>, %arg3: tensor<*xf32>, %arg4: tensor<*xf32>, %arg5: tensor<*xf32>, %arg6: tensor<*xf32>, %arg7: tensor<*xf32>, %arg8: tensor<*xf32>, %arg9: tensor<*xf32>, %arg10: tensor<*xf32>, %arg11: tensor<*xf32>, %arg12: tensor<*xf32>, %arg13: tensor<*xf32>, %arg14: tensor<*xf32>, %arg15: tensor<*xf32>, %arg16: tensor<*xf32>, %arg17: tensor<*xf32>, %arg18: tensor<*xf32>, %arg19: tensor<*xf32>, %arg20: tensor<*xf32>, %arg21: tensor<*xf32>, %arg22: tensor<*xf32>, %arg23: tensor<*xf32>) -> tensor<*xf32> {
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
|
||||
|
||||
func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -> tensor<4 x f32> {
|
||||
func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
@ -9,126 +9,126 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 1, 4 ],
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "arg0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "arg1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 3,
|
||||
// CHECK-NEXT: name: "arg2",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 4,
|
||||
// CHECK-NEXT: name: "arg3",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 5,
|
||||
// CHECK-NEXT: name: "arg4",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 6,
|
||||
// CHECK-NEXT: name: "arg5",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 7,
|
||||
// CHECK-NEXT: name: "arg6",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 8,
|
||||
// CHECK-NEXT: name: "arg7",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 9,
|
||||
// CHECK-NEXT: name: "arg8",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 10,
|
||||
// CHECK-NEXT: name: "arg9",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 11,
|
||||
// CHECK-NEXT: name: "arg10",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 12,
|
||||
// CHECK-NEXT: name: "arg11",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 1, 4 ],
|
||||
// CHECK-NEXT: buffer: 13,
|
||||
// CHECK-NEXT: name: "arg12",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 1, 4 ],
|
||||
// CHECK-NEXT: buffer: 14,
|
||||
// CHECK-NEXT: name: "arg13",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 1, 4 ],
|
||||
// CHECK-NEXT: buffer: 15,
|
||||
// CHECK-NEXT: name: "arg14",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 1, 4 ],
|
||||
// CHECK-NEXT: buffer: 16,
|
||||
// CHECK-NEXT: name: "arg15",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 17,
|
||||
// CHECK-NEXT: name: "arg16",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 1, 4 ],
|
||||
// CHECK-NEXT: buffer: 18,
|
||||
// CHECK-NEXT: name: "arg17",
|
||||
// CHECK-NEXT: quantization: {
|
||||
@ -163,21 +163,21 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 1, 4 ],
|
||||
// CHECK-NEXT: name: "Const",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: },
|
||||
// CHECK-NEXT: is_variable: true
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 1, 4 ],
|
||||
// CHECK-NEXT: name: "Const1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: },
|
||||
// CHECK-NEXT: is_variable: true
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 1, 4 ],
|
||||
// CHECK-NEXT: buffer: 25,
|
||||
// CHECK-NEXT: name: "tfl.lstm",
|
||||
// CHECK-NEXT: quantization: {
|
||||
@ -261,9 +261,9 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t
|
||||
// CHECK-EMPTY:
|
||||
|
||||
|
||||
^bb0(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>, %arg4: tensor<4 x f32>, %arg5: tensor<4 x f32>, %arg6: tensor<4 x f32>, %arg7: tensor<4 x f32>, %arg8: tensor<4 x f32>, %arg9: tensor<4 x f32>, %arg10: tensor<4 x f32>, %arg11: tensor<4 x f32>, %arg12: tensor<4 x f32>, %arg13: tensor<4 x f32>, %arg14: tensor<4 x f32>, %arg15: tensor<4 x f32>, %arg16: tensor<4 x f32>, %arg17: tensor<4 x f32>, %arg18: tensor<4 x f32>, %arg19: tensor<4 x f32>, %arg20: tensor<4 x f32>, %arg21: tensor<4 x f32>):
|
||||
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %24 : tensor<4xf32>
|
||||
^bb0(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>):
|
||||
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
|
||||
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
|
||||
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %24 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
@ -675,6 +675,41 @@ func @testLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// test invalid input dimension, the first input operand for lstm op should be at least 2D tensor.
|
||||
func @testLstmWithInvalidInputDimension(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>, %arg4: tensor<4 x f32>, %arg5: tensor<4 x f32>, %arg6: tensor<4 x f32>, %arg7: tensor<4 x f32>, %arg8: tensor<4 x f32>, %arg9: tensor<4 x f32>, %arg10: tensor<4 x f32>, %arg11: tensor<4 x f32>, %arg12: tensor<4 x f32>, %arg13: tensor<4 x f32>, %arg14: tensor<4 x f32>, %arg15: tensor<4 x f32>, %arg16: tensor<4 x f32>, %arg17: tensor<4 x f32>, %arg18: tensor<4 x f32>, %arg19: tensor<4 x f32>, %arg20: tensor<4 x f32>, %arg21: tensor<4 x f32>) -> tensor<4 x f32> {
|
||||
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
// expected-error @+1 {{'tfl.lstm' op the first input operand should have more than 2 dimensions.}}
|
||||
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %24 : tensor<4xf32>
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// 'input_to_output_weights' input for lstm op has unmatched rank with `input`.
|
||||
func @testLstmWithInvalidInputsRankMatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4x2xf32>, %arg2: tensor<4x2xf32>, %arg3: tensor<4x2xf32>, %arg4: tensor<4x2xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
|
||||
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
|
||||
// expected-error @+1 {{'tfl.lstm' op inputs don't match with the dimensions.}}
|
||||
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %24 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Coefficient inputs of LSTM op don't match the dimension with input operand `input_to_output_weights`.
|
||||
func @testLstmWithInvalidInputsRankMatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<3xf32>, %arg19: tensor<3xf32>, %arg20: tensor<3xf32>, %arg21: tensor<3xf32>) -> tensor<1x4xf32> {
|
||||
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
|
||||
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
|
||||
// expected-error @+1 {{'tfl.lstm' op coefficient inputs have more than 2 dimensions or don't match the dimension with input operand `input_to_output_weights`.}}
|
||||
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<1x4xf32>
|
||||
return %24 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// test invalid kernel type
|
||||
|
||||
@ -439,6 +439,31 @@ func @NotReorderReshapeAddIfNotTailingDim(%arg0: tensor<40x40x1xf32>) -> tensor<
|
||||
// CHECK: return %[[rs2]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @ReorderElementwiseValueOpAndMoveOp
|
||||
func @ReorderElementwiseValueOpAndMoveOp(%arg0: tensor<40x40x1xf32>) -> tensor<40x40xf32> {
|
||||
%shape = constant dense<[40, 40]> : tensor<2xi32>
|
||||
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<40x40x1xf32>, tensor<2xi32>) -> tensor<40x40xf32>
|
||||
%2 = "tfl.relu"(%1) : (tensor<40x40xf32>) -> tensor<40x40xf32>
|
||||
return %2 : tensor<40x40xf32>
|
||||
|
||||
// CHECK: %[[rs1:.*]] = "tfl.relu"(%arg0
|
||||
// CHECK: %[[rs2:.*]] = "tfl.reshape"(%[[rs1]]
|
||||
// CHECK: return %[[rs2]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @NotReorderElementwiseValueOpAndMoveOp
|
||||
func @NotReorderElementwiseValueOpAndMoveOp(%arg0: tensor<40x40x1xf32>) -> (tensor<40x40xf32>, tensor<40x40xf32>) {
|
||||
%shape = constant dense<[40, 40]> : tensor<2xi32>
|
||||
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<40x40x1xf32>, tensor<2xi32>) -> tensor<40x40xf32>
|
||||
%2 = "tfl.relu"(%1) : (tensor<40x40xf32>) -> tensor<40x40xf32>
|
||||
return %1, %2 : tensor<40x40xf32>, tensor<40x40xf32>
|
||||
|
||||
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0
|
||||
// CHECK: %[[rs2:.*]] = "tfl.relu"(%[[rs1]]
|
||||
// CHECK: return %[[rs1]], %[[rs2]]
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: @FuseFullyConnectedRelu
|
||||
func @FuseFullyConnectedRelu(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> {
|
||||
%0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32>
|
||||
@ -450,6 +475,28 @@ func @FuseFullyConnectedRelu(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @FuseFullyConnectedRelu6
|
||||
func @FuseFullyConnectedRelu6(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> {
|
||||
%0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32>
|
||||
%1 = "tfl.relu6"(%0) : (tensor<1x128xf32>) -> tensor<1x128xf32>
|
||||
return %1 : tensor<1x128xf32>
|
||||
|
||||
// CHECK: %[[RES:[0-9].*]] = "tfl.fully_connected"
|
||||
// CHECK-SAME: fused_activation_function = "RELU6"
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @FuseFullyConnectedRelu1
|
||||
func @FuseFullyConnectedRelu1(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> {
|
||||
%0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32>
|
||||
%1 = "tfl.relu_n1_to_1"(%0) : (tensor<1x128xf32>) -> tensor<1x128xf32>
|
||||
return %1 : tensor<1x128xf32>
|
||||
|
||||
// CHECK: %[[RES:[0-9].*]] = "tfl.fully_connected"
|
||||
// CHECK-SAME: fused_activation_function = "RELU_N1_TO_1"
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @HardSwishPattern
|
||||
func @HardSwishPattern(%arg0: tensor<1xf32>) -> tensor<1xf32> {
|
||||
%three = constant dense<3.> : tensor<f32>
|
||||
|
||||
@ -161,7 +161,7 @@ func @_functionalize_if_else_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>
|
||||
}
|
||||
|
||||
func @_functionalize_if_then_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
|
||||
%0 = "my_unknown_op.blah"() : () -> tensor<i1>
|
||||
%0 = "tf.blah"() : () -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
@ -199,7 +199,7 @@ func @_functionalize_if_else_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>
|
||||
}
|
||||
|
||||
func @_functionalize_if_then_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
|
||||
%0 = "my_unknown_op.blah"() : () -> tensor<i1>
|
||||
%0 = "tf.blah"() : () -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
|
||||
@ -29,7 +29,7 @@ limitations under the License.
|
||||
|
||||
namespace mlir {
|
||||
/// Create a pass to convert from the TFExecutor to the TF control dialect.
|
||||
std::unique_ptr<OpPassBase<FuncOp>>
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreateTFExecutorToControlDialectConversion();
|
||||
} // namespace mlir
|
||||
|
||||
@ -134,6 +134,10 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
pass_manager->addPass(
|
||||
mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
|
||||
|
||||
if (pass_config.shape_inference) {
|
||||
// Add a shape inference pass to optimize away the unnecessary casts.
|
||||
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||
}
|
||||
// Legalize while early to allow further constant folding.
|
||||
// TODO(jpienaar): This may not actually matter as we do canonicalization
|
||||
// after the legalize below, for now it needs to be below the above passes
|
||||
@ -160,11 +164,6 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
// constant ops.
|
||||
pass_manager->addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass());
|
||||
|
||||
if (pass_config.shape_inference) {
|
||||
// Add a shape inference pass to optimize away the unnecessary casts.
|
||||
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||
}
|
||||
|
||||
// The below passes only make sense if Builtin TFLite ops are enabled
|
||||
// for emission.
|
||||
if (pass_config.emit_builtin_tflite_ops) {
|
||||
@ -173,7 +172,8 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
pass_manager->addPass(
|
||||
mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul));
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
pass_manager->addPass(mlir::TFL::CreateLegalizeTFPass());
|
||||
pass_manager->addPass(
|
||||
mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
|
||||
pass_manager->addPass(mlir::TFL::CreateOptimizePass());
|
||||
// This pass operates on TensorFlow ops but is triggered after legalization
|
||||
// so that it can target constants introduced once TensorFlow Identity ops
|
||||
@ -255,7 +255,8 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
|
||||
// TFLite dialect passes.
|
||||
pm.addPass(mlir::TFL::CreatePrepareTFPass(true));
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::TFL::CreateLegalizeTFPass());
|
||||
pm.addPass(
|
||||
mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true));
|
||||
pm.addPass(mlir::TFL::CreateOptimizePass());
|
||||
pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
|
||||
|
||||
@ -268,7 +269,7 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
|
||||
|
||||
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
|
||||
|
||||
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
|
||||
pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
|
||||
}
|
||||
|
||||
// Registers a pass pipeline for the standard TFL passes.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user