Merge branch 'master' into interface_16x8

This commit is contained in:
Elena Zhelezina 2020-04-17 10:15:23 +01:00 committed by GitHub
commit e7b615dc2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3146 changed files with 151883 additions and 16454 deletions

View File

@ -73,6 +73,10 @@
# tensorflow_testing_rbe_linux: RBE options to use RBE with tensorflow-testing project on linux
# tensorflow_testing_rbe_win: RBE options to use RBE with tensorflow-testing project on windows
#
# Embedded Linux options (experimental and only tested with TFLite build yet)
# elinux: General Embedded Linux options shared by all flavors.
# elinux_aarch64: Embedded Linux options for aarch64 (ARM64) CPU support.
# elinux_armhf: Embedded Linux options for armhf (ARMv7) CPU support.
@ -432,6 +436,14 @@ build:tensorflow_testing_rbe_linux --config=rbe_linux
common:tensorflow_testing_rbe_win --remote_instance_name=projects/tensorflow-testing/instances/windows
build:tensorflow_testing_rbe_win --config=tensorflow_testing_rbe
# TFLite build configs for generic embedded Linux
build:elinux --crosstool_top=@local_config_embedded_arm//:toolchain
build:elinux --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
build:elinux_aarch64 --config=elinux
build:elinux_aarch64 --cpu=aarch64
build:elinux_armhf --config=elinux
build:elinux_armhf --cpu=armhf
# END TF REMOTE BUILD EXECUTION OPTIONS
# Default options should come above this line

View File

@ -1 +1 @@
2.0.0
3.0.0

View File

@ -10,32 +10,30 @@ labels: 'type:bug'
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:bug_template</em>
**System information**
- Have I written custom code (as opposed to using a stock
example script provided in TensorFlow):
- OS Platform and Distribution (e.g.,
Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
the issue happens on mobile device:
- TensorFlow installed from (source or
binary): - TensorFlow version (use command below):
- Python version: - Bazel
version (if compiling from source):
- GCC/Compiler version (if compiling from
source):
- CUDA/cuDNN version: - GPU model and memory:
**System information**
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
- TensorFlow installed from (source or binary):
- TensorFlow version (use command below):
- Python version:
- Bazel version (if compiling from source):
- GCC/Compiler version (if compiling from source):
- CUDA/cuDNN version:
- GPU model and memory:
You can collect some of this information using our environment capture
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
You can also obtain the TensorFlow version with:
1. TF 1.0: `python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"`
2. TF 2.0: `python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
**Describe the current behavior**
**Describe the expected behavior**
**Standalone code to reproduce the issue**
**Standalone code to reproduce the issue**
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.

View File

@ -11,32 +11,29 @@ As per our
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:performance_template</em>
**System information**
- Have I written custom code (as opposed to using a stock
example script provided in TensorFlow):
- OS Platform and Distribution (e.g.,
Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
the issue happens on mobile device:
- TensorFlow installed from (source or
binary): - TensorFlow version (use command below):
- Python version: - Bazel
version (if compiling from source):
- GCC/Compiler version (if compiling from
source):
- CUDA/cuDNN version: - GPU model and memory:
**System information**
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
- TensorFlow installed from (source or binary):
- TensorFlow version (use command below):
- Python version:
- Bazel version (if compiling from source):
- GCC/Compiler version (if compiling from source):
- CUDA/cuDNN version:
- GPU model and memory:
You can collect some of this information using our environment capture
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
You can also obtain the TensorFlow version with:
1. TF 1.0: `python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"`
2. TF 2.0: `python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
**Describe the current behavior**
**Describe the expected behavior**
**Standalone code to reproduce the issue**
**Standalone code to reproduce the issue**
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.

View File

@ -2,6 +2,10 @@
<img src="https://www.tensorflow.org/images/tf_logo_social.png">
</div>
[![Python](https://img.shields.io/pypi/pyversions/tensorflow.svg?style=plastic)](https://badge.fury.io/py/tensorflow)
[![PyPI](https://badge.fury.io/py/tensorflow.svg)](https://badge.fury.io/py/tensorflow)
**`Documentation`** |
------------------- |
[![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) |

View File

@ -2,58 +2,42 @@ package(default_visibility = ["//visibility:public"])
filegroup(
name = "gcc",
srcs = [
"bin/arm-rpi-linux-gnueabihf-gcc",
],
srcs = glob(["bin/*-gcc"]),
)
filegroup(
name = "ar",
srcs = [
"bin/arm-rpi-linux-gnueabihf-ar",
],
srcs = glob(["bin/*-ar"]),
)
filegroup(
name = "ld",
srcs = [
"bin/arm-rpi-linux-gnueabihf-ld",
],
srcs = glob(["bin/*-ld"]),
)
filegroup(
name = "nm",
srcs = [
"bin/arm-rpi-linux-gnueabihf-nm",
],
srcs = glob(["bin/*-nm"]),
)
filegroup(
name = "objcopy",
srcs = [
"bin/arm-rpi-linux-gnueabihf-objcopy",
],
srcs = glob(["bin/*-objcopy"]),
)
filegroup(
name = "objdump",
srcs = [
"bin/arm-rpi-linux-gnueabihf-objdump",
],
srcs = glob(["bin/*-objdump"]),
)
filegroup(
name = "strip",
srcs = [
"bin/arm-rpi-linux-gnueabihf-strip",
],
srcs = glob(["bin/*-strip"]),
)
filegroup(
name = "as",
srcs = [
"bin/arm-rpi-linux-gnueabihf-as",
],
srcs = glob(["bin/*-as"]),
)
filegroup(
@ -66,6 +50,16 @@ filegroup(
]),
)
filegroup(
name = "aarch64_compiler_pieces",
srcs = glob([
"aarch64-none-linux-gnu/**",
"libexec/**",
"lib/gcc/aarch64-none-linux-gnu/**",
"include/**",
]),
)
filegroup(
name = "compiler_components",
srcs = [

2
configure vendored
View File

@ -4,7 +4,7 @@ set -e
set -o pipefail
if [ -z "$PYTHON_BIN_PATH" ]; then
PYTHON_BIN_PATH=$(which python || which python3 || true)
PYTHON_BIN_PATH=$(which python3 || which python || true)
fi
# Set all env variables

View File

@ -50,7 +50,7 @@ _TF_WORKSPACE_ROOT = ''
_TF_BAZELRC = ''
_TF_CURRENT_BAZEL_VERSION = None
_TF_MIN_BAZEL_VERSION = '2.0.0'
_TF_MAX_BAZEL_VERSION = '2.0.0'
_TF_MAX_BAZEL_VERSION = '3.99.0'
NCCL_LIB_PATHS = [
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
@ -58,8 +58,6 @@ NCCL_LIB_PATHS = [
# List of files to configure when building Bazel on Apple platforms.
APPLE_BAZEL_FILES = [
'tensorflow/lite/experimental/delegates/coreml/BUILD',
'tensorflow/lite/experimental/delegates/coreml/builders/BUILD',
'tensorflow/lite/experimental/ios/BUILD',
'tensorflow/lite/experimental/objc/BUILD',
'tensorflow/lite/experimental/swift/BUILD',

View File

@ -214,6 +214,12 @@ config_setting(
visibility = ["//visibility:public"],
)
config_setting(
name = "linux_armhf",
values = {"cpu": "armhf"},
visibility = ["//visibility:public"],
)
config_setting(
name = "linux_x86_64",
values = {"cpu": "k8"},
@ -703,8 +709,8 @@ tf_cc_shared_object(
"//tensorflow/c:version_script.lds",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:distributed_tensorflow_dependencies",
"//tensorflow/core:tensorflow",
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
],
)

View File

@ -186,10 +186,6 @@ struct TF_Server {
namespace tensorflow {
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
TF_Buffer* out);

View File

@ -240,11 +240,6 @@ tf_cuda_cc_test(
"c_api_remote_test.cc",
],
extra_copts = tfe_xla_copts(),
tags = [
"guitar",
"multi_gpu",
"no_oss",
],
deps = [
":c_api",
":c_api_experimental",
@ -372,6 +367,22 @@ tf_cuda_cc_test(
],
)
cc_library(
name = "custom_device_testutil",
testonly = True,
srcs = ["custom_device_testutil.cc"],
hdrs = ["custom_device_testutil.h"],
visibility = ["//tensorflow:internal"],
deps = [
":c_api",
":c_api_experimental",
":c_api_test_util",
"//tensorflow/c:c_api",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
)
tf_cc_test(
name = "custom_device_test",
size = "small",
@ -382,6 +393,7 @@ tf_cc_test(
":c_api",
":c_api_experimental",
":c_api_test_util",
":custom_device_testutil",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/cc/profiler",

View File

@ -1587,6 +1587,7 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
// require TFE_Op* and just convert it internally a NameAttrValue, so
// consider adding an overload to the C API to make this case easier.
TFE_OpSetAttrFunction(op, attr_name, func_op);
TFE_DeleteOp(func_op);
} break;
case tensorflow::AttrValue::kList:
TF_FALLTHROUGH_INTENDED;
@ -1684,6 +1685,8 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
};
} // namespace
extern "C" {
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
const char* device_name, void* device_info,
TF_Status* status) {
@ -1694,3 +1697,5 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
status->status =
context->RegisterCustomDevice(device_name, std::move(custom_device));
}
} // extern "C"

View File

@ -515,9 +515,11 @@ typedef struct TFE_CustomDevice {
// This API is highly experimental, and in particular is expected to change when
// it starts supporting operations with attributes and when tf.function support
// is added.
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
const char* device_name, void* device_info,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_RegisterCustomDevice(TFE_Context* ctx,
TFE_CustomDevice device,
const char* device_name,
void* device_info,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx,
const char* function_name,

View File

@ -129,7 +129,45 @@ void TestRemoteExecute(bool async) {
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
void TestRemoteExecuteSilentCopies(bool async, bool remote) {
string MatMulFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {"
" name: 'MatMulFunction'"
" input_arg {"
" name: 'a'"
" type: DT_FLOAT"
" }"
" input_arg {"
" name: 'b'"
" type: DT_FLOAT"
" }"
" output_arg {"
" name: 'm'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'matmul'"
" op: 'MatMul'"
" input: 'a'"
" input: 'b'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'm'"
" value: 'matmul:product'"
" }",
&def));
return def.SerializeAsString();
}
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
@ -169,12 +207,36 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Handles are on task0 (local), and task2, but op is on task1.
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2);
TFE_Op* matmul = nullptr;
if (func) {
string function_def = MatMulFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
matmul = TFE_NewOp(ctx, "MatMulFunction", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(matmul, h0_task0, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(matmul, h1_task2, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
} else {
// Handles are on task0 (local), and task2, but op is on task1.
matmul = MatMulOp(ctx, h0_task0, h1_task2);
}
if (remote) {
TFE_OpSetDevice(matmul, task1_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
} else if (!async) {
// Set the local device to CPU to easily validate mirroring
string cpu_device_name;
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
// The input handles should never change since they have been mirrored.
ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr));
}
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
@ -182,12 +244,10 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TODO(gjn): Add support for waiting on async local mirrors
if (!async) {
if (!remote && !async) {
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
tensorflow::EagerOperation* op =
tensorflow::OperationFromInterface(matmul->operation);
// The input handles should never change since they have been mirrored.
ASSERT_EQ(op->Inputs()[1], remote_arg);
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
}
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
@ -217,6 +277,9 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteExecutor(executor);
if (func) {
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
}
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
@ -227,16 +290,22 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
}
TEST(CAPI, RemoteExecuteSilentCopies) {
TestRemoteExecuteSilentCopies(false, true);
TestRemoteExecuteSilentCopies(false, true, false);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
TestRemoteExecuteSilentCopies(true, true);
TestRemoteExecuteSilentCopies(true, true, false);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
TestRemoteExecuteSilentCopies(true, true, true);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
TestRemoteExecuteSilentCopies(false, false);
TestRemoteExecuteSilentCopies(false, false, false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
TestRemoteExecuteSilentCopies(true, false);
TestRemoteExecuteSilentCopies(true, false, false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
TestRemoteExecuteSilentCopies(true, false, true);
}
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {

View File

@ -78,11 +78,18 @@ void BM_Execute(int iters, int async) {
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = MatMulOp(ctx, m, m);
TFE_Op* matmul = TFE_NewOp(ctx, "MatMul", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
TFE_OpReset(matmul, "MatMul", nullptr, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(matmul, m, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(matmul, m, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
}
@ -113,11 +120,15 @@ void BM_Execute_Identity(int iters, int async) {
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_Op* identity = IdentityOp(ctx, m);
TFE_Op* identity = TFE_NewOp(ctx, "Identity", status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
TFE_OpReset(identity, "Identity", nullptr, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(identity, m, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Execute(identity, &retvals[0], &num_retvals, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
}
@ -405,6 +416,11 @@ void TensorHandleSilentCopy(bool async,
hcpu, ctx, gpu_device_name.c_str(), status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
auto cpu_arg = tensorflow::TensorHandleFromInterface(hcpu->handle);
auto gpu_arg = tensorflow::TensorHandleFromInterface(hgpu->handle);
auto gpu_device = absl::get<tensorflow::Device*>(gpu_arg->device());
ASSERT_FALSE(cpu_arg->HasLocalMirror(gpu_device));
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
if (cpu_op) {
string cpu_device_name;
@ -420,15 +436,8 @@ void TensorHandleSilentCopy(bool async,
TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
// Validate if the input was replaced with a different TensorHandle
auto arg0 = tensorflow::TensorHandleFromInterface(hcpu->handle);
auto arg1 = tensorflow::TensorHandleFromInterface(hgpu->handle);
tensorflow::EagerOperation* op =
tensorflow::OperationFromInterface(matmul->operation);
// The input handles should never change since they have been mirrored.
EXPECT_EQ(op->Inputs()[0], arg0);
EXPECT_EQ(op->Inputs()[1], arg1);
// The CPU handle should have been copied and have a mirror on the GPU
ASSERT_TRUE(cpu_arg->HasLocalMirror(gpu_device));
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(retvals[0]);
@ -626,17 +635,6 @@ void ExecuteAdd(bool async, bool forward_input, bool tfrt) {
}
int num_retvals = 1;
if (async) {
// Enqueue dummy ops so we backlog async execution & actually test async.
for (int i = 0; i < 10000; ++i) {
TFE_TensorHandle* dummy = nullptr;
TFE_Execute(add_op, &dummy, &num_retvals, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(dummy);
}
}
TFE_TensorHandle* retval = nullptr;
TFE_Execute(add_op, &retval, &num_retvals, status);
EXPECT_EQ(1, num_retvals);

View File

@ -38,96 +38,159 @@ typedef void (*ExecuteOperation)(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs,
TF_OutputList* o, TF_ExecutionContext* ctx,
TF_Status* s);
struct TF_ExecutionContext {
explicit TF_ExecutionContext() {}
absl::variant<TFE_Context*, TF_GraphContext*> ctx;
ExecuteOperation execution_callback;
};
// Needed to implement our own version of RTTI since dynamic_cast is not
// supported in mobile builds.
enum ExecutionContextKind { GraphContext, EagerContext };
explicit TF_ExecutionContext(ExecutionContextKind kind) : k(kind) {}
ExecutionContextKind getKind() const { return k; }
struct TF_AbstractTensor {
absl::variant<TFE_TensorHandle*, TF_GraphTensor*> t;
};
virtual void ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs,
TF_OutputList* o, TF_Status* s) = 0;
virtual TF_AbstractOp* CreateOperation() = 0;
virtual void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) = 0;
virtual ~TF_ExecutionContext() {}
struct TF_AbstractOp {
string op_type;
string op_name;
private:
const ExecutionContextKind k;
};
TF_ExecutionContext* TF_NewExecutionContext() {
return new TF_ExecutionContext();
}
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete c; }
TF_AbstractOp* TF_NewAbstractOp() {
TF_AbstractOp* op = new TF_AbstractOp;
return op;
template <typename T, typename S>
T* dynamic_cast_helper(S source) {
if (source->getKind() != T::kKind) {
return nullptr;
}
return tensorflow::down_cast<T*>(source);
}
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete op; }
TF_AbstractTensor* TF_NewAbstractTensor() {
TF_AbstractTensor* t = new TF_AbstractTensor;
return t;
}
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete t; }
struct TF_GraphContext {
TF_Graph* graph;
// TODO(srbs): Handle captures.
};
TF_GraphContext* TF_NewGraphContext(TF_Graph* g) {
auto ctx = new TF_GraphContext;
ctx->graph = g;
return ctx;
}
void TF_DeleteGraphContext(TF_GraphContext* ctx) { delete ctx; }
class TF_GraphContext;
class TF_EagerContext;
struct TF_GraphTensor {
TF_Output output;
TF_GraphContext* ctx;
};
TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* ctx, TF_Output output,
TF_Status* s) {
TF_GraphTensor* t = new TF_GraphTensor;
t->output = output;
t->ctx = ctx;
return t;
}
TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s) {
return t->output;
}
void TF_DeleteGraphTensor(TF_GraphTensor* t) { delete t; }
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
TF_Status* s) {
at->t = t;
}
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s) {
if (!absl::holds_alternative<TFE_TensorHandle*>(at->t)) {
string msg = absl::StrCat("Not an eager tensor handle.",
reinterpret_cast<uintptr_t>(at));
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
struct TF_AbstractTensor {
absl::variant<TFE_TensorHandle*, TF_GraphTensor*> t;
~TF_AbstractTensor() {
if (absl::holds_alternative<TFE_TensorHandle*>(t)) {
TFE_DeleteTensorHandle(absl::get<TFE_TensorHandle*>(t));
} else if (absl::holds_alternative<TF_GraphTensor*>(t)) {
delete absl::get<TF_GraphTensor*>(t);
}
}
return absl::get<TFE_TensorHandle*>(at->t);
};
struct TF_AbstractOp {
// Needed to implement our own version of RTTI since dynamic_cast is not
// supported in mobile builds.
enum AbstractOpKind { GraphOp, EagerOp };
explicit TF_AbstractOp(AbstractOpKind kind) : k(kind) {}
AbstractOpKind getKind() const { return k; }
virtual void SetOpType(const char* const op_type, TF_Status* s) = 0;
virtual void SetOpName(const char* const op_name, TF_Status* s) = 0;
virtual void SetAttrType(const char* const attr_name, TF_DataType value,
TF_Status* s) = 0;
virtual ~TF_AbstractOp() {}
private:
const AbstractOpKind k;
};
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
return c->CreateOperation();
}
void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t,
TF_Status* s) {
at->t = t;
}
TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at,
TF_Status* s) {
if (!absl::holds_alternative<TF_GraphTensor*>(at->t)) {
string msg = absl::StrCat("Not an graph tensor handle.");
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete op; }
class TF_GraphOp : public TF_AbstractOp {
public:
explicit TF_GraphOp(TF_Graph* g) : TF_AbstractOp(kKind), g_(g) {}
void SetOpType(const char* const op_type, TF_Status* s) override {
if (op_) {
TF_SetStatus(
s, TF_FAILED_PRECONDITION,
absl::StrCat("SetOpType called on already built op.").c_str());
return;
}
if (op_name_ != nullptr) {
op_.reset(TF_NewOperation(g_, op_type, op_name_));
op_name_ = nullptr;
} else {
op_type_ = op_type;
}
}
return absl::get<TF_GraphTensor*>(at->t);
}
void SetOpName(const char* const op_name, TF_Status* s) override {
if (op_) {
TF_SetStatus(
s, TF_FAILED_PRECONDITION,
absl::StrCat("SetOpName called on already built op.").c_str());
return;
}
if (op_type_ != nullptr) {
op_.reset(TF_NewOperation(g_, op_type_, op_name));
op_type_ = nullptr;
} else {
op_name_ = op_name;
}
}
void SetAttrType(const char* const attr_name, TF_DataType value,
TF_Status* s) override {
if (!op_) {
TF_SetStatus(
s, TF_FAILED_PRECONDITION,
"op_type and op_name must be specified before specifying attrs.");
return;
}
TF_SetAttrType(op_.get(), attr_name, value);
}
~TF_GraphOp() override {}
static constexpr AbstractOpKind kKind = GraphOp;
private:
friend class TF_GraphContext; // For access to op_.
TF_Graph* g_;
std::unique_ptr<TF_OperationDescription> op_;
// Hold `op_type` and `op_name` till both are available since we need both
// to build a graph operation.
const char* op_type_ = nullptr;
const char* op_name_ = nullptr;
};
class TF_EagerOp : public TF_AbstractOp {
public:
explicit TF_EagerOp(TFE_Context* ctx) : TF_AbstractOp(kKind), ctx_(ctx) {}
void SetOpType(const char* const op_type, TF_Status* s) override {
op_ = TFE_NewOp(ctx_, op_type, s);
}
void SetOpName(const char* const op_name, TF_Status* s) override {
// Name is ignored in eager mode.
}
void SetAttrType(const char* const attr_name, TF_DataType value,
TF_Status* s) override {
if (op_ == nullptr) {
TF_SetStatus(s, TF_FAILED_PRECONDITION,
"op_type must be specified before specifying attrs.");
return;
}
TFE_OpSetAttrType(op_, attr_name, value);
}
~TF_EagerOp() override { TFE_DeleteOp(op_); }
static constexpr AbstractOpKind kKind = EagerOp;
private:
friend class TF_EagerContext; // For access to op_.
TFE_Op* op_ = nullptr;
TFE_Context* ctx_;
};
bool IsEagerTensor(const TF_AbstractTensor* const t) {
return absl::holds_alternative<TFE_TensorHandle*>(t->t);
@ -138,6 +201,221 @@ struct TF_OutputList {
int expected_num_outputs = -1;
};
struct TF_AbstractFunction {
TF_Function* func = nullptr;
~TF_AbstractFunction() { TF_DeleteFunction(func); }
};
class TF_EagerContext : public TF_ExecutionContext {
public:
TF_EagerContext() : TF_ExecutionContext(kKind) {}
void Build(TFE_ContextOptions* options, TF_Status* status) {
eager_ctx_ = TFE_NewContext(options, status);
}
TF_AbstractOp* CreateOperation() override {
// TODO(srbs): Should the lifetime of this op be tied to the context.
return new TF_EagerOp(eager_ctx_);
}
void ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_Status* s) override {
auto* eager_op = dynamic_cast_helper<TF_EagerOp>(op);
if (eager_op == nullptr) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Unable to cast TF_AbstractOp to TF_EagerOp.");
return;
}
auto* tfe_op = eager_op->op_;
if (TF_GetCode(s) != TF_OK) return;
for (int i = 0; i < num_inputs; ++i) {
if (!IsEagerTensor(inputs[i])) {
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
return;
}
TFE_OpAddInput(tfe_op, absl::get<TFE_TensorHandle*>(inputs[i]->t), s);
if (TF_GetCode(s) != TF_OK) return;
}
if (o->expected_num_outputs == -1) {
string msg =
"The number of outputs must be provided in eager mode. Use "
"TF_OutputListSetNumOutputs.";
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return;
}
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals;
int num_retvals = o->expected_num_outputs;
retvals.resize(num_retvals);
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
if (TF_GetCode(s) != TF_OK) {
return;
}
o->outputs.clear();
o->outputs.reserve(num_retvals);
for (int i = 0; i < num_retvals; ++i) {
auto* t = new TF_AbstractTensor();
t->t = retvals[i];
o->outputs.push_back(t);
}
}
void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) override {
TFE_ContextAddFunction(eager_ctx_, func->func, s);
}
~TF_EagerContext() override { TFE_DeleteContext(eager_ctx_); }
static constexpr ExecutionContextKind kKind = EagerContext;
private:
friend TFE_Context* TF_ExecutionContextGetTFEContext(
TF_ExecutionContext* ctx);
TFE_Context* eager_ctx_;
};
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete t; }
TF_GraphContext* GetGraphContext(TF_AbstractTensor const* t) {
return absl::get<TF_GraphTensor*>(t->t)->ctx;
}
class TF_GraphContext : public TF_ExecutionContext {
public:
TF_GraphContext()
: TF_ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {}
TF_AbstractOp* CreateOperation() override {
// TODO(srbs): Should the lifetime of this op be tied to the context.
return new TF_GraphOp(graph_.get());
}
void ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_Status* s) override {
auto* graph_op = dynamic_cast_helper<TF_GraphOp>(op);
if (graph_op == nullptr) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Unable to cast TF_AbstractOp to TF_GraphOp.");
return;
}
auto* tf_opdesc = graph_op->op_.release();
for (int i = 0; i < num_inputs; ++i) {
auto* input = inputs[i];
if (IsEagerTensor(input)) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Capturing eager tensors is not supported yet.");
return;
} else {
if (GetGraphContext(input) != this) {
TF_SetStatus(
s, TF_INVALID_ARGUMENT,
"Capturing tensors from other graphs is not supported yet.");
return;
}
TF_AddInput(tf_opdesc, absl::get<TF_GraphTensor*>(input->t)->output);
}
}
auto* operation = TF_FinishOperation(tf_opdesc, s);
// TF_FinishOperation deletes `tf_opdesc` so clear its reference.
graph_op->op_ = nullptr;
if (TF_GetCode(s) != TF_OK) return;
int num_outputs = TF_OperationNumOutputs(operation);
o->outputs.clear();
o->outputs.reserve(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
auto* t = new TF_AbstractTensor;
TF_GraphTensor* graph_t = new TF_GraphTensor;
graph_t->ctx = this;
graph_t->output = {operation, i};
t->t = graph_t;
o->outputs.push_back(t);
}
}
TF_Function* ToFunction(const char* fn_name, int num_inputs,
const TF_AbstractTensor* inputs, int num_outputs,
const TF_AbstractTensor* outputs,
TF_Status* status) const {
std::vector<TF_Output> graph_inputs;
graph_inputs.resize(num_inputs);
std::vector<TF_Output> graph_outputs;
graph_outputs.resize(num_outputs);
for (int i = 0; i < num_inputs; i++) {
graph_inputs[i] = absl::get<TF_GraphTensor*>(inputs[i].t)->output;
}
for (int i = 0; i < num_outputs; i++) {
graph_outputs[i] = absl::get<TF_GraphTensor*>(outputs[i].t)->output;
}
return TF_GraphToFunction(graph_.get(), fn_name, 0, -1, nullptr,
graph_inputs.size(), graph_inputs.data(),
graph_outputs.size(), graph_outputs.data(),
nullptr, nullptr, fn_name, status);
}
void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) override {
TF_SetStatus(s, TF_UNIMPLEMENTED,
"Registering graph functions has not been implemented yet.");
}
~TF_GraphContext() override {}
static constexpr ExecutionContextKind kKind = GraphContext;
private:
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
};
struct TF_GraphContextOptions {};
struct TF_EagerContextOptions {
explicit TF_EagerContextOptions(TFE_ContextOptions* options)
: options(options) {}
TFE_ContextOptions* options; // Not owned.
};
struct TF_ExecutionContextOptions {
absl::variant<TF_GraphContextOptions*, TF_EagerContextOptions*> options;
~TF_ExecutionContextOptions() {
if (absl::holds_alternative<TF_GraphContextOptions*>(options)) {
delete absl::get<TF_GraphContextOptions*>(options);
} else if (absl::holds_alternative<TF_EagerContextOptions*>(options)) {
delete absl::get<TF_EagerContextOptions*>(options);
}
}
};
TF_ExecutionContextOptions* TF_NewGraphContextOptions() {
auto* options = new TF_ExecutionContextOptions();
options->options = new TF_GraphContextOptions();
return options;
}
void TF_DeleteExecutionContextOptions(TF_ExecutionContextOptions* options) {
delete options;
}
TF_ExecutionContextOptions* TF_NewEagerContextOptions(
TFE_ContextOptions* tfe_options) {
auto* options = new TF_ExecutionContextOptions();
options->options = new TF_EagerContextOptions(tfe_options);
return options;
}
TF_ExecutionContext* TF_NewExecutionContext(TF_ExecutionContextOptions* options,
TF_Status* s) {
if (absl::holds_alternative<TF_EagerContextOptions*>(options->options)) {
auto* ctx = new TF_EagerContext();
ctx->Build(absl::get<TF_EagerContextOptions*>(options->options)->options,
s);
return ctx;
} else {
return new TF_GraphContext();
}
}
TF_OutputList* TF_NewOutputList() { return new TF_OutputList; }
void TF_DeleteOutputList(TF_OutputList* o) { delete o; }
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
@ -149,113 +427,74 @@ TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
return o->outputs[i];
}
void ExecuteOperationEager(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s) {
auto* tfe_op =
TFE_NewOp(absl::get<TFE_Context*>(ctx->ctx), op->op_type.c_str(), s);
if (TF_GetCode(s) != TF_OK) return;
for (int i = 0; i < num_inputs; ++i) {
if (!IsEagerTensor(inputs[i])) {
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
return;
}
TFE_OpAddInput(tfe_op, absl::get<TFE_TensorHandle*>(inputs[i]->t), s);
if (TF_GetCode(s) != TF_OK) return;
}
if (o->expected_num_outputs == -1) {
string msg =
"The number of outputs must be provided in eager mode. Use "
"TF_OutputListSetNumOutputs.";
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return;
}
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals;
int num_retvals = o->expected_num_outputs;
retvals.resize(num_retvals);
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
TFE_DeleteOp(tfe_op);
if (TF_GetCode(s) != TF_OK) {
return;
}
o->outputs.clear();
o->outputs.reserve(num_retvals);
for (int i = 0; i < num_retvals; ++i) {
auto* t = TF_NewAbstractTensor();
t->t = retvals[i];
o->outputs.push_back(t);
}
}
TF_GraphContext* GetGraphContext(TF_AbstractTensor const* t) {
return absl::get<TF_GraphTensor*>(t->t)->ctx;
}
void ExecuteOperationGraph(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s) {
TF_GraphContext* graph_ctx = absl::get<TF_GraphContext*>(ctx->ctx);
TF_Graph* g = graph_ctx->graph;
auto* tf_opdesc =
TF_NewOperation(g, op->op_type.c_str(), op->op_name.c_str());
for (int i = 0; i < num_inputs; ++i) {
auto* input = inputs[i];
if (IsEagerTensor(input)) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Capturing eager tensors is not supported yet.");
return;
} else {
if (GetGraphContext(input) != graph_ctx) {
TF_SetStatus(
s, TF_INVALID_ARGUMENT,
"Capturing tensors from other graphs is not supported yet.");
return;
}
TF_AddInput(tf_opdesc, absl::get<TF_GraphTensor*>(input->t)->output);
}
}
auto* operation = TF_FinishOperation(tf_opdesc, s);
if (TF_GetCode(s) != TF_OK) return;
int num_outputs = TF_OperationNumOutputs(operation);
o->outputs.clear();
o->outputs.reserve(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
auto* t = TF_NewAbstractTensor();
TF_GraphTensor* output_t = TF_NewGraphTensor(graph_ctx, {operation, i}, s);
if (TF_GetCode(s) != TF_OK) {
return;
}
t->t = output_t;
o->outputs.push_back(t);
}
}
void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context,
TFE_Context* eager_context,
TF_Status* s) {
context->ctx = eager_context;
context->execution_callback = &ExecuteOperationEager;
}
void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context,
TF_GraphContext* graph_context,
TF_Status* s) {
context->ctx = graph_context;
context->execution_callback = &ExecuteOperationGraph;
}
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
TF_Status* s) {
op->op_type = op_type;
op->SetOpType(op_type, s);
}
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
TF_Status* s) {
op->op_name = op_name;
op->SetOpName(op_name, s);
}
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
TF_DataType value, TF_Status* s) {
op->SetAttrType(attr_name, value, s);
}
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s) {
ctx->execution_callback(op, num_inputs, inputs, o, ctx, s);
ctx->ExecuteOperation(op, num_inputs, inputs, o, s);
}
TF_AbstractFunction* TF_ExecutionContextToFunction(
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const TF_AbstractTensor* inputs, int num_outputs,
const TF_AbstractTensor* outputs, TF_Status* status) {
auto* graph_ctx = dynamic_cast_helper<const TF_GraphContext>(fn_body);
if (graph_ctx == nullptr) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"fn_body is not a TF_GraphContext.");
return nullptr;
}
TF_AbstractFunction* func = new TF_AbstractFunction;
func->func = graph_ctx->ToFunction(fn_name, num_inputs, inputs, num_outputs,
outputs, status);
return func;
}
void TF_DeleteAbstractFunction(TF_AbstractFunction* func) { delete func; }
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext* ctx,
TF_AbstractFunction* func,
TF_Status* s) {
ctx->RegisterFunction(func, s);
}
// Temporary APIs till we figure out how to create scalar valued Eager
// tensors and how to get value out of eager abstract tensors.
TF_AbstractTensor* TF_NewAbstractTensor() {
TF_AbstractTensor* t = new TF_AbstractTensor;
return t;
}
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
TF_Status* s) {
at->t = t;
}
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s) {
if (!absl::holds_alternative<TFE_TensorHandle*>(at->t)) {
string msg = absl::StrCat("Not an eager tensor handle.",
reinterpret_cast<uintptr_t>(at));
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
}
return absl::get<TFE_TensorHandle*>(at->t);
}
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx) {
return dynamic_cast_helper<TF_EagerContext>(ctx)->eager_ctx_;
}

View File

@ -15,8 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/tf_status.h"
#ifdef __cplusplus
extern "C" {
@ -41,32 +41,19 @@ typedef struct TF_AbstractTensor TF_AbstractTensor;
// could contain the op type and other attributes.
typedef struct TF_AbstractOp TF_AbstractOp;
TF_ExecutionContext* TF_NewExecutionContext();
// `TF_ExecutionContextOptions` define what type of `TF_ExecutionContext` is
// created. It can be used to pass context specific params.
typedef struct TF_ExecutionContextOptions TF_ExecutionContextOptions;
void TF_DeleteExecutionContextOptions(TF_ExecutionContextOptions*);
TF_ExecutionContext* TF_NewExecutionContext(TF_ExecutionContextOptions*,
TF_Status* s);
void TF_DeleteExecutionContext(TF_ExecutionContext*);
TF_AbstractOp* TF_NewAbstractOp();
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx);
void TF_DeleteAbstractOp(TF_AbstractOp*);
TF_AbstractTensor* TF_NewAbstractTensor();
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
// -----------------------------------------------------------------------------
// APIs for Eager and graph modes
// -----------------------------------------------------------------------------
// Keeps track of the current graph and other state e.g. captures etc.
typedef struct TF_GraphContext TF_GraphContext;
TF_GraphContext* TF_NewGraphContext(TF_Graph*);
void TF_DeleteGraphContext(TF_GraphContext*);
// `eager_context` must outlive `context`.
void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context,
TFE_Context* eager_context, TF_Status*);
// `graph_context` must outlive `context`.
void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context,
TF_GraphContext* graph_context,
TF_Status*);
// TODO(srbs): Add APIs for specifying attrs etc.
// `op_type` must outlive `op`.
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
@ -74,25 +61,9 @@ void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
// `op_name` must outlive `op`.
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
TF_Status* s);
// Wrapper for TF_Output but contains a pointer to TF_GraphContext as well.
typedef struct TF_GraphTensor TF_GraphTensor;
TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* c, TF_Output t,
TF_Status* s);
TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s);
void TF_DeleteGraphTensor(TF_GraphTensor* t);
// `t` must outlive `at`.
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
TF_Status* s);
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s);
// `t` must outlive `at`.
void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t,
TF_Status* s);
TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at,
TF_Status* s);
// `attr_name` must outlive `op`.
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
TF_DataType value, TF_Status* s);
// TF_OutputList just lets us not specify the number of outputs of an operation
// beforehand. This forces a memory allocation in the runtime, which is bad, but
@ -104,6 +75,17 @@ void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*);
int TF_OutputListNumOutputs(TF_OutputList* o);
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
// Stores a function representation that can be used for execution or for
// setting functional attributes of other composite ops e.g. control flow.
typedef struct TF_AbstractFunction TF_AbstractFunction;
TF_AbstractFunction* TF_ExecutionContextToFunction(
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const TF_AbstractTensor* inputs, int num_outputs,
const TF_AbstractTensor* outputs, TF_Status* status);
void TF_DeleteAbstractFunction(TF_AbstractFunction*);
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext*,
TF_AbstractFunction*, TF_Status*);
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
// capture some inputs and then add a node in the graph, and after
// execution/node creation it'll go and record things that happened in any tape
@ -112,6 +94,23 @@ void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s);
// -----------------------------------------------------------------------------
// APIs specific to Eager and graph modes
// -----------------------------------------------------------------------------
TF_ExecutionContextOptions* TF_NewGraphContextOptions();
TF_ExecutionContextOptions* TF_NewEagerContextOptions(TFE_ContextOptions*);
// Temporary APIs till we figure out how to create scalar valued Eager
// tensors and how to get value out of eager abstract tensors.
TF_AbstractTensor* TF_NewAbstractTensor();
void TF_AbstractTensorSetEagerTensor(
TF_AbstractTensor* at, TFE_TensorHandle* t,
TF_Status* s); // `at` takes ownership of `t`.
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s);
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext*);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -33,26 +33,25 @@ namespace tensorflow {
namespace {
TEST(UnifedCAPI, TestBasicEager) {
TF_ExecutionContext* ctx = TF_NewExecutionContext();
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* eager_ctx = TFE_NewContext(opts, status.get());
TF_ExecutionContextOptions* options = TF_NewEagerContextOptions(opts);
TF_ExecutionContext* ctx = TF_NewExecutionContext(options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
// Enter the eager context.
TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract input tensor.
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx);
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
TF_AbstractTensor* at = TF_NewAbstractTensor();
TF_AbstractTensorSetEagerTensor(at, t, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract operation.
auto* op = TF_NewAbstractOp();
auto* op = TF_NewAbstractOp(ctx);
TF_AbstractOpSetOpType(op, "Add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
@ -69,7 +68,6 @@ TEST(UnifedCAPI, TestBasicEager) {
// Clean up operation and inputs.
TF_DeleteAbstractOp(op);
TF_DeleteAbstractTensor(at);
TFE_DeleteTensorHandle(t);
// Verify the results.
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
@ -83,100 +81,98 @@ TEST(UnifedCAPI, TestBasicEager) {
TF_DeleteTensor(result_tensor);
TF_DeleteAbstractTensor(result);
TFE_DeleteTensorHandle(result_t);
TF_DeleteOutputList(o);
TFE_DeleteContext(eager_ctx);
TF_DeleteExecutionContext(ctx);
TF_DeleteExecutionContextOptions(options);
}
TEST(UnifedCAPI, TestBasicGraph) {
TF_ExecutionContext* ctx = TF_NewExecutionContext();
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
// Enter a graph context.
TF_Graph* g = TF_NewGraph();
TF_GraphContext* graph_context = TF_NewGraphContext(g);
TF_ExecutionContextSetGraphContext(ctx, graph_context, status.get());
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
TF_ExecutionContext* graph_ctx =
TF_NewExecutionContext(options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
auto* placeholder_op = TF_NewOperation(g, "Placeholder", "Placeholder");
TF_SetAttrType(placeholder_op, "dtype", TF_FLOAT);
auto* operation = TF_FinishOperation(placeholder_op, status.get());
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_Output placeholder_t = {operation, 0};
TF_GraphTensor* graph_t =
TF_NewGraphTensor(graph_context, placeholder_t, status.get());
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractTensor* t = TF_NewAbstractTensor();
TF_AbstractTensorSetGraphTensor(t, graph_t, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract operation.
auto* op = TF_NewAbstractOp();
TF_AbstractOpSetOpType(op, "Add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(op, "my_add", status.get());
TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {t, t};
TF_OutputList* o = TF_NewOutputList();
TF_OutputList* placeholder_outputs = TF_NewOutputList();
// Execute.
TF_ExecuteOperation(op, 2, inputs, o, ctx, status.get());
TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs,
graph_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs));
TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0);
// Delete placeholder op.
TF_DeleteAbstractOp(placeholder_op);
// Build an abstract operation.
auto* add_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(add_op, "Add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(add_op, "my_add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t};
TF_OutputList* add_outputs = TF_NewOutputList();
// Execute.
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractTensor* output_t = TF_OutputListGet(add_outputs, 0);
// Clean up operation and inputs.
TF_DeleteAbstractOp(op);
TF_DeleteAbstractTensor(t);
TF_DeleteGraphTensor(graph_t);
TF_DeleteAbstractOp(add_op);
TF_AbstractTensor* result = TF_OutputListGet(o, 0);
TF_GraphTensor* result_graph_tensor =
TF_AbstractTensorGetGraphTensor(result, status.get());
TF_DeleteAbstractTensor(result);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_Output result_output =
TF_GraphTensorToOutput(result_graph_tensor, status.get());
TF_DeleteGraphTensor(result_graph_tensor);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
string fn_name = "double";
TF_Function* f = TF_GraphToFunction(
g, fn_name.c_str(), 0, -1, nullptr, 1, &placeholder_t, 1, &result_output,
nullptr, nullptr, fn_name.c_str(), status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
graph_ctx, fn_name.c_str(), 1, placeholder_t, 1, output_t, status.get());
TF_DeleteAbstractTensor(placeholder_t);
TF_DeleteAbstractTensor(output_t);
// Build an eager context to run the function.
// Build eager context.
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* eager_ctx = TFE_NewContext(opts, status.get());
TF_ExecutionContextOptions* eager_ctx_options =
TF_NewEagerContextOptions(opts);
TF_ExecutionContext* eager_execution_ctx =
TF_NewExecutionContext(eager_ctx_options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
// Build the abstract op to run the function.
TFE_ContextAddFunction(eager_ctx, f, status.get());
TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOp* fn_op = TF_NewAbstractOp();
// Build the abstract op to run the function.
TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx);
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract input tensor.
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
TF_AbstractTensor* input_t = TF_NewAbstractTensor();
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(eager_execution_ctx);
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
TF_AbstractTensorSetEagerTensor(input_t, input_eager, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Enter the eager context.
TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get());
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_OutputListSetNumOutputs(o, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_ExecuteOperation(fn_op, 1, &input_t, o, ctx, status.get());
TF_ExecuteOperation(fn_op, 1, &input_t, add_outputs, eager_execution_ctx,
status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
TF_AbstractTensor* final_result = TF_OutputListGet(o, 0);
ASSERT_EQ(1, TF_OutputListNumOutputs(add_outputs));
TF_AbstractTensor* final_result = TF_OutputListGet(add_outputs, 0);
TFE_TensorHandle* final =
TF_AbstractTensorGetEagerTensor(final_result, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
@ -185,19 +181,201 @@ TEST(UnifedCAPI, TestBasicGraph) {
float* f_value = static_cast<float*>(TF_TensorData(f_t));
ASSERT_EQ(*f_value, 4.0);
TF_DeleteOutputList(o);
TF_DeleteOutputList(add_outputs);
TF_DeleteOutputList(placeholder_outputs);
TF_DeleteAbstractOp(fn_op);
TF_DeleteAbstractTensor(input_t);
TFE_DeleteTensorHandle(input_eager);
TF_DeleteAbstractTensor(final_result);
TFE_DeleteTensorHandle(final);
TF_DeleteTensor(f_t);
TF_DeleteFunction(f);
TF_DeleteAbstractFunction(func);
TF_DeleteExecutionContext(graph_ctx);
TF_DeleteExecutionContext(eager_execution_ctx);
TF_DeleteExecutionContextOptions(eager_ctx_options);
TF_DeleteExecutionContextOptions(options);
}
TEST(UnifedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContextOptions* options = TF_NewEagerContextOptions(opts);
TF_ExecutionContext* ctx = TF_NewExecutionContext(options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
ctx, nullptr, 0, nullptr, 0, nullptr, status.get());
ASSERT_EQ(nullptr, func);
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
TF_DeleteGraphContext(graph_context);
TF_DeleteGraph(g);
TFE_DeleteContext(eager_ctx);
TF_DeleteExecutionContext(ctx);
TF_DeleteExecutionContextOptions(options);
}
TEST(UnifedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
TF_ExecutionContext* graph_ctx =
TF_NewExecutionContext(options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// This should fail.
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
ASSERT_EQ(TF_FAILED_PRECONDITION, TF_GetCode(status.get()));
TF_DeleteAbstractOp(placeholder_op);
TF_DeleteExecutionContext(graph_ctx);
TF_DeleteExecutionContextOptions(options);
}
TEST(UnifedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
TF_ExecutionContext* graph_ctx =
TF_NewExecutionContext(options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// This should fail.
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
ASSERT_EQ(TF_FAILED_PRECONDITION, TF_GetCode(status.get()));
TF_DeleteAbstractOp(placeholder_op);
TF_DeleteExecutionContext(graph_ctx);
TF_DeleteExecutionContextOptions(options);
}
TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
// Build an Eager context.
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContextOptions* options = TF_NewEagerContextOptions(opts);
TF_ExecutionContext* ctx = TF_NewExecutionContext(options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an Eager operation.
auto* op = TF_NewAbstractOp(ctx);
TF_AbstractOpSetOpType(op, "Add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract input tensor.
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx);
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
TF_AbstractTensor* at = TF_NewAbstractTensor();
TF_AbstractTensorSetEagerTensor(at, t, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {at, at};
TF_OutputList* o = TF_NewOutputList();
TF_OutputListSetNumOutputs(o, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build a Graph context.
TF_ExecutionContextOptions* graph_options = TF_NewGraphContextOptions();
TF_ExecutionContext* graph_ctx =
TF_NewExecutionContext(graph_options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Execute eager op using graph context.
TF_ExecuteOperation(op, 2, inputs, o, graph_ctx, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
// Clean up operation and inputs.
TF_DeleteAbstractOp(op);
TF_DeleteAbstractTensor(at);
TF_DeleteOutputList(o);
TF_DeleteExecutionContext(ctx);
TF_DeleteExecutionContextOptions(options);
TF_DeleteExecutionContext(graph_ctx);
TF_DeleteExecutionContextOptions(graph_options);
}
TEST(UnifedCAPI, TestExecutingGraphOpInEagerModeRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
TF_ExecutionContext* graph_ctx =
TF_NewExecutionContext(options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_OutputList* placeholder_outputs = TF_NewOutputList();
// Execute.
TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs,
graph_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs));
TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0);
// Delete placeholder op.
TF_DeleteAbstractOp(placeholder_op);
// Build an abstract operation.
auto* add_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(add_op, "Add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(add_op, "my_add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t};
TF_OutputList* add_outputs = TF_NewOutputList();
// Build eager context.
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContextOptions* eager_ctx_options =
TF_NewEagerContextOptions(opts);
TF_ExecutionContext* eager_execution_ctx =
TF_NewExecutionContext(eager_ctx_options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
// Execute.
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, eager_execution_ctx,
status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
// Clean up operation and inputs.
TF_DeleteAbstractTensor(placeholder_t);
TF_DeleteAbstractOp(add_op);
TF_DeleteOutputList(add_outputs);
TF_DeleteOutputList(placeholder_outputs);
TF_DeleteExecutionContext(graph_ctx);
TF_DeleteExecutionContext(eager_execution_ctx);
TF_DeleteExecutionContextOptions(eager_ctx_options);
TF_DeleteExecutionContextOptions(options);
}
} // namespace

View File

@ -16,134 +16,16 @@ limitations under the License.
// A simple logging device to test custom device registration.
#include <memory>
#include "absl/strings/match.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/custom_device_testutil.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/test.h"
namespace {
struct LoggingDevice {
tensorflow::string device_name;
tensorflow::string underlying_device;
// Set to true whenever a TensorHandle is copied onto the device
bool* arrived_flag;
// Set to true whenever an operation is executed
bool* executed_flag;
};
struct LoggedTensor {
TFE_TensorHandle* tensor;
LoggedTensor() = delete;
explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
};
void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
delete reinterpret_cast<LoggedTensor*>(data);
}
TFE_TensorHandle* MakeLoggedTensorHandle(
TFE_Context* context, const tensorflow::string& logging_device_name,
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
auto dtype = TFE_TensorHandleDataType(t->tensor);
return TFE_NewTensorHandleFromDeviceMemory(
context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
}
TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status, void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, context, dev->underlying_device.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
auto dst = std::make_unique<LoggedTensor>(t);
*(dev->arrived_flag) = true;
return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
status);
}
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info) {
TF_SetStatus(status, TF_INTERNAL,
"Trying to copy a tensor out of a logging device.");
return nullptr;
}
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s,
void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_Op* op(TFE_NewOp(context, operation_name, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddAttrs(op, attributes);
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
for (int j = 0; j < num_inputs; ++j) {
TFE_TensorHandle* input = inputs[j];
const char* input_device = TFE_TensorHandleDeviceName(input, s);
if (TF_GetCode(s) != TF_OK) return;
if (dev->device_name == input_device) {
LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(input, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddInput(op, t->tensor, s);
} else {
TFE_OpAddInput(op, input, s);
}
if (TF_GetCode(s) != TF_OK) return;
}
std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
TFE_Execute(op, op_outputs.data(), num_outputs, s);
TFE_DeleteOp(op);
if (TF_GetCode(s) != TF_OK) return;
std::vector<TFE_TensorHandle*> unwrapped_outputs;
for (auto* handle : op_outputs) {
unwrapped_outputs.push_back(handle);
}
for (int i = 0; i < *num_outputs; ++i) {
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
std::move(logged_tensor), s);
}
*(dev->executed_flag) = true;
}
void DeleteLoggingDevice(void* device_info) {
delete reinterpret_cast<LoggingDevice*>(device_info);
}
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag, bool* executed_flag,
TF_Status* status) {
TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
custom_device.delete_device = &DeleteLoggingDevice;
custom_device.execute = &LoggingDeviceExecute;
LoggingDevice* device = new LoggingDevice;
device->arrived_flag = arrived_flag;
device->executed_flag = executed_flag;
device->device_name = name;
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
}
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
@ -276,9 +158,7 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
tensorflow::string(
TFE_TensorHandleBackingDeviceName(var_value, status.get())));
TFE_TensorHandle* var_value_unpacked =
reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(var_value, status.get()))
->tensor;
UnpackTensorHandle(var_value, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> resolved_value(
TFE_TensorHandleResolve(var_value_unpacked, status.get()),
@ -296,7 +176,7 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
TEST(CUSTOM_DEVICE, AccessVariableOnCustomDevice) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
@ -346,16 +226,21 @@ TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
// Read the variable's value.
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpAddInput(op.get(), var_handle, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
executed = false;
num_retvals = 1;
TFE_TensorHandle* var_value = nullptr;
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
EXPECT_FALSE(TF_GetCode(status.get()) == TF_OK)
<< "Execution should fail because the variable is being used on the "
"wrong device.";
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
ASSERT_EQ(
tensorflow::string(name),
tensorflow::string(TFE_TensorHandleDeviceName(var_value, status.get())));
TFE_DeleteTensorHandle(var_value);
// Free the backing buffer for the variable.
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
@ -366,6 +251,79 @@ TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
TEST(CUSTOM_DEVICE, InputBasedPlacement) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* custom0 = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
const char* custom1 = "/job:localhost/replica:0/task:0/device:CUSTOM:1";
bool arrived = false;
bool executed = false;
RegisterLoggingDevice(context.get(), custom0, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
RegisterLoggingDevice(context.get(), custom1, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcpu(
TestMatrixTensorHandle(context.get()), TFE_DeleteTensorHandle);
ASSERT_FALSE(arrived);
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcustom0(
TFE_TensorHandleCopyToDevice(hcpu.get(), context.get(), custom0,
status.get()),
TFE_DeleteTensorHandle);
ASSERT_TRUE(arrived);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
arrived = false;
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcustom1(
TFE_TensorHandleCopyToDevice(hcpu.get(), context.get(), custom1,
status.get()),
TFE_DeleteTensorHandle);
ASSERT_TRUE(arrived);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Base case: two CPU inputs executes fine.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> matmul(
MatMulOp(context.get(), hcpu.get(), hcpu.get()), TFE_DeleteOp);
TFE_TensorHandle* retval;
int num_retvals = 1;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_DeleteTensorHandle(retval);
// Custom device: inputs in same custom device works.
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcustom0.get()));
num_retvals = 1;
executed = false;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
TFE_DeleteTensorHandle(retval);
// Custom device: inputs in different custom devices fails.
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcustom1.get()));
num_retvals = 1;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_NE(TF_OK, TF_GetCode(status.get()));
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom1));
// Custom device: mix of custom/physical fails.
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
num_retvals = 1;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_NE(TF_OK, TF_GetCode(status.get()));
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
ASSERT_TRUE(
absl::StrContains(TF_Message(status.get()), "[]")); // kVariantDeviceNull
}
TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
@ -394,5 +352,3 @@ TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get());
}
} // namespace

View File

@ -0,0 +1,172 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// A simple logging device to test custom device registration.
#include <memory>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/test.h"
namespace {
struct LoggingDevice {
tensorflow::string device_name;
tensorflow::string underlying_device;
// Set to true whenever a TensorHandle is copied onto the device
bool* arrived_flag;
// Set to true whenever an operation is executed
bool* executed_flag;
};
struct LoggedTensor {
TFE_TensorHandle* tensor;
LoggedTensor() = delete;
explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
};
void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
delete reinterpret_cast<LoggedTensor*>(data);
}
TFE_TensorHandle* MakeLoggedTensorHandle(
TFE_Context* context, const tensorflow::string& logging_device_name,
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
auto dtype = TFE_TensorHandleDataType(t->tensor);
return TFE_NewTensorHandleFromDeviceMemory(
context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
}
TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status, void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, context, dev->underlying_device.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
auto dst = std::make_unique<LoggedTensor>(t);
*(dev->arrived_flag) = true;
return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
status);
}
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info) {
TF_SetStatus(status, TF_INTERNAL,
"Trying to copy a tensor out of a logging device.");
return nullptr;
}
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s,
void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_Op* op(TFE_NewOp(context, operation_name, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddAttrs(op, attributes);
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
for (int j = 0; j < num_inputs; ++j) {
TFE_TensorHandle* input = inputs[j];
const char* input_device = TFE_TensorHandleDeviceName(input, s);
if (TF_GetCode(s) != TF_OK) return;
if (dev->device_name == input_device) {
LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(input, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddInput(op, t->tensor, s);
} else {
TFE_OpAddInput(op, input, s);
}
if (TF_GetCode(s) != TF_OK) return;
}
std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
TFE_Execute(op, op_outputs.data(), num_outputs, s);
TFE_DeleteOp(op);
if (TF_GetCode(s) != TF_OK) return;
std::vector<TFE_TensorHandle*> unwrapped_outputs;
for (auto* handle : op_outputs) {
unwrapped_outputs.push_back(handle);
}
for (int i = 0; i < *num_outputs; ++i) {
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
std::move(logged_tensor), s);
}
*(dev->executed_flag) = true;
}
void DeleteLoggingDevice(void* device_info) {
delete reinterpret_cast<LoggingDevice*>(device_info);
}
} // namespace
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag, bool* executed_flag,
TF_Status* status) {
TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
custom_device.delete_device = &DeleteLoggingDevice;
custom_device.execute = &LoggingDeviceExecute;
LoggingDevice* device = new LoggingDevice;
device->arrived_flag = arrived_flag;
device->executed_flag = executed_flag;
device->device_name = name;
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
}
TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle,
TF_Status* status) {
return reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(logged_tensor_handle, status))
->tensor;
}
void AllocateLoggingDevice(const char* name, bool* arrived_flag,
bool* executed_flag, TFE_CustomDevice** device,
void** device_info) {
TFE_CustomDevice* custom_device = new TFE_CustomDevice;
custom_device->copy_tensor_to_device = &CopyToLoggingDevice;
custom_device->copy_tensor_from_device = &CopyTensorFromLoggingDevice;
custom_device->delete_device = &DeleteLoggingDevice;
custom_device->execute = &LoggingDeviceExecute;
*device = custom_device;
LoggingDevice* logging_device = new LoggingDevice;
logging_device->arrived_flag = arrived_flag;
logging_device->executed_flag = executed_flag;
logging_device->device_name = name;
logging_device->underlying_device =
"/job:localhost/replica:0/task:0/device:CPU:0";
*device_info = reinterpret_cast<void*>(logging_device);
}

View File

@ -0,0 +1,36 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
#define TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
// A simple logging device to test custom device registration.
#include <memory>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/tf_status.h"
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag, bool* executed_flag,
TF_Status* status);
void AllocateLoggingDevice(const char* name, bool* arrived_flag,
bool* executed_flag, TFE_CustomDevice** device,
void** device_info);
TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle,
TF_Status* status);
#endif // TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_

View File

@ -42,7 +42,28 @@ class AbstractOperationInterface {
virtual Status Reset(const char* op, const char* raw_device_name) = 0;
virtual const string& Name() const = 0;
// Returns the operation's device name.
//
// The value returned may be different from the one set by SetDeviceName, but
// it will be compatible with it: the name will be updated by device placement
// logic to refer to the specific device chosen.
//
// Example: If one calls `op->SetDeviceName("/device:GPU")`, the value
// returned by DeviceName should be "/device:GPU:*" until a particular GPU is
// chosen for the operation by the device placement logic in the
// executor. After that, the value returned by DeviceName will be a full
// device name such as "/job:localhost/replica:0/task:0/device:GPU:1".
virtual const string& DeviceName() const = 0;
// Sets the operation device name.
//
// The given `name` must be parseable by DeviceNameUtils::ParseFullName, and
// the result will be used as a constraint for device placement. See the
// documentation for DeviceName for more details.
//
// The value will override the previous value - that is, no "merging" of
// existing and given constraints will be performed.
virtual Status SetDeviceName(const char* name) = 0;
virtual Status AddInput(AbstractTensorHandleInterface* input) = 0;

View File

@ -0,0 +1,38 @@
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
)
package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "parallel_device",
srcs = ["parallel_device.cc"],
hdrs = ["parallel_device.h"],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:variant",
],
)
tf_cc_test(
name = "parallel_device_test",
srcs = ["parallel_device_test.cc"],
deps = [
":parallel_device",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)

View File

@ -0,0 +1,597 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
#include <memory>
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
namespace tensorflow {
namespace eager {
namespace {
// Functor for making unique_ptrs slightly more ergonomic. Using
// decltype(delete_fn) in the unique_ptr's second template argument requires
// passing a function pointer to delete_fn when constructing the unique_ptr.
class TensorHandleDeleter {
public:
void operator()(TFE_TensorHandle* to_delete) const {
TFE_DeleteTensorHandle(to_delete);
}
};
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
class OpDeleter {
public:
void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
};
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
class ExecutorDeleter {
public:
void operator()(TFE_Executor* to_delete) const {
TFE_DeleteExecutor(to_delete);
}
};
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
class ParallelTensor;
using MaybeParallelTensorOwned =
absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
using MaybeParallelTensorUnowned =
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
// Creates a vector of `count` new executors (threads).
std::vector<ExecutorPtr> MakeExecutors(size_t count) {
std::vector<ExecutorPtr> executors;
executors.reserve(count);
for (int i = 0; i < count; ++i) {
executors.emplace_back(TFE_NewExecutor(true /* is_async */));
}
return executors;
}
// A representation of the custom device passed in and out of the TFE custom
// device APIs, providing context about the parallel device to
// ParallelDeviceExecute.
class ParallelDevice {
public:
ParallelDevice(const std::string& name,
const std::vector<std::string>& devices);
// Helper to copy a tensor handle from another device once for each component
// of the ParallelDevice.
//
// Sets a bad status and returns a nullptr if `tensor` is already on the
// ParallelDevice, or if the individual copies fail.
std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status) const;
// Takes a description of a single operation being executed on the
// ParallelDevice, and in turn runs one operation per component device with
// its corresponding inputs from the input ParallelTensors (or
// implicitly-mirrored tensors on other devices). Wraps the resulting
// per-device and per-output TFE_TensorHandles into one ParallelTensor per
// output of the original operation.
//
// `inputs` are either ParallelTensors, i.e. already on the ParallelDevice, or
// un-replicated TFE_TensorHandles on other devices. TPUReplicatedInput
// requires non-parallel tensors, and TPUReplicatedOutput requires a parallel
// tensor, but other operations will implicitly broadcast non-parallel input
// tensors across the ParallelDevice's component devices.
//
// Two special-cased operations, TPUReplicatedInput and TPUReplicatedOutput,
// pack and un-pack parallel tensors respectively. Only TPUReplicatedOutput
// causes `Execute` to return non-parallel tensors.
//
// Attributes are forwarded to executed operations unmodified.
//
// The returned optional has a value if and only if `status` evaluates to
// TF_OK.
absl::optional<std::vector<MaybeParallelTensorOwned>> Execute(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const;
// Implements the parallel case for `Execute`, where all of the outputs of the
// operation are ParallelTensors, and all inputs are either ParallelTensors or
// should be implicitly broadcast. This means the operation is not
// TPUReplicatedInput or TPUReplicatedOutput.
//
// The returned optional has a value if and only if `status` evaluates to
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
// if sanity checks on dtypes/metadata fail.
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
ExecuteParallelOperation(TFE_Context* context,
std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name,
const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const;
const std::string& device_name() const { return device_name_; }
private:
// The name of the parallel device
// (e.g. "/job:localhost/replica:0/task:0/device:CUSTOM:0")
const std::string device_name_;
// A sequence of device names, indicating which devices replicated operations
// are forwarded to.
const std::vector<std::string> underlying_devices_;
// A sequence of TFE_Executors, one per device, for executing operations in
// parallel.
const std::vector<ExecutorPtr> executors_;
};
// The internal representation of a TFE_TensorHandle placed on a
// ParallelDevice. Contains a tuple of tensors, one on each of the
// `underlying_devices_` of the ParallelDevice.
class ParallelTensor {
public:
// Construct a ParallelTensor from TensorHandles placed on the component
// devices of a ParallelDevice.
static std::unique_ptr<ParallelTensor> FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status);
// Helper to wrap a ParallelTensor into a TFE_TensorHandle which contains it.
static TensorHandlePtr AsTensorHandle(TFE_Context* context,
std::unique_ptr<ParallelTensor> t,
TF_Status* status);
size_t num_tensors() const { return tensors_.size(); }
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
private:
ParallelTensor(const ParallelDevice& device,
std::vector<TensorHandlePtr> tensors,
std::vector<int64_t> shape, const TF_DataType dtype)
: device_(device),
tensors_(std::move(tensors)),
shape_(std::move(shape)),
dtype_(dtype) {}
const ParallelDevice& device_;
const std::vector<TensorHandlePtr> tensors_;
const std::vector<int64_t> shape_;
const TF_DataType dtype_;
};
ParallelDevice::ParallelDevice(const std::string& name,
const std::vector<std::string>& devices)
: device_name_(name),
underlying_devices_(devices),
executors_(MakeExecutors(underlying_devices_.size())) {}
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
const char* current_device = TFE_TensorHandleDeviceName(tensor, status);
if (device_name_ == current_device) {
std::string message(absl::StrCat(
"Tried to copy a TensorHandle to its existing device: ", device_name_));
TF_SetStatus(status, TF_INTERNAL, message.c_str());
return nullptr;
}
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (const std::string& underlying_device_name : underlying_devices_) {
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, context, underlying_device_name.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(t);
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const {
absl::optional<std::vector<MaybeParallelTensorOwned>> result;
// TODO(allenl): We should remove "TPU" from these op names at the very least,
// or consider other ways of packing/unpacking parallel tensors.
if (operation_name == std::string("TPUReplicatedInput")) {
// Special-cased operation for packing per-device tensors into one parallel
// tensor.
if (inputs.size() != underlying_devices_.size()) {
std::string message(absl::StrCat(
"The parallel device ", device_name_, " expected ",
underlying_devices_.size(), " inputs to TPUReplicatedInput, but got ",
inputs.size()));
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
return result;
}
std::vector<TensorHandlePtr> components;
components.reserve(inputs.size());
for (int i = 0; i < inputs.size(); ++i) {
if (absl::holds_alternative<ParallelTensor*>(inputs[i])) {
std::string message(absl::StrCat(
"Expected all inputs to TPUReplicatedInput to be non-parallel "
"TensorHandles. The input ",
i,
" was a parallel tensor (already "
"placed on the parallel device)."));
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
return result;
}
components.emplace_back(TFE_TensorHandleCopySharingTensor(
absl::get<TFE_TensorHandle*>(inputs[i]), status));
}
std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(1);
result_content.push_back(ParallelTensor::FromTensorHandles(
*this, std::move(components), status));
if (TF_GetCode(status) != TF_OK) return result;
result.emplace(std::move(result_content));
return result;
} else if (operation_name == std::string("TPUReplicatedOutput")) {
// Special-cased operation for un-packing one parallel tensor into
// per-device tensors.
OpPtr op(TFE_NewOp(context, operation_name, status));
TFE_OpAddAttrs(op.get(), attributes);
int expected_outputs = TFE_OpGetOutputLength(op.get(), "outputs", status);
if (TF_GetCode(status) != TF_OK) return result;
if (expected_outputs != underlying_devices_.size()) {
std::string message(absl::StrCat(
"The parallel device ", device_name_, " expected ",
underlying_devices_.size(),
" outputs for TPUReplicatedOutput, but got ", expected_outputs));
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
return result;
}
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[0])) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"Expected the input to "
"TPUReplicatedOutput to be a parallel tensor (placed on the "
"parallel device).");
return result;
}
ParallelTensor* t = absl::get<ParallelTensor*>(inputs[0]);
std::vector<MaybeParallelTensorOwned> outputs;
outputs.reserve(t->num_tensors());
for (int i = 0; i < t->num_tensors(); ++i) {
TensorHandlePtr this_output(
TFE_TensorHandleCopySharingTensor(t->tensor(i), status));
outputs.emplace_back(std::move(this_output));
if (TF_GetCode(status) != TF_OK) return result;
}
result.emplace(std::move(outputs));
return result;
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
maybe_parallel_results(
ExecuteParallelOperation(context, std::move(inputs), operation_name,
attributes, expected_max_outputs, status));
if (!maybe_parallel_results.has_value()) return result;
std::vector<std::unique_ptr<ParallelTensor>> parallel_results(
std::move(maybe_parallel_results.value()));
std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(parallel_results.size());
for (std::unique_ptr<ParallelTensor>& parallel_result : parallel_results) {
result_content.push_back(
MaybeParallelTensorOwned(std::move(parallel_result)));
}
result.emplace(std::move(result_content));
return result;
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
ParallelDevice::ExecuteParallelOperation(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const {
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
// Compute per-device per-output tensors
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
per_device_output_tensors.reserve(underlying_devices_.size());
// TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep
// setting the thread-local executor like this.
TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
auto reset_executor = gtl::MakeCleanup([context, previous_executor]() {
TFE_ContextSetExecutorForThread(context, previous_executor);
TFE_DeleteExecutor(previous_executor);
});
int first_op_output_count;
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
TFE_Executor* executor = executors_[device_index].get();
// Note that the `reset_executor` cleanup sets the thread's executor back to
// the value before this function ran.
TFE_ContextSetExecutorForThread(context, executor);
OpPtr op(TFE_NewOp(context, operation_name, status));
if (TF_GetCode(status) != TF_OK) return result;
TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
status);
TFE_OpAddAttrs(op.get(), attributes);
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[input_index])) {
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
// to each parallel operation.
//
// TODO(allenl): There may be smarter ways to do this copy in some
// cases, i.e. with a collective broadcast. We'll need to be careful
// about things that are taken as inputs on the host or on their
// existing device (for multi-device functions).
TFE_OpAddInput(op.get(),
absl::get<TFE_TensorHandle*>(inputs[input_index]),
status);
if (TF_GetCode(status) != TF_OK) return result;
} else {
// Parallel tensors are divided between operations by device.
TFE_OpAddInput(op.get(),
absl::get<ParallelTensor*>(inputs[input_index])
->tensor(device_index),
status);
if (TF_GetCode(status) != TF_OK) return result;
}
}
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
int real_num_outputs = expected_max_outputs;
// For nested devices, the inner device sees the async executor we've
// set. Inner parallel devices will just overwrite this with their own and
// then set it back to ours before returning. This means parallel devices
// which consist of several aliased parallel devices would hypothetically
// deadlock if the outer parallel device ran one collective with a group
// size equal to the total number of aliased physical devices. Currently
// physical devices cannot participate in a single collective reduction
// multiple times, so this would fail earlier.
//
// TODO(allenl): Keep a map from outer executor to list of inner executors
// rather than a single list of executors so aliased nested parallel devices
// don't re-use an executor.
TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
if (device_index == 0) {
first_op_output_count = real_num_outputs;
} else {
if (real_num_outputs != first_op_output_count) {
TF_SetStatus(status, TF_INTERNAL,
"Parallel ops produced different numbers of tensors.");
return result;
}
}
if (TF_GetCode(status) != TF_OK) return result;
std::vector<TensorHandlePtr> this_outputs;
this_outputs.reserve(real_num_outputs);
for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
this_outputs.emplace_back(op_outputs[output_num]);
}
per_device_output_tensors.push_back(std::move(this_outputs));
}
// For each output of the original operation, pack the per-device
// TensorHandles we've computed into a single parallel TensorHandle.
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
per_device_outputs.reserve(first_op_output_count);
for (int i = 0; i < first_op_output_count; ++i) {
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int j = 0; j < underlying_devices_.size(); ++j) {
components.push_back(std::move(per_device_output_tensors[j][i]));
}
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
*this, std::move(components), status));
if (TF_GetCode(status) != TF_OK) return result;
}
result.emplace(std::move(per_device_outputs));
return result;
}
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status) {
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
std::vector<int64_t> shape(
TFE_TensorHandleNumDims(components[0].get(), status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(components[0].get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
// Verify that the TensorHandle's shape and dtype match all of the component
// shapes and dtypes.
for (TensorHandlePtr& component : components) {
for (int i = 0; i < shape.size(); ++i) {
int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (tensor_dim != shape[i]) {
// TODO(allenl): Allow shapes to differ.
TF_SetStatus(status, TF_UNIMPLEMENTED,
"Components of a ParallelTensor must currently all have "
"the same shape");
return nullptr;
}
if (TFE_TensorHandleDataType(component.get()) != dtype) {
TF_SetStatus(status, TF_INTERNAL,
"Components of a ParallelTensor must all have "
"the same dtype");
return nullptr;
}
}
}
return std::unique_ptr<ParallelTensor>(new ParallelTensor(
parallel_device, std::move(components), std::move(shape), dtype));
}
// Used as an argument to TFE_NewTensorHandleFromDeviceMemory, indicating how
// ParallelTensors wrapped in TFE_TensorHandles should be cleaned up once their
// reference counts drop to zero.
void ParallelTensorDeallocator(void* data, size_t len, void* arg) {
delete reinterpret_cast<ParallelTensor*>(data);
}
TensorHandlePtr ParallelTensor::AsTensorHandle(
TFE_Context* context, std::unique_ptr<ParallelTensor> t,
TF_Status* status) {
// The resulting TensorHandle owns an opaque pointer to "device memory", which
// for a ParallelDevice is really a ParallelTensor. When the TensorHandle is
// deleted, it will call ParallelTensorDeallocator to free the struct.
ParallelTensor* t_released = t.release();
return TensorHandlePtr(TFE_NewTensorHandleFromDeviceMemory(
context, t_released->device_.device_name().c_str(), t_released->dtype_,
t_released->shape_.data(), t_released->shape_.size(), t_released, 1,
&ParallelTensorDeallocator, nullptr, status));
}
// For TFE_CustomDevice::copy_tensor_to_device in the parallel device
// registration.
//
// Replicates a single TFE_TensorHandle, producing a TFE_TensorHandle containing
// a ParallelTensor with one copy of `tensor` for each device in the
// ParallelDevice.
//
// Since this function is used to satisfy the TFE_CustomDevice C API,
// device_info is passed in using a C-style generic. It must always be a
// ParallelDevice.
TFE_TensorHandle* CopyToParallelDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status, void* device_info) {
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info);
std::unique_ptr<ParallelTensor> parallel_tensor(
dev->CopyToParallelDevice(context, tensor, status));
if (TF_GetCode(status) != TF_OK) return nullptr;
return ParallelTensor::AsTensorHandle(context, std::move(parallel_tensor),
status)
.release();
}
// For TFE_CustomDevice::copy_tensor_from_device in the parallel device
// registration.
//
// Currently this is an error, and un-packing ParallelTensors must be performed
// explicitly by running a TPUReplicatedOutput operation on the parallel device.
//
// TODO(allenl): There are some use-cases that are only supported by copying to
// host at the moment (e.g. debug print on a tensor, .numpy(), etc.). We either
// need to return something here or address these use-cases one by one.
TFE_TensorHandle* CopyTensorFromParallelDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info) {
TF_SetStatus(status, TF_INTERNAL,
"Trying to copy a tensor out of a parallel device.");
return nullptr;
}
// For TFE_CustomDevice::execute in the parallel device registration.
//
// Since this function is used to satisfy the TFE_CustomDevice C API,
// device_info is passed in using a C-style generic. It must always be a
// ParallelDevice.
void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs,
const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* status,
void* device_info) {
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info);
std::vector<MaybeParallelTensorUnowned> typed_inputs;
typed_inputs.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
const char* tensor_handle_device =
TFE_TensorHandleDeviceName(inputs[i], status);
if (TF_GetCode(status) != TF_OK) return;
if (dev->device_name() == tensor_handle_device) {
// We assume that any tensors already placed on this device are
// ParallelTensors.
typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
TFE_TensorHandleDevicePointer(inputs[i], status)));
if (TF_GetCode(status) != TF_OK) return;
} else {
typed_inputs.emplace_back(inputs[i]);
}
}
absl::optional<std::vector<MaybeParallelTensorOwned>> maybe_typed_outputs(
dev->Execute(context, std::move(typed_inputs), operation_name, attributes,
*num_outputs, status));
if (TF_GetCode(status) != TF_OK) return;
if (!maybe_typed_outputs.has_value()) {
TF_SetStatus(status, TF_INTERNAL, "OK status but no value was returned.");
return;
}
std::vector<MaybeParallelTensorOwned> typed_outputs(
std::move(maybe_typed_outputs.value()));
if (typed_outputs.size() > *num_outputs) {
TF_SetStatus(status, TF_INTERNAL,
"The allocated output buffer was too small.");
return;
}
for (int i = 0; i < typed_outputs.size(); ++i) {
MaybeParallelTensorOwned typed_output(std::move(typed_outputs[i]));
if (absl::holds_alternative<TensorHandlePtr>(typed_output)) {
outputs[i] = absl::get<TensorHandlePtr>(typed_output).release();
} else {
outputs[i] = ParallelTensor::AsTensorHandle(
context,
std::move(absl::get<std::unique_ptr<ParallelTensor>>(
typed_output)),
status)
.release();
if (TF_GetCode(status) != TF_OK) return;
}
}
*num_outputs = typed_outputs.size();
}
// For TFE_CustomDevice::delete_device in the parallel device registration.
//
// Since this function is used to satisfy the TFE_CustomDevice C API,
// device_info is passed in using a C-style generic. It must always be a
// ParallelDevice.
void DeleteParallelDevice(void* device_info) {
delete reinterpret_cast<ParallelDevice*>(device_info);
}
} // namespace
void RegisterParallelDevice(TFE_Context* context, const char* device_name,
const char** underlying_devices,
int num_underlying_devices, TF_Status* status) {
TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToParallelDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromParallelDevice;
custom_device.delete_device = &DeleteParallelDevice;
custom_device.execute = &ParallelDeviceExecute;
std::vector<std::string> underlying_devices_vector;
underlying_devices_vector.reserve(num_underlying_devices);
for (int device_index = 0; device_index < num_underlying_devices;
++device_index) {
underlying_devices_vector.push_back(underlying_devices[device_index]);
}
ParallelDevice* d =
new ParallelDevice(device_name, underlying_devices_vector);
TFE_RegisterCustomDevice(context, custom_device, device_name, d, status);
}
} // namespace eager
} // namespace tensorflow

View File

@ -0,0 +1,62 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
#include "tensorflow/c/eager/c_api.h"
namespace tensorflow {
namespace eager {
// Register a parallel device named `device_name` which forwards operations to
// `underlying_devices`, maintaining "parallel tensors" with components placed
// on each underlying device.
//
// For example if `device_name` is
// "/job:localhost/replica:0/task:0/device:CUSTOM:0"
// and `underlying_devices` is
// {"/job:localhost/replica:0/task:0/device:GPU:0",
// "/job:localhost/replica:0/task:0/device:GPU:1"}
// Then executing an operation on CUSTOM:0 will execute it on GPU:0 and GPU:1.
//
// Implicit copies onto `device_name` are allowed, replicating the value once
// per device in `underlying_devices`. Implicit copies off of the device throw
// an error.
//
// All component tensors must have the same dtype. Currently they must also have
// the same shape, although this requirement may be relaxed in the future.
//
// `device_name` must not name an existing physical or custom device (see
// the documentation for TFE_RegisterCustomDevice for more information).
//
// Tensors may be copied on or off the device explicitly using
// TPUReplicatedInput and TPUReplicatedOutput respectively. For example, with
// two component devices, running `x = TPUReplicatedInput(inputs=[a, b])` on the
// parallel device creates a parallel tensor `x` with `a` on the first of
// `underlying_devices` and `b` on the second. Running `a_unpacked, b_unpacked =
// TPUReplicatedOutput(input=x, num_replicas=2)` un-packs the parallel tensor
// into its components.
//
// `context` owns the parallel device. `underlying_devices` must stay valid
// while the parallel device is in use.
void RegisterParallelDevice(TFE_Context* context, const char* device_name,
const char** underlying_devices,
int num_underlying_devices, TF_Status* status);
} // namespace eager
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_

View File

@ -0,0 +1,917 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
#include <array>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/core/platform/test.h"
// NOTE(allenl): These tests currently go through TFE_Execute and so are
// integration testing rather than purely testing the parallel device. They
// correspond fairly well to the implementation, but testing the C++ directly is
// another option.
// Functor for making unique_ptr to TFE_TensorHandle slightly more
// ergonomic. Using decltype(TFE_DeleteTensorHandle) in the unique_ptr's second
// template argument requires passing a function pointer to
// TFE_DeleteTensorHandle when constructing the unique_ptr.
class TensorHandleDeleter {
public:
void operator()(TFE_TensorHandle* to_delete) {
TFE_DeleteTensorHandle(to_delete);
}
};
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
// A helper for performing common operations on variables. A much more
// restricted stand-in for tf.Variable in Python.
class Variable {
public:
// Construct a Variable from a resource-dtype TFE_TensorHandle and an
// indication of the dtype of the variable's value.
//
// Note that creating this resource-dtype handle can fail, so `Create` is a
// separate static method which returns a status.
Variable(TFE_TensorHandle* handle, TF_DataType type)
: handle_(handle), type_(type) {}
// Helper for constructing a resource handle and wrapping it in a `Variable`
// object.
static Variable* Create(TFE_Context* context, TF_DataType type,
const int64_t* dims, const int num_dims,
const char* device, TF_Status* status);
// Dereferences the backing buffer for the variable. Note that since this can
// fail (it runs operations), it must be called explicitly and the resulting
// `status` checked.
void Destroy(TFE_Context* context, TF_Status* status);
// Reads from the variable.
TensorHandlePtr Read(TFE_Context* context, TF_Status* status);
// Assigns a new value to the variable.
void Assign(TFE_Context* context, TFE_TensorHandle* value, TF_Status* status);
// Adds `value` to the existing value of the variable.
void AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
TF_Status* status);
private:
// Helper for running any single-argument assignment ops (Assign, AssignAdd,
// AssignSub, ...).
void GeneralAssignment(const char* op_name, TFE_Context* context,
TFE_TensorHandle* value, TF_Status* status);
// The a handle for the resource-dtype tensor pointing to the variable's
// buffer.
TFE_TensorHandle* handle_;
// The dtype of the variable's buffer (input dtype for assignments, output
// dtype of read operations).
TF_DataType type_;
};
Variable* Variable::Create(TFE_Context* context, TF_DataType type,
const int64_t* dims, const int num_dims,
const char* device, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op.get(), "dtype", type);
TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status);
TFE_OpSetAttrString(op.get(), "container", "", 0);
// Use the special GUID for no buffer sharing
//
// TODO(allenl): Should we provide a better API for this? AFAIK this is the
// only reasonable way to make variables with no aliasing using the eager C
// API.
std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347";
TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(),
no_sharing.length());
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
TFE_Execute(op.get(), &var_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return new Variable(var_handle, type);
}
void Variable::Destroy(TFE_Context* context, TF_Status* status) {
// Free the backing buffer for the variable.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpAddInput(op.get(), handle_, status);
if (TF_GetCode(status) != TF_OK) return;
const char* device = TFE_TensorHandleDeviceName(handle_, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return;
int num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return;
// Delete the variable handle itself.
TFE_DeleteTensorHandle(handle_);
}
TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op.get(), handle_, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
const char* device = TFE_TensorHandleDeviceName(handle_, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op.get(), "dtype", type_);
int num_retvals = 1;
TFE_TensorHandle* var_value = nullptr;
TFE_Execute(op.get(), &var_value, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(var_value);
}
void Variable::GeneralAssignment(const char* op_name, TFE_Context* context,
TFE_TensorHandle* value, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, op_name, status), &TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetAttrType(op.get(), "dtype", type_);
TFE_OpAddInput(op.get(), handle_, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpAddInput(op.get(), value, status);
if (TF_GetCode(status) != TF_OK) return;
const char* device = TFE_TensorHandleDeviceName(handle_, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op.get(), device, status);
int num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return;
}
void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
TF_Status* status) {
GeneralAssignment("AssignAddVariableOp", context, value, status);
}
void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value,
TF_Status* status) {
GeneralAssignment("AssignVariableOp", context, value, status);
}
// Passed to `TF_NewTensor` to indicate how an array of floats should be
// deleted.
static void FloatDeallocator(void* data, size_t, void* arg) {
delete[] static_cast<float*>(data);
}
// Creates a TFE_TensorHandle with value `v`.
TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) {
const int num_bytes = sizeof(float);
float* values = new float[1];
values[0] = v;
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator,
nullptr),
TF_DeleteTensor);
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
}
// Creates a rank-one TFE_TensorHandle with value `v`.
TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
TF_Status* status) {
const int num_bytes = v.size() * sizeof(float);
float* values = new float[v.size()];
memcpy(values, v.data(), num_bytes);
int64_t dims = v.size();
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes,
&FloatDeallocator, nullptr),
TF_DeleteTensor);
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
}
// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
template <std::size_t num_replicas>
void ExtractPerDeviceValues(
TFE_Context* context, TFE_TensorHandle* input,
std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas);
TFE_OpAddInput(op.get(), input, status);
if (TF_GetCode(status) != TF_OK) return;
const char* device = TFE_TensorHandleDeviceName(input, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_TensorHandle* result_handles[num_replicas];
int num_retvals = num_replicas;
TFE_Execute(op.get(), result_handles, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return;
for (int i = 0; i < num_replicas; ++i) {
(*components)[i].reset(result_handles[i]);
}
}
// Helper to pack `num_replicas` TFE_TensorHandles into one parallel handle.
template <std::size_t num_replicas>
TensorHandlePtr CreatePerDeviceValues(
TFE_Context* context,
const std::array<TFE_TensorHandle*, num_replicas>& components,
const char* device, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "TPUReplicatedInput", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrInt(op.get(), "N", num_replicas);
for (int i = 0; i < num_replicas; ++i) {
TFE_OpAddInput(op.get(), components[i], status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(result_handle);
}
TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
TFE_TensorHandle* second, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "Mul", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op.get(), first, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op.get(), second, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
const char* first_device = TFE_TensorHandleDeviceName(first, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(op.get(), first_device, status);
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(result_handle);
}
// Assert that `handle` is equal to `expected_value`.
void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(expected_value,
*static_cast<float*>(TF_TensorData(value_zero.get())));
}
// Create and modify a variable placed on a parallel device which composes
// `first_device` and `second_device`.
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
const char* second_device) {
// Register the custom device
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::array<const char*, 2> underlying_devices{first_device, second_device};
tensorflow::eager::RegisterParallelDevice(
context, device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle (uninitialized to start) placed on the parallel
// device.
std::function<void(Variable*)> variable_deleter = [&](Variable* to_delete) {
to_delete->Destroy(context, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
delete to_delete;
};
std::unique_ptr<Variable, decltype(variable_deleter)> variable(
Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name,
status.get()),
variable_deleter);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Assign an initial value to the variable, implicitly mirroring it to each
// component device.
{
TensorHandlePtr initial_value = FloatTensorHandle(20., status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
variable->Assign(context, initial_value.get(), status.get());
}
// Read from the variable and verify that we have a parallel tensor.
{
TensorHandlePtr read = variable->Read(context, status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(components[0].get(), 20.);
AssertScalarFloatEq(components[1].get(), 20.);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
// Add a parallel tensor with different values on each device to the variable.
{
TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
std::array<TFE_TensorHandle*, 2> components{value_one.get(),
value_two.get()};
TensorHandlePtr combined_value =
CreatePerDeviceValues(context, components, device_name, status.get());
variable->AssignAdd(context, combined_value.get(), status.get());
}
// Read the variable and verify that each component has the right modified
// value.
{
TensorHandlePtr read = variable->Read(context, status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(components[0].get(), 23.);
AssertScalarFloatEq(components[1].get(), 18.);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
}
TEST(PARALLEL_DEVICE, TestBasicCPU) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*xla*/ false,
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
2),
TF_DeleteBuffer);
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
status.get());
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
BasicTestsForTwoDevices(context.get(),
"/job:localhost/replica:0/task:0/device:CPU:0",
"/job:localhost/replica:0/task:0/device:CPU:1");
}
TEST(PARALLEL_DEVICE, TestBasicCPUAliased) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
BasicTestsForTwoDevices(context.get(),
"/job:localhost/replica:0/task:0/device:CPU:0",
"/job:localhost/replica:0/task:0/device:CPU:0");
}
TEST(PARALLEL_DEVICE, TestBasicTPUAliased) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Skip the test if no TPU is available.
std::unique_ptr<TF_DeviceList, decltype(&TF_DeleteDeviceList)> devices(
TFE_ContextListDevices(context.get(), status.get()), TF_DeleteDeviceList);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool has_tpu = false;
for (int device_index = 0; device_index < TF_DeviceListCount(devices.get());
++device_index) {
std::string device_type =
TF_DeviceListType(devices.get(), device_index, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
if (device_type == "TPU") {
has_tpu = true;
break;
}
}
if (has_tpu) {
BasicTestsForTwoDevices(context.get(),
"/job:localhost/replica:0/task:0/device:TPU:0",
"/job:localhost/replica:0/task:0/device:TPU:0");
}
}
TEST(PARALLEL_DEVICE, TestExplicitCopies) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*xla*/ false,
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
2),
TF_DeleteBuffer);
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
status.get());
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> underlying_devices;
const char* first_device_name =
"/job:localhost/replica:0/task:0/device:CPU:0";
underlying_devices.push_back(first_device_name);
const char* second_device_name =
"/job:localhost/replica:0/task:0/device:CPU:1";
underlying_devices.push_back(second_device_name);
tensorflow::eager::RegisterParallelDevice(
context.get(), device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr cpu_value(FloatTensorHandle(3., status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Copying on to a parallel device is OK.
TensorHandlePtr device_value(TFE_TensorHandleCopyToDevice(
cpu_value.get(), context.get(), device_name, status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* backing_device =
TFE_TensorHandleBackingDeviceName(device_value.get(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(std::string(device_name), backing_device);
// Un-pack the parallel tensor to verify that the copy was successful.
{
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context.get(), device_value.get(), &components,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// The value of the original tensor is replicated on each device.
AssertScalarFloatEq(components[0].get(), 3.);
AssertScalarFloatEq(components[1].get(), 3.);
// Verify that the mirrors are placed on the component devices.
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
// Copies off of parallel devices must be explicit.
TensorHandlePtr copy_back(TFE_TensorHandleCopyToDevice(
device_value.get(), context.get(), first_device_name, status.get()));
ASSERT_EQ(TF_GetCode(status.get()), TF_INTERNAL);
}
TEST(PARALLEL_DEVICE, TestDifferentShapes) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*xla*/ false,
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
2),
TF_DeleteBuffer);
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
status.get());
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> underlying_devices;
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
tensorflow::eager::RegisterParallelDevice(
context.get(), device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create two vectors with different lengths
std::vector<float> size_two_value{1., 2.};
std::vector<float> size_three_value{1., 2., 3.};
TensorHandlePtr size_two(
VectorFloatTensorHandle(size_two_value, status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr size_three(
VectorFloatTensorHandle(size_three_value, status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Try to combine these values into a single parallel tensor.
std::array<TFE_TensorHandle*, 2> components{size_two.get(), size_three.get()};
TensorHandlePtr combined_value = CreatePerDeviceValues(
context.get(), components, device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_UNIMPLEMENTED)
<< TF_Message(status.get());
}
TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*xla*/ false,
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
3),
TF_DeleteBuffer);
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
status.get());
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a parallel device with two CPUs
const char* first_device_name =
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> first_underlying_devices{
"/job:localhost/replica:0/task:0/device:CPU:0",
"/job:localhost/replica:0/task:0/device:CPU:1"};
tensorflow::eager::RegisterParallelDevice(
context.get(), first_device_name, first_underlying_devices.data(),
first_underlying_devices.size(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a second parallel device with the first parallel device and one
// additional CPU.
const char* second_device_name =
"/job:localhost/replica:0/task:0/device:CUSTOM:1";
std::vector<const char*> second_underlying_devices{
"/job:localhost/replica:0/task:0/device:CUSTOM:0",
"/job:localhost/replica:0/task:0/device:CPU:2"};
tensorflow::eager::RegisterParallelDevice(
context.get(), second_device_name, second_underlying_devices.data(),
second_underlying_devices.size(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a tensor on the first parallel device
TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
TensorHandlePtr first_combined_value = CreatePerDeviceValues(
context.get(), components, first_device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Nest the first parallel tensor into a second
TensorHandlePtr value_three(FloatTensorHandle(3., status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
components[0] = first_combined_value.get();
components[1] = value_three.get();
TensorHandlePtr second_combined_value = CreatePerDeviceValues(
context.get(), components, second_device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr negative_one(FloatTensorHandle(3., status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr multiply_result(Multiply(context.get(),
second_combined_value.get(),
negative_one.get(), status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Un-pack the parallel tensor to verify that the operation was
// successful. The resulting structure should be:
// second_device{first_device{1. * 3., 2. * 3.}, 3. * 3.}.
std::array<TensorHandlePtr, 2> second_components;
ExtractPerDeviceValues(context.get(), multiply_result.get(),
&second_components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(second_components[1].get(), 9.);
// Verify that the mirrors are placed on the component devices.
std::string first_device = TFE_TensorHandleBackingDeviceName(
second_components[0].get(), status.get());
ASSERT_EQ(second_underlying_devices[0], first_device);
std::string second_device = TFE_TensorHandleBackingDeviceName(
second_components[1].get(), status.get());
ASSERT_EQ(second_underlying_devices[1], second_device);
// Un-pack the first parallel device's tensor too
std::array<TensorHandlePtr, 2> first_components;
ExtractPerDeviceValues(context.get(), second_components[0].get(),
&first_components, status.get());
AssertScalarFloatEq(first_components[0].get(), 3.);
AssertScalarFloatEq(first_components[1].get(), 6.);
first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(),
status.get());
ASSERT_EQ(first_underlying_devices[0], first_device);
second_device = TFE_TensorHandleBackingDeviceName(first_components[1].get(),
status.get());
ASSERT_EQ(first_underlying_devices[1], second_device);
}
TEST(PARALLEL_DEVICE, TestInvalidPacking) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> underlying_devices;
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
tensorflow::eager::RegisterParallelDevice(
context.get(), device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
{
// Try to pack two TensorHandles onto a parallel device with a single
// component.
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TFE_TensorHandle*, 2> components{value_one.get(),
value_two.get()};
TensorHandlePtr combined_value = CreatePerDeviceValues(
context.get(), components, device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
<< TF_Message(status.get());
}
{
// Try to extract the wrong number of components from a parallel tensor
std::array<TFE_TensorHandle*, 1> correct_components{value_one.get()};
TensorHandlePtr combined_value = CreatePerDeviceValues(
context.get(), correct_components, device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TensorHandlePtr, 2> incorrect_components;
ExtractPerDeviceValues(context.get(), combined_value.get(),
&incorrect_components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
<< TF_Message(status.get());
}
{
// Try to pass a ParallelTensor to TPUReplicatedInput
std::array<TFE_TensorHandle*, 1> correct_components{value_one.get()};
TensorHandlePtr combined_value = CreatePerDeviceValues(
context.get(), correct_components, device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TFE_TensorHandle*, 1> incorrect_components{combined_value.get()};
TensorHandlePtr recombined_value = CreatePerDeviceValues(
context.get(), incorrect_components, device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
<< TF_Message(status.get());
}
{
// Try to pass a non-parallel tensor to TPUReplicatedOutput
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context.get(), "TPUReplicatedOutput", status.get()),
TFE_DeleteOp);
if (TF_GetCode(status.get()) != TF_OK) return;
TFE_OpSetAttrInt(op.get(), "num_replicas", 1);
TFE_OpAddInput(op.get(), value_one.get(), status.get());
if (TF_GetCode(status.get()) != TF_OK) return;
TFE_OpSetDevice(op.get(), device_name, status.get());
if (TF_GetCode(status.get()) != TF_OK) return;
TFE_TensorHandle* result_handles;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handles, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
<< TF_Message(status.get());
}
}
TensorHandlePtr CollectiveSum(TFE_Context* context, TFE_TensorHandle* input,
int group_size, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "CollectiveReduce", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
const char* device = TFE_TensorHandleDeviceName(input, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op.get(), "T", TFE_TensorHandleDataType(input));
TFE_OpSetAttrInt(op.get(), "group_size", group_size);
TFE_OpSetAttrInt(op.get(), "group_key", 0);
TFE_OpSetAttrInt(op.get(), "instance_key", 0);
const std::string merge_op("Add");
TFE_OpSetAttrString(op.get(), "merge_op", merge_op.c_str(),
merge_op.length());
const std::string final_op("Id");
TFE_OpSetAttrString(op.get(), "final_op", final_op.c_str(),
final_op.length());
TFE_OpSetAttrIntList(op.get(), "subdiv_offsets", nullptr, 0);
TFE_OpAddInput(op.get(), input, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(result_handle);
}
TEST(PARALLEL_DEVICE, TestCollective) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*xla*/ false,
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
2),
TF_DeleteBuffer);
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
status.get());
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> underlying_devices;
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
tensorflow::eager::RegisterParallelDevice(
context.get(), device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a tensor on the parallel device
TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
TensorHandlePtr parallel_value = CreatePerDeviceValues(
context.get(), components, device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Run a collective sum, so each component should now be the same.
TensorHandlePtr reduced(
CollectiveSum(context.get(), parallel_value.get(), 2, status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TensorHandlePtr, 2> result_components;
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(result_components[0].get(), 3.);
AssertScalarFloatEq(result_components[1].get(), 3.);
}
void RegisterCollectiveMulFunction(TFE_Context* context,
const char* function_name, int group_size,
TF_Status* status) {
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> body(TF_NewGraph(),
TF_DeleteGraph);
TF_OperationDescription* placeholder_desc =
TF_NewOperation(body.get(), "Placeholder", "Placeholder");
TF_SetAttrType(placeholder_desc, "dtype", TF_FLOAT);
TF_Operation* placeholder_op = TF_FinishOperation(placeholder_desc, status);
if (TF_GetCode(status) != TF_OK) return;
TF_Output x{placeholder_op, 0};
TF_OperationDescription* reduce_desc =
TF_NewOperation(body.get(), "CollectiveReduce", "CollectiveReduce");
TF_SetAttrType(reduce_desc, "T", TF_FLOAT);
TF_SetAttrInt(reduce_desc, "group_size", group_size);
TF_SetAttrInt(reduce_desc, "group_key", 0);
TF_SetAttrInt(reduce_desc, "instance_key", 0);
const std::string merge_op("Mul");
TF_SetAttrString(reduce_desc, "merge_op", merge_op.c_str(),
merge_op.length());
const std::string final_op("Id");
TF_SetAttrString(reduce_desc, "final_op", final_op.c_str(),
final_op.length());
TF_SetAttrIntList(reduce_desc, "subdiv_offsets", nullptr, 0);
TF_AddInput(reduce_desc, x);
TF_Operation* reduce_op = TF_FinishOperation(reduce_desc, status);
if (TF_GetCode(status) != TF_OK) return;
TF_Operation* operations[]{placeholder_op, reduce_op};
TF_Output y{reduce_op, 0};
const char* output_name = "y";
std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)> function(
TF_GraphToFunction(
/* fn_body */ body.get(), /* fn_name */ function_name,
/* append_hash_to_fn_name */ 0, /* num_opers */ 2,
/* opers */ operations, /* ninputs */ 1, /* inputs */ &x,
/* noutputs */ 1, /* outputs */ &y, /* output_names */ &output_name,
/* opts */ nullptr, /* description */ "", /* status */ status),
TF_DeleteFunction);
if (TF_GetCode(status) != TF_OK) return;
TFE_ContextAddFunction(context, function.get(), status);
}
TEST(PARALLEL_DEVICE, TestFunction) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*xla*/ false,
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
2),
TF_DeleteBuffer);
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
status.get());
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> underlying_devices;
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
tensorflow::eager::RegisterParallelDevice(
context.get(), device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* function_name = "test_reduce_mul";
RegisterCollectiveMulFunction(context.get(), function_name, 2, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr value_one(FloatTensorHandle(7., status.get()));
TensorHandlePtr value_two(FloatTensorHandle(9., status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
TensorHandlePtr parallel_value = CreatePerDeviceValues(
context.get(), components, device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context.get(), function_name, status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetDevice(op.get(), device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpAddInput(op.get(), parallel_value.get(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* raw_result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &raw_result_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr reduced(raw_result_handle);
std::array<TensorHandlePtr, 2> result_components;
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(result_components[0].get(), 7. * 9.);
AssertScalarFloatEq(result_components[1].get(), 7. * 9.);
std::string first_device = TFE_TensorHandleBackingDeviceName(
result_components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device = TFE_TensorHandleBackingDeviceName(
result_components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}

View File

@ -119,6 +119,9 @@ inline Tensor& TensorFromInterface(AbstractTensorInterface* tensor) {
return down_cast<TensorInterface*>(tensor)->Tensor();
}
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
} // namespace tensorflow
#endif // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_

View File

@ -156,6 +156,7 @@ cc_library(
":array_grad",
":data_flow_grad",
":image_grad",
":manip_grad",
":math_grad",
":nn_grad",
],
@ -494,6 +495,32 @@ tf_cc_test(
],
)
cc_library(
name = "manip_grad",
srcs = ["gradients/manip_grad.cc"],
deps = [
":cc_ops",
":grad_op_registry",
":gradients",
],
alwayslink = 1,
)
tf_cc_test(
name = "gradients_manip_grad_test",
srcs = ["gradients/manip_grad_test.cc"],
deps = [
":array_ops",
":cc_ops",
":gradient_checker",
":manip_grad",
":testutil",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
# Generates separate libraries for array_ops and math_ops to reduce the dependency count of targets that depend on only these
tf_gen_op_wrappers_cc(
name = "math_ops",

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/framework/grad_op_registry.h"
#include "tensorflow/cc/framework/gradients.h"
#include "tensorflow/cc/ops/manip_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
namespace tensorflow {
namespace ops {
namespace {
Status RollGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
auto shift = op.input(1);
auto axis = op.input(2);
auto grad_op = Roll(scope, grad_inputs[0], Neg(scope, shift), axis);
grad_outputs->push_back(grad_op);
grad_outputs->push_back(NoGradient());
grad_outputs->push_back(NoGradient());
return scope.status();
}
REGISTER_GRADIENT_OP("Roll", RollGrad);
} // namespace
} // namespace ops
} // namespace tensorflow

View File

@ -0,0 +1,51 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/framework/gradient_checker.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/manip_ops.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace tensorflow {
namespace {
using ops::Placeholder;
using ops::Roll;
class ManipGradTest : public ::testing::Test {
protected:
ManipGradTest() : scope_(Scope::NewRootScope()) {}
void RunTest(const Output& x, const TensorShape& x_shape, const Output& y,
const TensorShape& y_shape) {
TF_ASSERT_OK(scope_.status());
float max_error;
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
scope_, {x}, {x_shape}, {y}, {y_shape}, &max_error)));
EXPECT_LT(max_error, 1e-4);
}
Scope scope_;
};
TEST_F(ManipGradTest, RollGrad) {
TensorShape shape({5, 4, 3});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
auto y = Roll(scope_, x, {2, 1}, {0, 1});
RunTest(x, shape, y, shape);
}
} // namespace
} // namespace tensorflow

View File

@ -16,6 +16,12 @@ cc_library(
deps = ["//tensorflow/core:test_main"],
)
filegroup(
name = "quantize_header",
srcs = ["quantize.h"],
visibility = ["//visibility:public"],
)
cc_library(
name = "tfcompile_lib",
srcs = [
@ -27,6 +33,7 @@ cc_library(
"codegen.h",
"compile.h",
"flags.h",
"quantize.h",
],
defines = if_llvm_aarch64_available(["TF_LLVM_AARCH64_AVAILABLE=1"]),
visibility = ["//tensorflow/python:__pkg__"],
@ -37,7 +44,6 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"//tensorflow/compiler/mlir/lite/quantization/xla:quantize",
"//tensorflow/compiler/tf2xla",
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",

View File

@ -24,7 +24,7 @@ limitations under the License.
#include "llvm-c/Target.h"
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
#include "tensorflow/compiler/aot/quantize.h"
#include "tensorflow/compiler/tf2xla/tf2xla.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/client/client_library.h"
@ -46,6 +46,14 @@ limitations under the License.
namespace tensorflow {
namespace tfcompile {
static llvm::ManagedStatic<QuantizeXlaFn> quantize_xla;
bool RegisterQuantizeFn(const QuantizeXlaFn& fn) {
if (*quantize_xla) return false;
*quantize_xla = fn;
return true;
}
namespace {
// Compiles the XLA computation into executable code.
@ -116,9 +124,11 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
} else {
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
}
if (flags.experimental_quantize) {
TF_RETURN_IF_ERROR(mlir::xla_hlo::XlaQuantize(config, &computation));
if (flags.experimental_quantize && *quantize_xla) {
TF_RETURN_IF_ERROR((*quantize_xla)(config, &computation));
}
if (!flags.out_session_module.empty()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
computation.Snapshot());

View File

@ -13,21 +13,29 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
#ifndef TENSORFLOW_COMPILER_AOT_QUANTIZE_H_
#define TENSORFLOW_COMPILER_AOT_QUANTIZE_H_
#include <functional>
#include <iostream>
#include <ostream>
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
namespace mlir {
namespace xla_hlo {
namespace tensorflow {
namespace tfcompile {
// Quantizes the model in the computation.
tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config,
xla::XlaComputation* computation);
using QuantizeXlaFn = std::function<Status(const tf2xla::Config& config,
xla::XlaComputation* computation)>;
} // namespace xla_hlo
} // namespace mlir
// Set the static quantization function to the `fn` if it hasn't been set.
// Return false if the static function has been set.
bool RegisterQuantizeFn(const QuantizeXlaFn& fn);
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
} // namespace tfcompile
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_AOT_QUANTIZE_H_

View File

@ -321,7 +321,7 @@ xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
host_compute_builder.node_name());
// Copy all attributes.
for (auto attr : call_node->attrs()) {
for (const auto& attr : call_node->attrs()) {
host_compute_builder.Attr(attr.first, attr.second);
}
@ -346,7 +346,7 @@ xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
xla_token_input_nodes.emplace_back(kXlaTokenArgNodeName);
auto cluster_deps_it = cluster_deps.find(original_oc_name);
if (cluster_deps_it != cluster_deps.end()) {
for (auto dep : cluster_deps_it->second) {
for (const auto& dep : cluster_deps_it->second) {
xla_token_input_nodes.emplace_back(host_compute_node_name(dep));
}
}
@ -2359,7 +2359,7 @@ Status ExtractOutsideCompilationForFunction(
}
// For XlaHostCompute nodes with dependencies, add control edges between
// them so XlaCompiler can handle them in correct order.
for (auto iter : host_compute_nodes) {
for (const auto& iter : host_compute_nodes) {
Node* host_compute_node = iter.second;
std::vector<string> token_input_node_names;
TF_RETURN_IF_ERROR(GetNodeAttr(host_compute_node->def(),
@ -2479,7 +2479,7 @@ Status ExtractOutsideCompilation(
TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name));
for (auto shape_inference_graph_name : shape_inference_graphs) {
for (const auto& shape_inference_graph_name : shape_inference_graphs) {
TF_RETURN_IF_ERROR(RewriteShapeInferenceGraph(
shape_inference_graph_name, g, pivot_node, fld));
}

View File

@ -1161,7 +1161,7 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
std::vector<string> vall_ops = XlaOpRegistry::GetAllRegisteredOps();
absl::flat_hash_set<string> all_ops(vall_ops.begin(), vall_ops.end());
// Check that user's provided TF operation really exists.
for (auto s : whitelist) {
for (const auto& s : whitelist) {
if (!all_ops.contains(string(s))) {
return errors::InvalidArgument(
"The operation '", s,
@ -1475,7 +1475,7 @@ void MarkForCompilationPassImpl::VLogClusteringSummary() {
<< RatioToString(auto_clustering_info.clustered_node_count(),
graph_->num_nodes());
for (XlaAutoClusteringSummary::Cluster cluster :
for (const XlaAutoClusteringSummary::Cluster& cluster :
auto_clustering_info.clusters()) {
absl::string_view cluster_name = cluster.name();
int size = cluster.size();
@ -1891,6 +1891,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"DynamicStitch",
"Einsum",
"EmptyTensorList",
"EnsureShape",
"ExtractImagePatches",
"Igamma",
"IgammaGradA",

View File

@ -493,7 +493,7 @@ std::pair<string, AttrValue> impl::AttrLiteralHelper(
const std::pair<string, absl::Span<const string>>& string_list_attr) {
AttrValue attr_value;
AttrValue::ListValue* list = attr_value.mutable_list();
for (string s : string_list_attr.second) {
for (const string& s : string_list_attr.second) {
list->add_s(s);
}
return {string_list_attr.first, attr_value};

View File

@ -296,10 +296,10 @@ Status XlaCompilationCache::CompileSingleOp(
arg_shapes.push_back(absl::get<TensorShape>(arg.shape));
}
GraphDebugInfo debug_info;
return CompileGraphToXlaHlo(*graph, {arg_shapes.data(), arg_shapes.size()},
compile_options.use_tuple_arg,
*options.flib_def, debug_info,
options.shape_representation_fn, result);
return CompileGraphToXlaHlo(
*graph, {arg_shapes.data(), arg_shapes.size()},
options.device_type.type_string(), compile_options.use_tuple_arg,
*options.flib_def, debug_info, options.shape_representation_fn, result);
};
return CompileImpl(options, name, args, compile_op,
/*compile_threshold=*/absl::nullopt,

View File

@ -69,6 +69,7 @@ absl::optional<AllocatorStats> XlaDeviceAllocator::GetStats() {
tf_stats.bytes_reserved = se_stats->bytes_reserved;
tf_stats.peak_bytes_reserved = se_stats->peak_bytes_reserved;
tf_stats.bytes_reservable_limit = se_stats->bytes_reservable_limit;
tf_stats.largest_free_block_bytes = se_stats->largest_free_block_bytes;
return tf_stats;
}

View File

@ -479,6 +479,15 @@ Status XlaComputationLaunchContext::PopulateOutputs(
input_output_alias, output_num, ctx, i, shape, &output,
definition_event, stream, use_multiple_streams_));
} else {
auto program_shape =
kernel->computation->GetProgramShape().ValueOrDie();
if (program_shape.result().IsTuple() &&
program_shape.result().tuple_shapes(output_num).IsTuple()) {
return errors::Unimplemented(
"Support for TensorList or Stack crossing the XLA/TF boundary "
"is not implemented");
}
se::DeviceMemoryBase buffer = output.buffer({output_num});
Tensor output_tensor = GetOrCreateTensorForOutput(
output_num, ctx, missing_ctx_input_prefix, input_output_alias,

View File

@ -40,11 +40,11 @@ cc_library(
srcs = ["tf_mlir_opt_main.cc"],
deps = [
":init_mlir",
":passes",
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"@llvm-project//llvm:support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
@ -55,9 +55,15 @@ cc_library(
cc_library(
name = "passes",
visibility = [
":__subpackages__",
"//tensorflow/python:__subpackages__",
],
deps = [
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:QuantOps",
# Link jit lib to link JIT devices required to run
# xla-legalize-tf-with-tf2xla pass.
"//tensorflow/compiler/jit",
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_legalize_tf",
@ -65,33 +71,12 @@ cc_library(
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
"//tensorflow/compiler/mlir/lite/quantization:quantization_passes",
"//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_to_quant",
"//tensorflow/compiler/mlir/lite/quantization/xla:hlo_xla_quantization_passes",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
"//tensorflow/compiler/mlir/xla:buffer_assignment",
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/mlir/xla:hlo_legalize_to_lhlo",
"//tensorflow/compiler/mlir/xla:lhlo",
"//tensorflow/compiler/mlir/xla:lhlo_copy_removal",
"//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_parallel_loops",
"//tensorflow/compiler/mlir/xla:xla_dialect_registration",
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
"//tensorflow/compiler/mlir/xla:xla_legalize_to_linalg",
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
"//tensorflow/compiler/mlir/xla:xla_lower",
"//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts",
"//tensorflow/compiler/mlir/xla:xla_test_passes",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:QuantOps",
],
)
@ -139,11 +124,14 @@ cc_library(
tf_cc_binary(
name = "tf-opt",
deps = [
":passes",
":tf_mlir_opt_main",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
"//tensorflow/compiler/mlir/tfjs:tensorflow_js_dialect_registration",
"//tensorflow/compiler/mlir/xla:all_xla_passes_for_testing",
],
)
@ -170,7 +158,6 @@ tf_cc_binary(
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TranslateClParser",
"@llvm-project//mlir:Translation",
],
)

View File

@ -55,7 +55,7 @@ gentbl(
"ir/tfl_structs.cc.inc",
),
(
"-gen-op-doc",
"-gen-dialect-doc",
"g3doc/tfl_ops.md",
),
],
@ -307,7 +307,7 @@ cc_library(
"transforms/optimize_functional_ops.cc",
"transforms/prepare_composite_functions_tf.cc",
"transforms/prepare_tf.cc",
"transforms/runtime_type_verify.cc",
"transforms/runtime_verify.cc",
"transforms/split_merged_operands.cc",
"transforms/trim_functions_tf.cc",
"transforms/while_loop_outline.cc",
@ -512,7 +512,7 @@ cc_library(
],
deps = [
":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
@ -521,6 +521,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@flatbuffers",
"@llvm-project//llvm:analysis",
"@llvm-project//llvm:support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
@ -537,6 +538,15 @@ tf_native_cc_binary(
],
)
tf_native_cc_binary(
name = "json_to_flatbuffer",
srcs = ["json_to_flatbuffer.cc"],
deps = [
"//tensorflow/lite/schema:schema_fbs",
"@flatbuffers",
],
)
cc_library(
name = "emit_error_reporter",
srcs = [
@ -552,19 +562,16 @@ cc_library(
)
cc_library(
name = "flatbuffer_translate_lib",
name = "flatbuffer_export",
srcs = [
"flatbuffer_export.cc",
"flatbuffer_import.cc",
"utils/convert_type.cc",
],
hdrs = [
"flatbuffer_export.h",
"flatbuffer_export_flags.h",
"flatbuffer_import.h",
"utils/convert_type.h",
],
deps = [
":convert_type",
":flatbuffer_tflite_operator_lib",
":stateful_ops_utils",
":tensorflow_lite",
@ -582,14 +589,12 @@ cc_library(
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:status",
"//tensorflow/lite:framework",
"//tensorflow/lite:schema_fbs_version",
"//tensorflow/lite:string_util",
"//tensorflow/lite/delegates/flex:whitelisted_flex_ops_lib",
"//tensorflow/lite/kernels/internal:kernel_utils",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/tools/versioning",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
@ -604,6 +609,78 @@ cc_library(
],
)
cc_library(
name = "flatbuffer_import",
srcs = [
"flatbuffer_import.cc",
],
hdrs = [
"flatbuffer_import.h",
],
deps = [
":convert_type",
":flatbuffer_tflite_operator_lib",
":tensorflow_lite",
":tensorflow_lite_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
"//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Translation",
],
)
cc_library(
name = "convert_type",
srcs = [
"utils/convert_type.cc",
],
hdrs = [
"utils/convert_type.h",
],
deps = [
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:errors",
"//tensorflow/lite/schema:schema_fbs",
"@llvm-project//mlir:IR",
],
)
cc_library(
name = "flatbuffer_translate_lib",
hdrs = [
"flatbuffer_export.h",
"flatbuffer_export_flags.h",
"flatbuffer_import.h",
"utils/convert_type.h",
],
deps = [
":flatbuffer_export",
":flatbuffer_import",
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:protos_all_cc",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings",
"@llvm-project//mlir:IR",
],
)
cc_library(
name = "flatbuffer_translate_registeration",
srcs = [

View File

@ -36,7 +36,8 @@ struct PassConfig {
form_clusters(false),
unfold_batch_matmul(true),
legalize_tf_while(true),
shape_inference(true) {}
shape_inference(true),
runtime_verification(true) {}
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
// added, which produces TF Lite ops.
@ -65,6 +66,8 @@ struct PassConfig {
bool legalize_tf_while;
// Whether to do shape inference.
bool shape_inference;
// Whether to do TFLite runtime verification.
bool runtime_verification;
};
} // namespace TFL

View File

@ -246,6 +246,50 @@ static void EmitGetBuiltinOpCode(const std::vector<Record *> &defs,
"}\n";
}
// Emits functions that return the min/max operand numbers for a given tflite op
// name.
//
// Signature:
// llvm::MinMax mlir::OperandNumbersMinMax(llvm::StringRef op_name) {
// if(const auto *op = op_union.AsOptions()) {
// return {min, max};
// }
// ...
// return {0, 0};
// }
static void EmitOperandNumbers(const RecordKeeper &record_keeper,
const std::vector<Record *> &defs,
raw_ostream *ostream) {
raw_ostream &os = *ostream;
const auto attr_type = record_keeper.getClass("Attr");
const auto optional_tensor = record_keeper.getClass("TFL_TensorOfOrNone");
os << "llvm::MinMax mlir::OperandNumbersMinMax(llvm::StringRef op_name) {\n";
for (const auto *def : defs) {
auto op_name = def->getValueAsString("opName");
int tail_optional_tensor = 0, tensor_number_max = 0;
auto *arg_values = def->getValueAsDag("arguments");
for (int i = 0, e = arg_values->getNumArgs(); i < e; ++i) {
auto arg = arg_values->getArg(i);
auto *arg_def = dyn_cast<DefInit>(arg);
if (!arg_def) continue;
if (!arg_def->getDef()->isSubClassOf(attr_type)) {
tensor_number_max++;
if (arg_def->getDef()->isSubClassOf(optional_tensor)) {
tail_optional_tensor++;
} else {
tail_optional_tensor = 0;
}
}
}
const int tensor_number_min = tensor_number_max - tail_optional_tensor;
os << formatv(" if (op_name == \"tfl.{0}\") {{\n", op_name)
<< " return {" << tensor_number_min << ", " << tensor_number_max
<< "};\n }\n";
}
os << " return {0, 0};\n}\n";
}
// Emits a builder function that returns the packed FlatBuffer object given
// a general mlir::Operation.
//
@ -374,6 +418,8 @@ static bool OperatorWritersMain(raw_ostream &os, RecordKeeper &records) {
EmitBuildOperator(defs, &os);
os << "\n\n";
EmitBuiltinOptionsToAttributes(records, defs, &os);
os << "\n\n";
EmitOperandNumbers(records, defs, &os);
return false;
}
@ -441,7 +487,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
mlir::tblgen::FmtContext verify_ctx;
os << "::mlir::LogicalResult " << op.getCppClassName()
<< "::VerifyTflRuntimeTypes(::mlir::Operation *op, bool "
<< "::VerifyTflRuntimeConstraints(::mlir::Operation *op, bool "
"failure_on_operand_type_mismatch) {\n";
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
verify_ctx.withOp("top");
@ -450,7 +496,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
auto &value = op.getOperand(i);
// Skip from from first variadic operands for now. Else getOperand index
// used below doesn't match.
if (value.isVariadic()) break;
if (value.isVariableLength()) break;
if (!value.name.empty())
verify_ctx.addSubst(value.name, formatv("op->getOperand({0})", i));
}
@ -458,7 +504,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
auto &value = op.getResult(i);
// Skip from from first variadic results for now. Else getResult index
// used below doesn't match.
if (value.isVariadic()) break;
if (value.isVariableLength()) break;
if (!value.name.empty())
verify_ctx.addSubst(value.name, formatv("op->getResult({0})", i));
}
@ -466,6 +512,25 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
"operand");
GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(),
"result");
for (auto &trait : op.getTraits()) {
if (!trait.getDef().isSubClassOf("GenInternalOpTrait")) {
continue;
}
if (trait.getDef().getValueAsString("trait") !=
"OpTrait::TFLRuntimeOpTrait") {
continue;
}
auto *val = trait.getDef().getValue("tflRuntimePredicate");
if (!val) continue;
mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
os << tgfmt(
" if (!($0)) {\n "
" return ::mlir::LogicalResult::Failure;\n }\n",
&verify_ctx, tgfmt(pred.getCondition(), &verify_ctx));
}
os << " return top.verify();\n}\n";
}

View File

@ -9,7 +9,9 @@ cc_library(
name = "cost_estimators",
textual_hdrs = [
"estimator.h",
"cpu_estimators.h",
"gpu_estimators.h",
"hardware.h",
"arithmetic_count_util.h",
],
)

View File

@ -0,0 +1,45 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ARITHMETIC_COUNT_UTIL_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ARITHMETIC_COUNT_UTIL_H_
// For add/mul/div/sub and other broadcastable ops.
class ArithmeticCountUtilHelper {
public:
static bool GetArithmeticCountForBroadcastableOp(mlir::Operation* op,
int64_t* count) {
auto output = op->getResult(0);
auto output_type = output.getType().dyn_cast_or_null<RankedTensorType>();
if (!output_type || !output_type.hasStaticShape()) return false;
*count = output_type.getNumElements();
return true;
}
static bool GetInputTensorTotalSize(mlir::Operation* op, int64_t* count) {
int64_t total_count = 0;
for (auto input : op->getOperands()) {
auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
if (!input_type || !input_type.hasStaticShape()) {
return false;
}
total_count += input_type.getNumElements();
}
*count = total_count;
return true;
}
};
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ARITHMETIC_COUNT_UTIL_H_

View File

@ -0,0 +1,103 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_CPU_ESTIMATORS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_CPU_ESTIMATORS_H_
// CPU
constexpr float kCPUArithmeticUnitCost = 1.0;
// This basically assumes pure load/store. This is just fake data.
constexpr float kCPUCopyUnitCost = 0.5;
constexpr float kCPUDefaultCost = 3.0f;
// Default values.
constexpr float kCPUDefaultFixedValuedCost = 10000.0;
// tfl.add
template <>
class TFLiteCostEstimator<AddOp, hardware::CPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t count;
if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op,
&count))
return kCPUArithmeticUnitCost * count;
return kCPUDefaultFixedValuedCost;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.mul
template <>
class TFLiteCostEstimator<MulOp, hardware::CPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t count;
if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op,
&count))
return kCPUArithmeticUnitCost * count;
return kCPUDefaultFixedValuedCost;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.concatenation
template <>
class TFLiteCostEstimator<ConcatenationOp, hardware::CPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t count;
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
return kCPUCopyUnitCost * count;
return kCPUDefaultFixedValuedCost;
}
// TODO(renjieliu): We probably need to check for dynamic weights.
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.pack
template <>
class TFLiteCostEstimator<PackOp, hardware::CPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t count;
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
return kCPUCopyUnitCost * count;
return kCPUDefaultFixedValuedCost;
}
// TODO(renjieliu): We probably need to check for dynamic weights.
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.reshape
template <>
class TFLiteCostEstimator<ReshapeOp, hardware::CPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t count;
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
return kCPUCopyUnitCost * count;
return kCPUDefaultFixedValuedCost;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_CPU_ESTIMATORS_H_

View File

@ -16,9 +16,19 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
// tfl.add
// GPU
constexpr float kGPUArithmeticUnitCost = 0.2;
// The copy can be non-consectutive copy. This is just fake data.
constexpr float kGPUCopyUnitCost = 0.2;
constexpr float kGPUDefaultCost = 1.0f;
// Default values.
constexpr float kGPUDefaultFixedValuedCost = 10000.0;
// tfl.abs
template <>
class TFLiteCostEstimator<AddOp, hardware::GPU> {
class TFLiteCostEstimator<AbsOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
@ -29,6 +39,21 @@ class TFLiteCostEstimator<AddOp, hardware::GPU> {
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.add
template <>
class TFLiteCostEstimator<AddOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t count;
if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op,
&count))
return kGPUArithmeticUnitCost * count;
return kGPUDefaultFixedValuedCost;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.average_pool_2d
template <>
class TFLiteCostEstimator<AveragePool2DOp, hardware::GPU> {
@ -47,9 +72,10 @@ template <>
class TFLiteCostEstimator<ConcatenationOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
int64_t count;
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
return kGPUCopyUnitCost * count;
return kGPUDefaultFixedValuedCost;
}
// TODO(renjieliu): We probably need to check for dynamic weights.
@ -149,6 +175,19 @@ class TFLiteCostEstimator<HardSwishOp, hardware::GPU> {
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.log
template <>
class TFLiteCostEstimator<LogOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.logistic
template <>
class TFLiteCostEstimator<LogisticOp, hardware::GPU> {
@ -201,6 +240,33 @@ class TFLiteCostEstimator<MaximumOp, hardware::GPU> {
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.max_unpooling_2d
template <>
class TFLiteCostEstimator<MaxUnpooling2DOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.mean
template <>
class TFLiteCostEstimator<MeanOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
// TODO(renjieiu): check for constraints.
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.minimum
template <>
class TFLiteCostEstimator<MinimumOp, hardware::GPU> {
@ -219,9 +285,11 @@ template <>
class TFLiteCostEstimator<MulOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
int64_t count;
if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op,
&count))
return kGPUArithmeticUnitCost * count;
return kGPUDefaultFixedValuedCost;
}
static bool IsSupported(mlir::Operation* op) { return true; }
@ -240,6 +308,32 @@ class TFLiteCostEstimator<PadOp, hardware::GPU> {
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.pow
template <>
class TFLiteCostEstimator<PowOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.prelu
template <>
class TFLiteCostEstimator<PReluOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.relu
template <>
class TFLiteCostEstimator<ReluOp, hardware::GPU> {
@ -269,6 +363,33 @@ class TFLiteCostEstimator<Relu6Op, hardware::GPU> {
// tfl.reshape
template <>
class TFLiteCostEstimator<ReshapeOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t count;
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
return kGPUCopyUnitCost * count;
return kGPUDefaultFixedValuedCost;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.rsqrt
template <>
class TFLiteCostEstimator<RsqrtOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.sin
template <>
class TFLiteCostEstimator<SinOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
@ -305,6 +426,58 @@ class TFLiteCostEstimator<SoftmaxOp, hardware::GPU> {
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.space_to_depth
template <>
class TFLiteCostEstimator<SpaceToDepthOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.sqrt
template <>
class TFLiteCostEstimator<SqrtOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.square
template <>
class TFLiteCostEstimator<SquareOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.squared_difference
template <>
class TFLiteCostEstimator<SquaredDifferenceOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.strided_slice
template <>
class TFLiteCostEstimator<StridedSliceOp, hardware::GPU> {
@ -318,6 +491,19 @@ class TFLiteCostEstimator<StridedSliceOp, hardware::GPU> {
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.tanh
template <>
class TFLiteCostEstimator<TanhOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.transpose
template <>
class TFLiteCostEstimator<TransposeOp, hardware::GPU> {
@ -331,5 +517,18 @@ class TFLiteCostEstimator<TransposeOp, hardware::GPU> {
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.transpose_conv
template <>
class TFLiteCostEstimator<TransposeConvOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_

View File

@ -59,13 +59,11 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/Functional.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Translation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -579,6 +577,24 @@ StatusOr<Operation*> ConvertOp(
op_state.addTypes({type});
}
// While the last several tensors could be optional tensors for an tfl op, the
// number of input operands could vary. Gets the min/max number of
// operands from tflite op name.
// Also, since the above code special-handles the `tfl.reshape` op and add an
// additional input, we put these function block here.
llvm::MinMax input_min_max = mlir::OperandNumbersMinMax(op_name);
int input_max_num = input_min_max.Max;
int op_input_num = op_state.operands.size();
if (input_max_num != 0 && input_max_num > op_input_num) {
// If the number of current inputs is less than the op definition, fill in
// with `none` value,
llvm::SmallVector<Value, 4> none_operands(
input_max_num - op_input_num,
builder.create<mlir::ConstantOp>(loc, builder.getNoneType(),
builder.getUnitAttr()));
op_state.addOperands(ArrayRef<Value>(none_operands));
}
if (op_name == "tfl.lstm") {
// TODO(b/147587779): add the right region if region is empty.
op_state.addRegion();
@ -658,8 +674,8 @@ template <typename ContainerType>
mlir::NamedAttribute BuildTFEntryFunctionAttribute(
const tflite::SubGraphT& subgraph, Builder* builder, const std::string name,
const ContainerType indices) {
llvm::SmallVector<std::string, 8> tensor_names = mlir::functional::map(
[&](int i) { return subgraph.tensors.at(i)->name; }, indices);
auto tensor_names = llvm::map_range(
indices, [&](int i) { return subgraph.tensors.at(i)->name; });
return builder->getNamedAttr(
name, builder->getStringAttr(llvm::join(tensor_names, ",")));
}

View File

@ -25,7 +25,7 @@ limitations under the License.
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h"

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/AssumeBundleQueries.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
@ -55,6 +56,11 @@ void BuiltinOptionsToAttributes(
// NOLINTNEXTLINE
llvm::SmallVectorImpl<mlir::NamedAttribute> &attributes);
// While the last several tensors could be optional tensors for an tfl op, the
// number of input operands could vary. This function gets the min/max number of
// operands from tflite op name.
llvm::MinMax OperandNumbersMinMax(llvm::StringRef op_name);
// Populates the array of mlir::NamedAttributes corresponding to the given
// custom_options.
// We use an out parameter per LLVM convention

View File

@ -86,7 +86,7 @@ def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> {
let methods = [
StaticInterfaceMethod<
[{Returns whether the op's operands/results are supported by runtime.}],
"LogicalResult", "VerifyTflRuntimeTypes",
"LogicalResult", "VerifyTflRuntimeConstraints",
(ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch)
>,
];

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
@ -45,6 +46,30 @@ namespace mlir {
#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc"
namespace TFL {
// Returns true when the given two types have the same shape or broadcastable
// shape within the given rank. If any given shapes are non-static, this method
// returns true.
bool IsBinaryOperandsHaveSameShapesOrBroadcastableShape(Type lhs, Type rhs,
int max_bcast_rank) {
// Ignore shape checking on the non-static shapes for model compatibility.
auto lhs_shaped_type = lhs.dyn_cast<ShapedType>();
if (!lhs_shaped_type || !lhs_shaped_type.hasStaticShape()) return true;
auto rhs_shaped_type = rhs.dyn_cast<ShapedType>();
if (!rhs_shaped_type || !rhs_shaped_type.hasStaticShape()) return true;
if (lhs_shaped_type.getShape().equals(rhs_shaped_type.getShape()))
return true;
SmallVector<int64_t, 4> result_shape;
if (!OpTrait::util::getBroadcastedShape(lhs_shaped_type.getShape(),
rhs_shaped_type.getShape(),
result_shape)) {
return false;
}
return lhs_shaped_type.getRank() <= max_bcast_rank &&
rhs_shaped_type.getRank() <= max_bcast_rank;
}
//===----------------------------------------------------------------------===//
// TensorFlowLiteDialect
//===----------------------------------------------------------------------===//
@ -315,7 +340,7 @@ Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
const int num_elements = result_shape_type.getNumElements();
new_values.reserve(num_elements);
for (APFloat old_value : dense_elements.getValues<APFloat>()) {
for (const APFloat &old_value : dense_elements.getValues<APFloat>()) {
new_values.push_back(calculate(old_value));
}
@ -843,7 +868,7 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
if (!shape_elements) return nullptr;
SmallVector<int64_t, 4> shape_data;
for (auto it : shape_elements.getValues<APInt>()) {
for (const auto &it : shape_elements.getValues<APInt>()) {
shape_data.push_back(it.getSExtValue());
}
result_type =
@ -1266,10 +1291,65 @@ static LogicalResult Verify(SplitVOp op) {
static LogicalResult Verify(LSTMOp op) {
auto operands = op.GetStatefulOperands();
if (operands.size() == 2 && operands[0] == 18 && operands[1] == 19) {
return success();
if (operands.size() != 2 || operands[0] != 18 || operands[1] != 19) {
return op.emitOpError("LSTMOp expected to have two stateful operands");
}
return op.emitError("LSTMOp expected to have two stateful operands");
const auto input_type = op.input().getType().cast<ShapedType>();
// Since TFLite runtime generally supports dynamic shape/rank, if `input_type`
// doesn't have static shape, we skip the shape check below.
if (!input_type.hasStaticShape()) return success();
// The input should be at least 2D tensor since it will go through fully
// connected layer.
if (!input_type.hasRank() || input_type.getRank() < 2)
return op.emitOpError(
"the first input operand should have more than 2 dimensions.");
const auto activation_state =
op.input_activation_state().getType().cast<ShapedType>();
const auto cell_state = op.input_cell_state().getType().cast<ShapedType>();
const auto input_to_output_weights =
op.input_to_output_weights().getType().cast<ShapedType>();
const auto recurrent_to_output_weights =
op.recurrent_to_output_weights().getType().cast<ShapedType>();
if (activation_state.hasStaticShape() && cell_state.hasStaticShape() &&
input_to_output_weights.hasStaticShape() &&
recurrent_to_output_weights.hasStaticShape()) {
const int n_input = input_type.getDimSize(input_type.getRank() - 1);
const int n_cell = input_to_output_weights.getDimSize(0);
const int n_output = recurrent_to_output_weights.getDimSize(1);
const int output_state_size = activation_state.getNumElements();
const int n_batch = input_type.getRank() == 2 ? input_type.getDimSize(0)
: input_type.getDimSize(1);
const int state_size = cell_state.getNumElements();
// Check if the dimension of the inputs matches.
if ((output_state_size != n_batch * n_output) ||
(state_size != n_batch * n_cell) ||
(input_to_output_weights.getDimSize(1) != n_input) ||
(recurrent_to_output_weights.getRank() != 2) ||
(recurrent_to_output_weights.getDimSize(0) != n_cell) ||
(input_to_output_weights.getRank() != 2)) {
return op.emitOpError("inputs don't match with the dimensions.");
}
const bool is_layer_norm_lstm =
!op.forget_layer_norm_coefficients().getType().isa<NoneType>();
if (is_layer_norm_lstm) {
const auto forget_layer_norm_coefficients =
op.forget_layer_norm_coefficients().getType().cast<ShapedType>();
// If this lstm has layer normalization, this input value,
// "forget_layer_norm_coefficients" should be a 1D tensor.
if (forget_layer_norm_coefficients.getRank() != 1 ||
forget_layer_norm_coefficients.getDimSize(0) != n_cell)
return op.emitOpError(
"coefficient inputs have more than 2 dimensions or "
"don't match the dimension with input operand "
"`input_to_output_weights`.");
}
}
return success();
}
//===----------------------------------------------------------------------===//
@ -1742,7 +1822,7 @@ static LogicalResult Verify(TransposeOp op) {
int index = 0;
llvm::SmallVector<int64_t, 4> axes;
for (auto axis_int : perm.getValues<APInt>()) {
for (const auto &axis_int : perm.getValues<APInt>()) {
const int64_t axis = axis_int.getSExtValue();
if (axis < 0 || (input_type.hasRank() && axis >= input_type.getRank())) {
return op.emitOpError(

View File

@ -28,7 +28,6 @@ limitations under the License.
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
#include "mlir/Support/Functional.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/lite/schema/schema_generated.h"
@ -54,6 +53,8 @@ class TensorFlowLiteDialect : public Dialect {
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc"
// Include all specializes estimators below this line
#include "tensorflow/compiler/mlir/lite/experimental/estimators/arithmetic_count_util.h"
#include "tensorflow/compiler/mlir/lite/experimental/estimators/cpu_estimators.h"
#include "tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h"
} // end namespace TFL

View File

@ -106,6 +106,22 @@ class DerivedShapeAttr<code body> : DerivedAttr<"ArrayRef<int64_t>", body>;
class DerivedTFLiteTypeAttr<code body> :
DerivedAttr<"tflite::TensorType", body>;
// TFL Runtime op trait predicate.
class TFL_RuntimePredOpTrait<string desc, Pred pred> :
GenInternalOpTrait<"TFLRuntimeOpTrait"> {
Pred tflRuntimePredicate = pred;
string tflRuntimeDescription = desc;
}
class TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<
int i, int j, int max_bcast_rank> :
TFL_RuntimePredOpTrait<"operand #" # i # " and operand #" # j #
" have the same shape or broadcastable shapes within the rank " #
max_bcast_rank,
CPred<"TFL::IsBinaryOperandsHaveSameShapesOrBroadcastableShape("
"$_op.getOperand(" # i # ").getType(), $_op.getOperand(" # j #
").getType(), " # max_bcast_rank # ")">>;
// These additional types/type constraints here are used to decouple the ops
// from runtime support for the ops. Prefer to use these types when defining
// new TF_Ops for uniformity.
@ -344,7 +360,10 @@ class TFL_ConvOp<string mnemonic, string opSummary, int index> :
// TFL op definitions.
//===----------------------------------------------------------------------===//
def TFL_AbsOp : TFL_Op<"abs", [
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
NoSideEffect,
SameOperandsAndResultType,
NoQuantizableResult,
TFL_GpuTargetOp]> {
let summary = "Absolute value operator";
let description = [{
@ -360,10 +379,9 @@ an output element, this operation computes \\(y = |x|\\).
let hasFolder = 1;
}
def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape,
NoSideEffect,
Commutative,
TFL_GpuTargetOp]> {
def TFL_AddOp : TFL_Op<"add", [
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>,
ResultsBroadcastableShape, NoSideEffect, Commutative, TFL_GpuTargetOp]> {
let summary = "Addition operator";
let description = [{
@ -371,11 +389,11 @@ def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape,
}];
let arguments = (
ins AnyTensor:$lhs,
AnyTensor:$rhs,
ins TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$lhs,
TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$rhs,
TFL_AFAttr:$fused_activation_function);
let results = (outs AnyTensor:$output);
let results = (outs TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$output);
let hasFolder = 1;
@ -432,7 +450,7 @@ retained with length 1.
}
def TFL_TransposeConvOp:
TFL_Op<"transpose_conv", [NoSideEffect]> {
TFL_Op<"transpose_conv", [NoSideEffect, TFL_GpuTargetOp]> {
let summary = "Transpose convolution operator";
let description = [{
@ -1527,7 +1545,10 @@ def TFL_LogisticOp: TFL_Op<"logistic", [
}
def TFL_LogOp: TFL_Op<"log", [
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
NoSideEffect,
SameOperandsAndResultType,
NoQuantizableResult,
TFL_GpuTargetOp]> {
let summary = "Natural logarithm operator";
let description = [{
@ -1637,7 +1658,7 @@ def TFL_MaxPoolingWithArgMax2DOp :
}
def TFL_MaxUnpooling2DOp :
Op<TFL_Dialect, "max_unpooling_2d", [NoSideEffect]> {
Op<TFL_Dialect, "max_unpooling_2d", [NoSideEffect, TFL_GpuTargetOp]> {
let summary = "Max Unpool 2D";
let description = [{
@ -1690,7 +1711,7 @@ def TFL_MaximumOp : TFL_Op<"maximum", [
let hasOptions = 0;
}
def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect]> {
def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect, TFL_GpuTargetOp]> {
let summary = "Mean operator";
let description = [{
@ -2072,7 +2093,10 @@ def TFL_PadV2Op : TFL_Op<"padv2", [
let hasOptions = 1;
}
def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape,
NoSideEffect,
NoQuantizableResult,
TFL_GpuTargetOp]> {
let summary = "Power operator";
let description = [{
@ -2092,7 +2116,7 @@ def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape, NoSideEffect, NoQuanti
let builders = [TFL_BroadcastableBinaryBuilder];
}
def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect]> {
def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect, TFL_GpuTargetOp]> {
let summary = "Parameterized Relu operator";
let description = [{
@ -2141,6 +2165,17 @@ def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect,
let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x);
let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y);
// This builder doesn't work with quantized type, so it can only be used by
// non-quantization tablegen patterns. Currently, it is used by the
// elementwise-move reordering pattern in the optimize_patterns.td
let builders = [OpBuilder<
"Builder *, OperationState &state, Value input",
[{
state.addOperands({input});
state.addTypes(input.getType());
}]>
];
}
def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
@ -2157,6 +2192,17 @@ def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x);
let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y);
// This builder doesn't work with quantized type, so it can only be used by
// non-quantization tablegen patterns. Currently, it is used by the
// elementwise-move reordering pattern in the optimize_patterns.td
let builders = [OpBuilder<
"Builder *, OperationState &state, Value input",
[{
state.addOperands({input});
state.addTypes(input.getType());
}]>
];
}
def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect,
@ -2172,6 +2218,17 @@ def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect,
let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x);
let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y);
// This builder doesn't work with quantized type, so it can only be used by
// non-quantization tablegen patterns. Currently, it is used by the
// elementwise-move reordering pattern in the optimize_patterns.td
let builders = [OpBuilder<
"Builder *, OperationState &state, Value input",
[{
state.addOperands({input});
state.addTypes(input.getType());
}]>
];
}
def TFL_ReshapeOp: TFL_Op<"reshape", [
@ -2223,7 +2280,10 @@ slice `i`, with the first `seq_lengths[i]` slices along dimension
let hasOptions = 1;
}
def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect,
SameOperandsAndResultType,
NoQuantizableResult,
TFL_GpuTargetOp]> {
let summary = "Reciprocal of square root operator";
let description = [{
@ -2371,7 +2431,10 @@ def TFL_SelectV2Op : TFL_Op<"select_v2", [NoSideEffect]> {
}
def TFL_SinOp: TFL_Op<"sin", [
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
NoSideEffect,
SameOperandsAndResultType,
NoQuantizableResult,
TFL_GpuTargetOp]> {
let summary = "Sine operator";
let description = [{
@ -2413,7 +2476,10 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [
}
def TFL_SqrtOp: TFL_Op<"sqrt", [
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
NoSideEffect,
SameOperandsAndResultType,
NoQuantizableResult,
TFL_GpuTargetOp]> {
let summary = "Square root operator";
let description = [{
@ -2428,7 +2494,10 @@ def TFL_SqrtOp: TFL_Op<"sqrt", [
}
def TFL_SquareOp: TFL_Op<"square", [
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
NoSideEffect,
SameOperandsAndResultType,
NoQuantizableResult,
TFL_GpuTargetOp]> {
let summary = "Square operator";
let description = [{
@ -2472,7 +2541,10 @@ def TFL_SubOp : TFL_Op<"sub", [ResultsBroadcastableShape, NoSideEffect]> {
// TODO(jpienaar): Expand the kernel implementation to support all types besides
// I32 and F32.
def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
ResultsBroadcastableShape,
NoSideEffect,
NoQuantizableResult,
TFL_GpuTargetOp]> {
let summary = "Squared difference operator";
let description = [{
@ -2499,7 +2571,8 @@ def TFL_TanhOp: TFL_Op<"tanh", [
// zero_point = central_value
// scale = 1. / (central_value - min_value)
FixedResultScale<Int8UniformQuantizedType<0, 78125, -7>>,
FixedResultScale<UInt8UniformQuantizedType<128, 78125, -7>>]> {
FixedResultScale<UInt8UniformQuantizedType<128, 78125, -7>>,
TFL_GpuTargetOp]> {
let summary = "Hyperbolic tangent operator";
let description = [{
@ -2509,6 +2582,17 @@ def TFL_TanhOp: TFL_Op<"tanh", [
let arguments = (ins TFL_TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$x);
let results = (outs TFL_TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$y);
// This builder doesn't work with quantized type, so it can only be used by
// non-quantization tablegen patterns. Currently, it is used by the
// elementwise-move reordering pattern in the optimize_patterns.td
let builders = [OpBuilder<
"Builder *, OperationState &state, Value input",
[{
state.addOperands({input});
state.addTypes(input.getType());
}]>
];
}
def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, SameOperandsAndResultsScale,
@ -2694,7 +2778,8 @@ def TFL_SpaceToDepthOp: TFL_Op<"space_to_depth", [
NoSideEffect,
SameOperandsAndResultsScale,
PredOpTrait<"input and output must have same element type",
TCresVTEtIsSameAsOp<0, 0>>
TCresVTEtIsSameAsOp<0, 0>>,
TFL_GpuTargetOp
]> {
let summary = "SpaceToDepth operator";
@ -2957,14 +3042,13 @@ def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [
}];
let arguments = (ins
// TODO: add uint8 support when ready.
TFL_TensorOf<[F32, I32, I64]>:$input,
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8]>:$input,
TFL_TensorOf<[I32, I64]>:$pad,
TFL_MirrorPaddingAttr:$mode
);
let results = (outs
TFL_TensorOf<[F32, I32, I64]>:$output
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8]>:$output
);
let hasOptions = 1;

View File

@ -0,0 +1,63 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stdint.h>
#include <cstddef>
#include <cstdio>
#include <iostream>
#include <string>
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "flatbuffers/idl.h" // from @flatbuffers
#include "flatbuffers/util.h" // from @flatbuffers
#include "tensorflow/lite/schema/schema_generated.h"
int main(int argc, char** argv) {
// load FlatBuffer schema (.fbs) and JSON from disk
if (argc < 2) {
std::cerr << "Missing input argument. Usage:\n"
<< argv[0] << " <filename or - for stdin>\n\n";
return 1;
}
const char* schema_path = argv[1];
const char* json_path = argv[2];
std::string schema;
std::string json;
const bool status =
flatbuffers::LoadFile(schema_path, /*binary=*/false, &schema) &&
flatbuffers::LoadFile(json_path, /*binary=*/false, &json);
if (!status) {
std::cerr << "couldn't load files!\n";
return 1;
}
// parse schema first, so we can use it to parse the data after
flatbuffers::Parser parser;
const bool schema_parse_result =
parser.Parse(schema.c_str()) && parser.Parse(json.c_str());
if (!schema_parse_result) {
std::cerr << "Parse error.\n";
return 1;
}
const size_t length = parser.builder_.GetSize();
const size_t n =
std::fwrite(parser.builder_.GetBufferPointer(), 1, length, stdout);
if (n != length) {
std::cerr << "print to stdout filed.\n";
return 1;
}
return 0;
}

View File

@ -88,7 +88,6 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
pass_config.lower_tensor_list_ops = true;
pass_config.shape_inference = false;
return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module),
pass_config, result);

View File

@ -16,9 +16,12 @@ limitations under the License.
#include <utility>
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project
@ -41,6 +44,77 @@ limitations under the License.
namespace tensorflow {
Status HandleInputOutputArraysWithModule(const toco::ModelFlags& model_flags,
mlir::OwningModuleRef* module) {
mlir::FuncOp entry_function = nullptr;
for (auto func : module->get().getOps<mlir::FuncOp>()) {
if (auto tf_attrs =
func.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function")) {
// TODO(jaesung): There could be multiple entry functions. Let's handle
// such cases if there are any needs for that.
if (entry_function != nullptr) {
return errors::InvalidArgument(
"There should be only one tf.entry_function");
}
entry_function = func;
}
}
if (entry_function == nullptr) {
return errors::InvalidArgument("no tf.entry_function found");
}
// Get the list of input Op names from the function attribute.
mlir::DictionaryAttr tf_attrs =
entry_function.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
llvm::SmallVector<llvm::StringRef, 4> function_input_names;
function_input_names.reserve(model_flags.input_arrays().size());
auto input_attr = tf_attrs.get("inputs");
if (!input_attr) {
return errors::InvalidArgument("no inputs attribute found");
}
auto input_names = input_attr.cast<mlir::StringAttr>().getValue();
input_names.split(function_input_names, ",");
if (function_input_names.size() != model_flags.input_arrays().size()) {
return errors::InvalidArgument(
"input array size mismatch: got ", function_input_names.size(),
", expected: ", model_flags.input_arrays().size());
}
llvm::StringSet<> function_input_names_set;
function_input_names_set.insert(function_input_names.begin(),
function_input_names.end());
for (const auto& input_array : model_flags.input_arrays()) {
if (function_input_names_set.count(input_array.name()) == 0) {
return errors::InvalidArgument("input array name (", input_array.name(),
") does not exist in the given graph");
}
}
// Get the list of output Op names from the function attribute.
llvm::SmallVector<llvm::StringRef, 4> function_output_names;
function_output_names.reserve(model_flags.output_arrays().size());
auto output_attr = tf_attrs.get("outputs");
if (!output_attr) {
return errors::InvalidArgument("no outputs attribute found");
}
auto output_names = output_attr.cast<mlir::StringAttr>().getValue();
output_names.split(function_output_names, ",");
if (function_output_names.size() != model_flags.output_arrays().size()) {
return errors::InvalidArgument(
"output array size mismatch: got ", function_output_names.size(),
", expected: ", model_flags.output_arrays().size());
}
llvm::StringSet<> function_output_names_set;
function_output_names_set.insert(function_output_names.begin(),
function_output_names.end());
for (const auto& output_array : model_flags.output_arrays()) {
if (function_output_names_set.count(output_array) == 0) {
return errors::InvalidArgument("output array name (", output_array,
") does not exist in the given graph");
}
}
return Status::OK();
}
Status ConvertSavedModelToTFLiteFlatBuffer(
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
string* result) {
@ -77,11 +151,15 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
model_flags.saved_model_version(), tags,
exported_names, &context));
if (!model_flags.input_arrays().empty() ||
!model_flags.output_arrays().empty()) {
TF_RETURN_IF_ERROR(HandleInputOutputArraysWithModule(model_flags, &module));
}
mlir::TFL::PassConfig pass_config(quant_specs);
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
pass_config.lower_tensor_list_ops = true;
pass_config.shape_inference = true;
auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
toco_flags, std::move(module), pass_config, result);

View File

@ -285,7 +285,7 @@ Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
if (pass_config.legalize_tf_while) {
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
}
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,

View File

@ -1,4 +1,4 @@
load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary")
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_native_cc_binary")
load(
"//tensorflow/core/platform:build_config.bzl",
"tf_proto_library",
@ -14,7 +14,10 @@ package(
package_group(
name = "friends",
includes = ["//third_party/mlir:subpackages"],
packages = ["//tensorflow/compiler/mlir/..."],
packages = [
"//learning/brain/experimental/mlir/quantization/...",
"//tensorflow/compiler/mlir/...",
],
)
exports_files([
@ -112,11 +115,22 @@ tf_native_cc_binary(
],
)
cc_library(
name = "numerical_utils",
srcs = ["numerical_utils.cc"],
hdrs = ["numerical_utils.h"],
deps = [
"@com_google_absl//absl/types:optional",
],
)
cc_library(
name = "device_target",
srcs = ["device_target.cc"],
hdrs = ["device_target.h"],
deps = [
":numerical_utils",
"@com_google_absl//absl/types:optional",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
@ -139,3 +153,13 @@ cc_library(
"@llvm-project//mlir:Support",
],
)
tf_cc_test(
name = "numerical_utils_test",
srcs = ["numerical_utils_test.cc"],
deps = [
":numerical_utils",
"@com_google_absl//absl/types:optional",
"@com_google_googletest//:gtest_main",
],
)

View File

@ -15,12 +15,18 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
#include <algorithm>
#include "absl/types/optional.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h"
namespace mlir {
namespace quant {
@ -39,7 +45,7 @@ DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) {
assert(qi8n_ == qi8n_);
}
Optional<KernelSpec> DeviceTarget::Get(QuantizeRegionOp op) const {
Optional<KernelSpec> DeviceTarget::GetKernelSpec(QuantizeRegionOp op) const {
auto kernel_specs_it = specs_.find(op.logical_kernel());
if (kernel_specs_it == specs_.end()) return llvm::None;
@ -50,9 +56,15 @@ Optional<KernelSpec> DeviceTarget::Get(QuantizeRegionOp op) const {
return kernel_specs_it->getValue().Find(signature);
}
ScaleDecomposeFn DeviceTarget::GetDecomposeFn(QuantizeRegionOp op) const {
auto kernel_specs_it = specs_.find(op.logical_kernel());
if (kernel_specs_it == specs_.end()) return ScaleDecomposeFn(nullptr);
return kernel_specs_it->second.GetDecomposeFn();
}
LogicalResult DeviceTarget::RegisterKernel(
llvm::StringRef kernel, const KernelSpecs::Signature& signature,
const ScaleFn& fn) {
const ScaleFn& fn, const ScaleDecomposeFn& dfn) {
return specs_[kernel].Add(signature, {ScaleConstraintType::CustomScale, fn});
}
@ -78,5 +90,49 @@ void DeviceTarget::AppendToSignature(ArrayAttr specs_attr,
}
}
LogicalResult DeviceTarget::DecomposeMultiplyAccumulateScale(
Operation* op, quant::QuantizedMultipliers* input_multipliers,
quant::QuantizedMultipliers* output_multipliers,
quant::QuantizedRanges* output_ranges) {
auto rop = llvm::dyn_cast<quant::QuantizeRegionOp>(op);
if (!rop) return failure();
llvm::SmallVector<Type, 4> input_specs, out_specs;
for (auto spec : rop.input_specs()) {
input_specs.push_back(spec.cast<TypeAttr>().getValue());
}
for (auto spec : rop.output_specs()) {
out_specs.push_back(spec.cast<TypeAttr>().getValue());
}
auto in_spec = input_specs[0].dyn_cast<quant::UniformQuantizedType>();
// TODO(fengliuai): handles the PerAxis QuantizedType.
auto w_spec = input_specs[1].dyn_cast<quant::UniformQuantizedType>();
auto b_spec = input_specs[2].dyn_cast<quant::UniformQuantizedType>();
auto o_spec = out_specs[0].dyn_cast<quant::UniformQuantizedType>();
if (!in_spec || !w_spec || !b_spec || !o_spec) return failure();
double scale_product = in_spec.getScale() * w_spec.getScale();
if (fabs(scale_product - b_spec.getScale()) < 1e-6) return failure();
// input multipliers
input_multipliers->append(3, kUnitQuantizedMultiplier);
// output multipliers
double real_multiplier = o_spec.getScale() / scale_product;
output_multipliers->push_back(quant::QuantizeMultiplier(real_multiplier));
// output ranges
auto min = rop.getAttrOfType<FloatAttr>("min");
auto max = rop.getAttrOfType<FloatAttr>("max");
output_ranges->push_back(quant::CalculateQuantizedRange(
o_spec.getScale(), o_spec.getZeroPoint(),
(min ? absl::optional<double>(min.getValueAsDouble()) : absl::nullopt),
(max ? absl::optional<double>(max.getValueAsDouble()) : absl::nullopt),
o_spec.getStorageTypeMin(), o_spec.getStorageTypeMax()));
return success();
}
} // namespace quant
} // namespace mlir

View File

@ -17,13 +17,13 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_DEVICE_TARGET_H_
#include <functional>
#include <ostream>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
@ -33,6 +33,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h"
namespace mlir {
namespace quant {
@ -40,9 +41,17 @@ namespace quant {
class QuantizeContext;
using AdjacentOperations = llvm::SmallVectorImpl<Operation*>;
using QuantizedMultipliers = llvm::SmallVector<QuantizedMultiplier, 4>;
using QuantizedRanges = llvm::SmallVector<QuantizedRange, 4>;
using ScaleFn = std::function<LogicalResult(QuantizeContext*, Operation*,
AdjacentOperations*, bool*)>;
using ScaleDecomposeFn =
std::function<LogicalResult(Operation*, QuantizedMultipliers*,
QuantizedMultipliers*, QuantizedRanges*)>;
static const QuantizedMultiplier kUnitQuantizedMultiplier{1, 0};
enum class ScaleConstraintType {
OutputInputSameScale,
OutputInputFreeScale,
@ -73,12 +82,25 @@ class KernelSpecs {
}
}
ScaleDecomposeFn GetDecomposeFn() const { return decompose_fn_; }
// Adds the kernel signature with the kernel specification.
LogicalResult Add(const Signature& signature, const KernelSpec& spec) {
if (all_signatures_.insert({signature, spec}).second) return success();
return failure();
}
KernelSpecs& WithSignature(const KernelSpecs::Signature& signature,
const ScaleFn& fn) {
Add(signature, {ScaleConstraintType::CustomScale, fn});
return *this;
}
KernelSpecs& WithImpl(const ScaleDecomposeFn& dfn) {
decompose_fn_ = dfn;
return *this;
}
private:
// The signature is pattern match based.
struct SignatureInfo : public llvm::DenseMapInfo<Signature> {
@ -101,6 +123,10 @@ class KernelSpecs {
// Maps the signature to the kernel spec. Note that the matching is
// pattern match based.
llvm::DenseMap<Signature, KernelSpec, SignatureInfo> all_signatures_;
// A method to compute the effective multipliers. This is independent on the
// bits of the ports, thus all the signature shares the same here.
ScaleDecomposeFn decompose_fn_;
};
class DeviceTarget {
@ -108,19 +134,26 @@ class DeviceTarget {
explicit DeviceTarget(MLIRContext* ctx);
// Retrieves the kernel spec for the quant region op.
Optional<KernelSpec> Get(quant::QuantizeRegionOp op) const;
Optional<KernelSpec> GetKernelSpec(quant::QuantizeRegionOp op) const;
// Retrieves the scale decomposition function for the quant region op.
ScaleDecomposeFn GetDecomposeFn(quant::QuantizeRegionOp op) const;
protected:
// Adds the kernel spec with the custom scale function for the kernel.
LogicalResult RegisterKernel(llvm::StringRef kernel,
const KernelSpecs::Signature& signature,
const ScaleFn& fn);
const ScaleFn& fn, const ScaleDecomposeFn& dfn);
// Adds the kernel spec with the scale constraint type for the kernel.
LogicalResult RegisterKernel(llvm::StringRef kernel,
const KernelSpecs::Signature& signature,
const ScaleConstraintType constraint);
// Adds the kernel with the name. Retrun an existing one if it has been
// added before.
KernelSpecs& RegisterKernel(llvm::StringRef kernel) { return specs_[kernel]; }
// converts specification to signature:
// - UniformedQuantizedType -> AnyQuantizedType
// - AnyQuantizedType (int) -> AnyQuantizedType
@ -128,6 +161,13 @@ class DeviceTarget {
void AppendToSignature(ArrayAttr specs_attr,
KernelSpecs::Signature* signature) const;
// For "mulmat->add" type of kernels, convert the scales of all the ports to
// multipliers.
static LogicalResult DecomposeMultiplyAccumulateScale(
Operation* op, quant::QuantizedMultipliers* input_multipliers,
quant::QuantizedMultipliers* output_multipliers,
quant::QuantizedRanges* output_ranges);
// A set of parameters are required to build the signatures.
FloatType f32_;
IntegerType i8_;

View File

@ -33,7 +33,6 @@ limitations under the License.
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/Functional.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_info.pb.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
@ -55,7 +54,8 @@ namespace quant {
using QuantParamsEntry = QuantizationInfo::QuantParams;
namespace {
class ImportQuantStatsPass : public FunctionPass<ImportQuantStatsPass> {
class ImportQuantStatsPass
: public PassWrapper<ImportQuantStatsPass, FunctionPass> {
public:
explicit ImportQuantStatsPass(OperationToName op_to_name)
: op_to_name_(op_to_name) {}
@ -193,7 +193,7 @@ void ImportQuantStatsPass::runOnFunction() {
}
// Creates an instance of the default quant parameters pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
std::unique_ptr<OperationPass<FuncOp>> CreateImportQuantStatsPass(
OperationToName op_to_name, const std::string &stats_str) {
auto pass = absl::make_unique<ImportQuantStatsPass>(op_to_name);
if (pass->ParseQuantStats(stats_str)) return nullptr;
@ -203,7 +203,7 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
// Creates an instance pass to import quantization stats to the operations in
// the function. A custom method to get the name from the op is used because
// different dialect ops might have different ways to assign the name.
std::unique_ptr<OpPassBase<FuncOp>>
std::unique_ptr<OperationPass<FuncOp>>
CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) {
auto get_name_func = [](Operation *op) {
Location loc = op->getLoc();

View File

@ -0,0 +1,82 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h"
#include <assert.h>
#include <algorithm>
#include <cmath>
#include <limits>
#include "absl/types/optional.h"
namespace mlir {
namespace quant {
// This method is adopted from TFLite:
// ["tensorflow/lite/kernels/internal/quantization_util.cc"]
QuantizedMultiplier QuantizeMultiplier(double double_multiplier) {
if (double_multiplier < 1e-6) {
return {0, 0};
}
int32_t shift;
const double q = frexp(double_multiplier, &shift);
auto q_fixed = static_cast<int64_t>(round(q * (1ll << 31)));
assert(q_fixed <= (1ll << 31));
if (q_fixed == (1ll << 31)) {
q_fixed /= 2;
++shift;
}
assert(q_fixed <= std::numeric_limits<int32_t>::max());
// A shift amount smaller than -31 would cause all bits to be shifted out
// and thus all results would be zero. We implement that instead with
// q_fixed==0, so as to avoid hitting issues with right-shift
// operations with shift amounts greater than 31. Note that this happens
// roughly when abs(double_multiplier) < 2^-31 and the present handling means
// that we're effectively flushing tiny double_multiplier's to zero.
// We could conceivably handle values in the range (roughly) [32, 63]
// as 'denormals' i.e. (shift==0, q_fixed < 2^30). In that point of view
// the present handling is just doing 'flush denormals to zero'. We could
// reconsider and actually generate nonzero denormals if a need arises.
if (shift < -31) {
shift = 0;
q_fixed = 0;
}
return {static_cast<int32_t>(q_fixed), shift};
}
QuantizedRange CalculateQuantizedRange(double scale, int32_t zero_point,
absl::optional<double> rmin,
absl::optional<double> rmax,
int32_t qmin, int32_t qmax) {
auto quantize = [scale, zero_point](float f) {
return zero_point + static_cast<int32_t>(std::round(f / scale));
};
if (rmin.has_value() && rmax.has_value()) {
return {std::max(qmin, quantize(rmin.value())),
std::min(qmax, quantize(rmax.value()))};
} else if (rmin.has_value()) {
return {std::max(qmin, quantize(rmin.value())), qmax};
} else if (rmax.has_value()) {
return {qmin, std::min(qmax, quantize(rmax.value()))};
} else {
return {qmin, qmax};
}
}
} // namespace quant
} // namespace mlir

View File

@ -0,0 +1,45 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_NUMERICAL_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_NUMERICAL_UTILS_H_
#include <cstdint>
#include <utility>
#include "absl/types/optional.h"
namespace mlir {
namespace quant {
using QuantizedMultiplier = std::pair<int32_t, int32_t>;
using QuantizedRange = std::pair<int32_t, int32_t>;
// Decompose double precision multiplier to integer multiplier and exponent.
// double_multiplier = int_multiplier * 2 ^ (-31 + exponent)
// int_multiplier will be range of (2^31, 2^30].
QuantizedMultiplier QuantizeMultiplier(double double_multiplier);
// Calculate the effective quantized value range for the scale, zero point. The
// range is the minimum range defined by [rmin, rmax] and [qmin, qmax].
QuantizedRange CalculateQuantizedRange(double scale, int32_t zero_point,
absl::optional<double> rmin,
absl::optional<double> rmax,
int32_t qmin, int32_t qmax);
} // namespace quant
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_NUMERICAL_UTILS_H_

View File

@ -0,0 +1,114 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h"
#include <cmath>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/types/optional.h"
namespace mlir {
namespace quant {
namespace {
double ComposeScale(const QuantizedMultiplier& input) {
return input.first * exp2(-31 + input.second);
}
TEST(NumericalUtils, QuantizeMultiplier) {
// Decompose multiplier larger than 1.
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e6)), 1.0e6);
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e3)), 1.0e3);
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(10.)), 10.);
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(5.)), 5.);
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(2.)), 2.);
// Decompose multiplier between 1.0 and 1e-6.
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(0.0)), 0.0);
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0)), 1.0);
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-1)), 1.0e-1);
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-2)), 1.0e-2);
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-3)), 1.0e-3);
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-4)), 1.0e-4);
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-5)), 1.0e-5);
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-6)), 1.0e-6);
// When scale is smaller than 1.0e-6, it is decomposed to {0, 0}.
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-7)), 0.0);
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-8)), 0.0);
}
TEST(NumericalUtils, ActivationRange) {
// zero point = 0
auto a =
CalculateQuantizedRange(1e-6, 0, absl::nullopt, absl::nullopt, -128, 127);
ASSERT_EQ(a.first, -128);
ASSERT_EQ(a.second, 127);
auto b = CalculateQuantizedRange(1e-6, 0, 0.0, absl::nullopt, -128, 127);
ASSERT_EQ(b.first, 0);
ASSERT_EQ(b.second, 127);
auto c = CalculateQuantizedRange(1e-6, 0, -1.0, 1.0, -128, 127);
ASSERT_EQ(c.first, -128);
ASSERT_EQ(c.second, 127);
auto d = CalculateQuantizedRange(1e-6, 0, 0.0, 6.0, -128, 127);
ASSERT_EQ(d.first, 0);
ASSERT_EQ(d.second, 127);
// zero point = 100
auto e = CalculateQuantizedRange(1e-6, 100, absl::nullopt, absl::nullopt,
-128, 127);
ASSERT_EQ(e.first, -128);
ASSERT_EQ(e.second, 127);
auto f = CalculateQuantizedRange(1e-6, 100, 0.0, absl::nullopt, -128, 127);
ASSERT_EQ(f.first, 100);
ASSERT_EQ(f.second, 127);
auto g = CalculateQuantizedRange(1e-6, 100, -1.0, 1.0, -128, 127);
ASSERT_EQ(g.first, -128);
ASSERT_EQ(g.second, 127);
auto h = CalculateQuantizedRange(1e-6, 100, 0.0, 6.0, -128, 127);
ASSERT_EQ(h.first, 100);
ASSERT_EQ(h.second, 127);
// zero point = -100
auto i = CalculateQuantizedRange(1e-6, -100, absl::nullopt, absl::nullopt,
-128, 127);
ASSERT_EQ(i.first, -128);
ASSERT_EQ(i.second, 127);
auto j = CalculateQuantizedRange(1e-6, -100, 0.0, absl::nullopt, -128, 127);
ASSERT_EQ(j.first, -100);
ASSERT_EQ(j.second, 127);
auto k = CalculateQuantizedRange(1e-6, -100, -1.0, 1.0, -128, 127);
ASSERT_EQ(k.first, -128);
ASSERT_EQ(k.second, 127);
auto l = CalculateQuantizedRange(1e-6, -100, 0.0, 6.0, -128, 127);
ASSERT_EQ(l.first, -100);
ASSERT_EQ(l.second, 127);
}
} // namespace
} // namespace quant
} // namespace mlir

View File

@ -67,7 +67,7 @@ std::vector<quant::QuantizeRegionOp> QuantizeContext::GetAllOps() {
LogicalResult QuantizeContext::Handle(
quant::QuantizeRegionOp op, llvm::SmallVectorImpl<Operation *> *new_items,
bool *changed) {
auto spec = target_spec_.Get(op);
auto spec = target_spec_.GetKernelSpec(op);
if (!spec.hasValue()) {
op.emitWarning(
"Couldn't find kernel from the registeration for quantization.");

View File

@ -27,13 +27,13 @@ using OperationToName = std::function<llvm::StringRef(Operation* op)>;
// Creates an instance pass to import quantization stats to the operations in
// the function. A custom method to get the name from the op is used because
// different dialect ops might have different ways to assign the name.
std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
std::unique_ptr<OperationPass<FuncOp>> CreateImportQuantStatsPass(
OperationToName op_to_name, const std::string& stats_str);
// Creates an instance pass to import quantization stats to the operations in
// the function. A custom method to get the name from the op is used because
// different dialect ops might have different ways to assign the name.
std::unique_ptr<OpPassBase<FuncOp>>
std::unique_ptr<OperationPass<FuncOp>>
CreateImportQuantStatsPassForTFControlDialect(const std::string& stats_str);
} // namespace quant

View File

@ -79,7 +79,7 @@ TypeAttr RescaleQuantizedType(Type input, Attribute factor) {
SmallVector<double, 4> new_scales;
new_scales.reserve(scales.size());
auto scales_iter = scales.begin();
for (auto f : factor_values) {
for (const auto& f : factor_values) {
new_scales.push_back(*(scales_iter++) *
std::fabs(FloatAttr::getValueAsDouble(f)));
}

View File

@ -25,7 +25,7 @@ namespace mlir {
namespace TF {
// Legalize the tf ops to the quant ops, so the quantization passes can work.
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFToQuantPass();
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFToQuantPass();
} // namespace TF
} // namespace mlir

View File

@ -27,7 +27,7 @@ namespace TF {
namespace {
// Legalize TF quantization emulation ops to that in Quant ops dialect.
struct LegalizeTFToQuant : public FunctionPass<LegalizeTFToQuant> {
struct LegalizeTFToQuant : public PassWrapper<LegalizeTFToQuant, FunctionPass> {
explicit LegalizeTFToQuant() = default;
LegalizeTFToQuant(const LegalizeTFToQuant &) {}
@ -146,12 +146,12 @@ void LegalizeTFToQuant::runOnFunction() {
auto func = getFunction();
auto *ctx = func.getContext();
patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant>(ctx);
applyPatternsGreedily(func, patterns);
applyPatternsAndFoldGreedily(func, patterns);
}
} // namespace
// Creates an instance of the TensorFlow dialect to QuantOps dialect pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFToQuantPass() {
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFToQuantPass() {
return std::make_unique<LegalizeTFToQuant>();
}

View File

@ -1,112 +0,0 @@
load(
"//third_party/mlir:tblgen.bzl",
"gentbl",
)
package(
default_visibility = [
":friends",
],
licenses = ["notice"], # Apache 2.0
)
package_group(
name = "friends",
includes = ["//third_party/mlir:subpackages"],
packages = [
"//tensorflow/compiler/aot/...",
"//tensorflow/compiler/mlir/...",
"//tensorflow/compiler/mlir/lite/...",
],
)
cc_library(
name = "hlo_xla_quantization_passes",
srcs = [
"cpu_kernel_fusion.cc",
"generated_cpu_kernel_fusion.inc",
"materialize.cc",
"op_quant_spec.inc",
"propagate.cc",
],
hdrs = [
"passes.h",
],
deps = [
":cpu_device_target",
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"//tensorflow/compiler/mlir/lite/quantization:quantization_context",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/xla/client/lib:quantize",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
],
alwayslink = 1,
)
cc_library(
name = "cpu_device_target",
srcs = [
"cpu_device_target.cc",
],
hdrs = [
"cpu_device_target.h",
],
deps = [
"//tensorflow/compiler/mlir/lite/quantization:device_target",
"//tensorflow/compiler/mlir/lite/quantization:quantization_context",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:Support",
],
)
cc_library(
name = "quantize",
srcs = [
"quantize.cc",
],
hdrs = [
"quantize.h",
],
deps = [
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/mlir/xla:hlo_to_mlir_hlo",
"//tensorflow/compiler/tf2xla",
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/core/platform:status",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms",
],
)
gentbl(
name = "cpu_kernel_fusion_inc_gen",
tbl_outs = [
(
"-gen-rewriters",
"generated_cpu_kernel_fusion.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "cpu_kernel_fusion.td",
td_srcs = [
"@llvm-project//mlir:StdOpsTdFiles",
"//tensorflow/compiler/mlir/xla:hlo_ops_td_files",
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
],
)

View File

@ -1,67 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h"
namespace mlir {
namespace xla_hlo {
namespace ph = std::placeholders;
CpuDeviceTarget::CpuDeviceTarget(MLIRContext* ctx) : DeviceTarget(ctx) {
RegisterKernel("generic.concat", {qi8_, qi8_, qi8_},
quant::ScaleConstraintType::OutputInputSameScale);
// TODO(fengliuai): All the combinations are required to list. We need to
// improve this.
RegisterKernel("generic.reshape", {qi8_, any_},
quant::ScaleConstraintType::OutputInputSameScale);
RegisterKernel("generic.reshape", {any_, qi8_},
quant::ScaleConstraintType::OutputInputSameScale);
RegisterKernel("generic.mul", {qi8_, qi8_, qi8_},
quant::ScaleConstraintType::OutputInputFreeScale);
RegisterKernel("generic.mul_add", {qi8_, qi8n_, any_, qi8_},
std::bind(&CpuDeviceTarget::HandleMultiplyAccumulateScale,
this, ph::_1, ph::_2, ph::_3, ph::_4));
RegisterKernel("generic.matmul_add", {qi8_, qi8n_, any_, qi8_},
std::bind(&CpuDeviceTarget::HandleMultiplyAccumulateScale,
this, ph::_1, ph::_2, ph::_3, ph::_4));
}
LogicalResult CpuDeviceTarget::HandleMultiplyAccumulateScale(
quant::QuantizeContext* ctx, Operation* op,
quant::AdjacentOperations* new_items, bool* changed) {
auto bias_params = ctx->GetOperandParams(op, 2);
if (!EmptyParams(bias_params)) {
return success();
}
std::vector<quant::QuantParams> op_types{ctx->GetOperandParams(op, 0),
ctx->GetOperandParams(op, 1)};
auto bias_scale = GetUniformQuantizedTypeForBias(op_types);
if (bias_scale && ctx->SetOperandParams(op, 2, bias_scale)) {
*changed = true;
new_items->push_back(op->getOperand(2).getDefiningOp());
}
return success();
}
} // namespace xla_hlo
} // namespace mlir

View File

@ -1,40 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_CPU_DEVICE_TARGET_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_CPU_DEVICE_TARGET_H_
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
namespace mlir {
namespace xla_hlo {
// Target specs for cpu kernels
class CpuDeviceTarget : public quant::DeviceTarget {
public:
explicit CpuDeviceTarget(MLIRContext* ctx);
private:
LogicalResult HandleMultiplyAccumulateScale(
quant::QuantizeContext* ctx, Operation* op,
quant::AdjacentOperations* new_items, bool* changed);
};
} // namespace xla_hlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_CPU_DEVICE_TARGET_H_

View File

@ -1,252 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <initializer_list>
#include <iterator>
#include <numeric>
#include <string>
#include "absl/memory/memory.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/xla/client/lib/quantize.h"
#define DEBUG_TYPE "quant-kernel-fusion"
constexpr int kFakeQuantOperandsNum = 5;
constexpr int kFakeQuantPerChannelOperandsNum = 6;
namespace mlir {
namespace xla_hlo {
namespace {
TypeAttr GetQuantSpec(Operation* op) {
auto fake_quant = llvm::dyn_cast_or_null<CustomCallOp>(op);
if (!fake_quant || fake_quant.getNumOperands() < kFakeQuantOperandsNum ||
fake_quant.getNumOperands() > kFakeQuantPerChannelOperandsNum ||
fake_quant.call_target_name() != "fake_quant_with_min_max_vars")
return {};
DenseFPElementsAttr min, max;
DenseIntElementsAttr bit_width, narrow_range, quant_dim;
if (!matchPattern(fake_quant.getOperand(1), m_Constant(&min)) ||
!matchPattern(fake_quant.getOperand(2), m_Constant(&max)) ||
!matchPattern(fake_quant.getOperand(3), m_Constant(&bit_width)) ||
!matchPattern(fake_quant.getOperand(4), m_Constant(&narrow_range)))
return {};
auto bit_width_val = (*bit_width.attr_value_begin()).cast<IntegerAttr>();
auto narrow_range_val = (*narrow_range.int_value_begin()).getSExtValue();
int quant_dim_val = -1;
if (fake_quant.getNumOperands() == kFakeQuantPerChannelOperandsNum &&
matchPattern(fake_quant.getOperand(kFakeQuantPerChannelOperandsNum - 1),
m_Constant(&quant_dim))) {
quant_dim_val = (*quant_dim.int_value_begin()).getSExtValue();
}
OpBuilder builder(op);
Type input_type =
fake_quant.getOperand(0).getType().cast<ShapedType>().getElementType();
return quant::GetQuantizedTypeAttr(
builder, input_type, min, max, quant_dim_val, bit_width_val,
builder.getBoolAttr(narrow_range_val), /*is_signed=*/true);
}
// Collects input values from outside for 'ops'.
void CollectInputs(llvm::ArrayRef<Operation*> ops,
llvm::SmallVectorImpl<Value>* inputs,
llvm::SmallVectorImpl<Attribute>* input_specs) {
for (Operation* op : ops) {
for (Value operand : op->getOperands()) {
if (std::find(inputs->begin(), inputs->end(), operand) != inputs->end()) {
continue;
}
if (Operation* def_op = operand.getDefiningOp()) {
if (std::find(ops.begin(), ops.end(), def_op) == ops.end()) {
inputs->push_back(operand);
}
} else { // argument value
inputs->push_back(operand);
}
}
}
for (Value input : *inputs) {
ShapedType input_type = input.getType().cast<ShapedType>();
if (TypeAttr spec = GetQuantSpec(input.getDefiningOp())) {
input_specs->push_back(spec);
} else {
input_specs->push_back(TypeAttr::get(input_type.getElementType()));
}
}
}
// Collects values that are produced by 'ops' and have use outside of 'ops'.
// TODO(fengliuai): if it is a single user and QDQ, write that to the specs.
void CollectRets(llvm::ArrayRef<Operation*> ops,
llvm::SmallVectorImpl<Value>* rets,
llvm::SmallVectorImpl<Type>* ret_types,
llvm::SmallVectorImpl<Attribute>* ret_specs) {
for (Operation* op : ops) {
for (Value result : op->getResults()) {
for (Operation* user : result.getUsers()) {
// If there are any user outside of 'ops'
if (std::find(ops.begin(), ops.end(), user) == ops.end()) {
ShapedType ret_type = result.getType().cast<ShapedType>();
rets->push_back(result);
ret_types->push_back(ret_type);
if (TypeAttr spec = GetQuantSpec(user)) {
ret_specs->push_back(spec);
} else {
ret_specs->push_back(TypeAttr::get(ret_type.getElementType()));
}
break;
}
}
}
}
}
llvm::SmallVector<Value, 0> fuseOps(PatternRewriter* rewriter,
const std::initializer_list<Value>& results,
StringRef kernel) {
// Collect all the operations to be fused.
llvm::SmallVector<Operation*, 4> fused;
llvm::SmallVector<Location, 4> locs;
fused.reserve(results.size());
locs.reserve(results.size());
for (auto value : results) {
Operation* op = value.getDefiningOp();
fused.push_back(op);
locs.push_back(op->getLoc());
}
// Collect inputs from outside to 'ops'.
llvm::SmallVector<Value, 4> inputs;
llvm::SmallVector<Attribute, 4> input_specs;
CollectInputs(fused, &inputs, &input_specs);
// Collect outputs from 'ops' to outside.
llvm::SmallVector<Value, 4> rets;
llvm::SmallVector<Type, 4> ret_types;
llvm::SmallVector<Attribute, 4> ret_specs;
CollectRets(fused, &rets, &ret_types, &ret_specs);
// Create the region op with the return.
auto region = rewriter->create<quant::QuantizeRegionOp>(
rewriter->getFusedLoc(locs), ret_types, inputs,
rewriter->getArrayAttr(input_specs), rewriter->getArrayAttr(ret_specs),
kernel);
auto* body = new Block();
region.body().push_back(body);
OpBuilder builder = OpBuilder::atBlockEnd(body);
BlockAndValueMapping mapping;
// Make block arguments and add it to the block value mapping.
for (Value input : inputs) {
mapping.map(input, body->addArgument(input.getType()));
}
// Clone the operations 'ops' to the region.
for (Operation* op : fused) {
builder.clone(*op, mapping);
}
llvm::SmallVector<Value, 4> new_rets;
new_rets.reserve(rets.size());
for (auto ret : llvm::enumerate(rets)) {
Value new_ret = mapping.lookupOrNull(ret.value());
assert(new_ret && "couldn't find return value.");
new_rets.push_back(new_ret);
ret.value().replaceAllUsesWith(region.getResult(ret.index()));
}
builder.create<quant::ReturnOp>(builder.getUnknownLoc(), new_rets);
LLVM_DEBUG({
assert(region.verify().Success && "failed to create quant region.");
llvm::dbgs() << "\ncreated region: ";
region.print(llvm::dbgs());
llvm::dbgs() << "\n\n\n";
});
SmallVector<Value, 0> new_values(fused.back()->getNumResults());
return new_values;
}
struct CpuKernelFusionPass : public FunctionPass<CpuKernelFusionPass> {
explicit CpuKernelFusionPass() = default;
CpuKernelFusionPass(const CpuKernelFusionPass&) {}
void runOnFunction() override;
private:
LogicalResult fuseCpuKernels(Operation* op);
};
#include "tensorflow/compiler/mlir/lite/quantization/xla/generated_cpu_kernel_fusion.inc"
LogicalResult CpuKernelFusionPass::fuseCpuKernels(Operation* op) {
MLIRContext* ctx = op->getContext();
OwningRewritePatternList patterns;
populateWithGenerated(ctx, &patterns);
ConversionTarget target(*ctx);
target.addLegalDialect<quant::QuantizationDialect>();
target.addLegalOp<CallOp, ModuleOp, FuncOp, ModuleTerminatorOp,
::mlir::ReturnOp>();
return applyPartialConversion(op, target, patterns);
}
void CpuKernelFusionPass::runOnFunction() {
if (failed(fuseCpuKernels(getFunction()))) signalPassFailure();
}
} // namespace
// Creates an instance of the xla_hlo cpu kernel fusion pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateCpuKernelFusionPass() {
return std::make_unique<CpuKernelFusionPass>();
}
static PassRegistration<CpuKernelFusionPass> pass(
"xla-hlo-cpu-fusion", "Fuse xla hlo ops into cpu kernels");
} // namespace xla_hlo
} // namespace mlir

View File

@ -1,174 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass quantize the constant and rewrite the quantization
// ops by xla_hlo primitive ops.
#include <cstdint>
#include <iterator>
#include <numeric>
#include <string>
#include "absl/memory/memory.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/xla/client/lib/quantize.h"
//===----------------------------------------------------------------------===//
// The pass to materialize the quantization results by xla primitive ops.
//
namespace mlir {
namespace xla_hlo {
namespace {
// This pattern matches the "constant->qcast->dcast" pattern and replaces it by
// "quantized constant->xla_hlo.dequantize". If it only matches the
// "non-constant->qcast->dcast" pattern, it will remove both the "qcast->dcast".
// We chain the pattern as a whole to bypass the type checks of the normal
// xla_hlo ops.
// TODO(fengliuai): make this pass work for bf16 input.
class RewriteDequantize : public OpRewritePattern<quant::DequantizeCastOp> {
public:
explicit RewriteDequantize(int64_t size, MLIRContext *context)
: OpRewritePattern<quant::DequantizeCastOp>(context), size_(size) {}
LogicalResult matchAndRewrite(quant::DequantizeCastOp op,
PatternRewriter &rewriter) const override {
// quant.dcast
// xla_hlo dequantize only takes min/max, so let's recover them from
// the quantization parameters.
Value dcast = op.arg();
auto type = quant::QuantizedType::getQuantizedElementType(dcast.getType());
if (!type || !type.isa<quant::UniformQuantizedType>()) {
return failure();
}
auto qtype = type.cast<quant::UniformQuantizedType>();
double scale = qtype.getScale();
int64_t zero_point = qtype.getZeroPoint();
float min = scale * (qtype.getStorageTypeMin() - zero_point);
float max = scale * (qtype.getStorageTypeMax() - zero_point);
// quant.qcast
auto qcast =
llvm::dyn_cast_or_null<quant::QuantizeCastOp>(dcast.getDefiningOp());
if (!qcast) return failure();
// constant
DenseFPElementsAttr attr;
// If it isn't a floating-point constant or the size is too small, let's
// remove the quantization. Also the last dimension size should be a
// multiplier of 4, so the shape isn't broken during packing and unpacking.
if (!matchPattern(qcast.arg(), m_Constant(&attr)) ||
attr.getNumElements() <= size_ ||
attr.getType().getDimSize(attr.getType().getRank() - 1) % 4 != 0) {
op.getResult().replaceAllUsesWith(qcast.arg());
return success();
}
// TODO(fengliuai): implement transpose if it has high dimension.
// Create the quantized result
auto quantized_result =
quant::Quantize(attr, qtype).dyn_cast_or_null<DenseIntElementsAttr>();
if (!quantized_result) {
return failure();
}
// Pack the uint8 bits to uint32. The shape is changed from from
// [n0, n1, ..., nk] to [n0, n1, ..., nk / 4].
std::vector<uint8_t> raw_data;
for (auto d : quantized_result.getValues<uint8_t>()) {
raw_data.push_back(d);
}
// The packing might increase the data size by paddings.
auto packed_data = xla::PackToUint32<uint8_t>(raw_data);
auto packed_shape = attr.getType().getShape().vec();
int lower_dims = std::accumulate(
packed_shape.begin(),
std::next(packed_shape.begin(), packed_shape.size() - 1), 1,
std::multiplies<int>());
packed_shape[packed_shape.size() - 1] = packed_data.size() / lower_dims;
auto packed_type =
RankedTensorType::get(packed_shape, rewriter.getIntegerType(32));
auto packed_quantized_result =
DenseElementsAttr::get<uint32_t>(packed_type, packed_data);
auto quantized_constant =
rewriter.create<ConstantOp>(qcast.getLoc(), packed_quantized_result);
// Create the xla dequantize op with bf16 output
auto dequantized_type = RankedTensorType::get(attr.getType().getShape(),
rewriter.getBF16Type());
auto dequantize = rewriter.create<DequantizeOp>(
qcast.getLoc(), dequantized_type, quantized_constant,
rewriter.getF32FloatAttr(min), rewriter.getF32FloatAttr(max),
rewriter.getStringAttr("MIN_COMBINED"), rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false));
// Convert bf16 output back to f32
rewriter.replaceOpWithNewOp<ConvertOp>(op, op.getResult().getType(),
dequantize);
return success();
}
private:
int64_t size_;
};
// Materialize the quantization results by hlo primitive ops.
struct MaterializeToXlaPass : public FunctionPass<MaterializeToXlaPass> {
explicit MaterializeToXlaPass() = default;
MaterializeToXlaPass(const MaterializeToXlaPass &) {}
void runOnFunction() override;
};
void MaterializeToXlaPass::runOnFunction() {
FuncOp func = getFunction();
MLIRContext *ctx = &getContext();
OwningRewritePatternList patterns;
// TODO(fengliuai): make the size 6 configurable.
patterns.insert<RewriteDequantize>(6, ctx);
applyPatternsGreedily(func, patterns);
}
} // namespace
// Creates an instance of the xla_hlo dialect quantization propagation pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateMaterializeToXlaPass() {
return std::make_unique<MaterializeToXlaPass>();
}
static PassRegistration<MaterializeToXlaPass> pass(
"xla-hlo-materialize-quant",
"Materialize the quantization results by xla primitve ops");
} // namespace xla_hlo
} // namespace mlir

View File

@ -1,7 +0,0 @@
// TODO(fengliuai): automatically generate this file
// TODO(fengliuai): add all the xla_hlo ops
static std::unique_ptr<quant::OpQuantSpec> GetOpQuantSpec(mlir::Operation *op) {
auto spec = absl::make_unique<quant::OpQuantSpec>();
return spec;
}

View File

@ -1,107 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass applies quantization propagation on xla_hlo dialect.
#include <iterator>
#include <string>
#include "absl/memory/memory.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.h"
// NOLINTNEXTLINE
static llvm::cl::opt<bool> disable_per_channel(
"xla-disable-per-channel", llvm::cl::value_desc("bool"),
llvm::cl::desc("Whether disable per-channel quantized weights."),
llvm::cl::init(false));
//===----------------------------------------------------------------------===//
// The quantization propagation Pass.
//
namespace mlir {
namespace xla_hlo {
namespace {
// Applies the quantization propagation on the input function. During the
// propagation, two facts are respected:
// - The quantization type (params) of the ops in the function
// - The quantization spec for the ops
// The propagation results should assign quantization types to all the tensors
// and the two restrictions are respected.
struct PropagateQuantPass : public FunctionPass<PropagateQuantPass> {
explicit PropagateQuantPass() = default;
PropagateQuantPass(const PropagateQuantPass &) {}
void runOnFunction() override;
};
#include "tensorflow/compiler/mlir/lite/quantization/xla/op_quant_spec.inc"
void PropagateQuantPass::runOnFunction() {
FuncOp func = getFunction();
// TODO(fengliuai): deprecate this old code generation path.
// XLA only support uint8/uint16 quantization for now.
ApplyQuantizationParamsPropagation(func, /*is_signed*/ false,
disable_per_channel, GetOpQuantSpec);
CpuDeviceTarget spec(&getContext());
quant::QuantizeContext ctx(func, spec);
std::vector<quant::QuantizeRegionOp> work_list = ctx.GetAllOps();
bool changed = false;
while (!work_list.empty()) {
quant::QuantizeRegionOp op = work_list.back();
work_list.pop_back();
llvm::SmallVector<Operation *, 4> new_items;
if (failed(ctx.Handle(op, &new_items, &changed))) {
// The IR is still valid, thus we shouldn't fail.
signalPassFailure();
}
for (auto item : new_items) {
if (auto reg = llvm::dyn_cast_or_null<quant::QuantizeRegionOp>(item))
work_list.push_back(reg);
}
}
if (!changed) return;
if (failed(ctx.Finalize())) {
signalPassFailure();
}
}
} // namespace
// Creates an instance of the xla_hlo dialect quantization propagation pass.
std::unique_ptr<OpPassBase<FuncOp>> CreatePropagateQuantPass() {
return std::make_unique<PropagateQuantPass>();
}
static PassRegistration<PropagateQuantPass> pass(
"xla-hlo-propagate-quant", "Propagate quantization information");
} // namespace xla_hlo
} // namespace mlir

View File

@ -1,74 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/tf2xla/tf2xla.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
namespace mlir {
namespace xla_hlo {
static void RegisterDialects() {
static bool init_once = []() {
mlir::registerDialect<mlir::xla_hlo::XlaHloDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>();
return true;
}();
(void)init_once;
}
// Quantizes the model in the computation.
tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config,
xla::XlaComputation* computation) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> snapshot,
computation->Snapshot());
RegisterDialects();
MLIRContext context;
OwningModuleRef module = ModuleOp::create(UnknownLoc::get(&context));
auto status = xla::ConvertHloToMlirHlo(
module.get(), snapshot->mutable_hlo()->mutable_hlo_module());
if (!status.ok()) {
LOG(ERROR) << "Hlo module import failed: " << status;
return status;
}
PassManager pm(&context);
pm.addPass(createCanonicalizerPass());
pm.addPass(createInlinerPass());
pm.addPass(createSymbolDCEPass());
pm.addNestedPass<FuncOp>(createCSEPass());
mlir::StatusScopedDiagnosticHandler diag_handler(&context);
LogicalResult result = pm.run(module.get());
(void)result;
module->dump();
return tensorflow::Status::OK();
}
} // namespace xla_hlo
} // namespace mlir

View File

@ -1,46 +0,0 @@
// RUN: tf-opt -xla-hlo-cpu-fusion %s | FileCheck %s
// CHECK-LABEL: @mul_add_source
func @mul_add_source(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
%0 = "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%1 = "xla_hlo.add"(%0, %arg2) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %1 : tensor<4xf32>
// CHECK: %[[region:.*]] = "quant.region"(%arg0, %arg1, %arg2) ( {
// CHECK: ^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
// CHECK: %[[mul:.*]] = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
// CHECK: %[[add:.*]] = xla_hlo.add %[[mul]], %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
// CHECK: "quant.return"(%[[add]]) : (tensor<4xf32>) -> ()
// CHECK: }) {input_specs = [f32, f32, f32], logical_kernel = "generic.mul_add", output_specs = [f32]} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK: return %[[region]] : tensor<4xf32>
}
// CHECK-LABEL: @mul_add_annotated
func @mul_add_annotated(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>, %arg2: tensor<2x4xf32>) -> (tensor<2x4xf32>) {
%cst = constant dense<0.0> : tensor<f32>
%cst_0 = constant dense<255.0> : tensor<f32>
%cst_1 = constant dense<8> : tensor<i32>
%cst_2 = constant dense<false> : tensor<i1>
%qin = "xla_hlo.custom_call"(%arg0, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars",
has_side_effect = false, name = "custom-call.1"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
%qw = "xla_hlo.custom_call"(%arg1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars",
has_side_effect = false, name = "custom-call.2"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
%0 = "xla_hlo.multiply"(%qin, %qw) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
%1 = "xla_hlo.add"(%0, %arg2) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
%r = "xla_hlo.custom_call"(%1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars",
has_side_effect = false, name = "custom-call.3"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
return %r : tensor<2x4xf32>
// CHECK: %[[region:.*]] = "quant.region"
// CHECK: ^bb0(%arg3: tensor<2x4xf32>, %arg4: tensor<2x4xf32>, %arg5: tensor<2x4xf32>): // no predecessors
// CHECK: %[[mul:.*]] = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<2x4xf32>
// CHECK: %[[add:.*]] = xla_hlo.add %[[mul]], %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<2x4xf32>
// CHECK: "quant.return"(%[[add]]) : (tensor<2x4xf32>) -> ()
// CHECK: }) {input_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>, !quant.uniform<i8:f32, 1.000000e+00:-128>, f32],
// CHECK-SAME: logical_kernel = "generic.mul_add", output_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>]} :
// CHECK-SAME: (tensor<2x4xf32>, tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK: %[[r:.*]] = "xla_hlo.custom_call"(%[[region]]
// CHECK: return %[[r]] : tensor<2x4xf32>
}

View File

@ -1,15 +0,0 @@
# RUN: not tfcompile --graph=%s.pbtxt --config=%s.config.pbtxt --experimental_quantize --cpp_class="::test::fadd_quant" 2>&1 | FileCheck %s -dump-input-on-failure
# TODO(fengliuai): update this file with the progress of the implementation
// CHECK: func @main
// CHECK: %cst = constant dense<0.000000e+00> : tensor<f32>
// CHECK: %cst_0 = constant dense<1.270000e+02> : tensor<f32>
// CHECK: %cst_1 = constant dense<8> : tensor<i32>
// CHECK: %cst_2 = constant dense<false> : tensor<i1>
// CHECK: %0 = "xla_hlo.custom_call"(%arg0, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.9"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
// CHECK: %1 = "xla_hlo.custom_call"(%arg1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.14"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
// CHECK: %2 = xla_hlo.add %0, %1 {name = "add.15"} : tensor<2x4xf32>
// CHECK: %3 = "xla_hlo.custom_call"(%2, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.20"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
// CHECK: %4 = "xla_hlo.tuple"(%3) {name = "tuple.22"} : (tensor<2x4xf32>) -> tuple<tensor<2x4xf32>>
// CHECK: return %4 : tuple<tensor<2x4xf32>>
// CHECK: }

View File

@ -1,26 +0,0 @@
feed {
id { node_name: "input0" }
shape {
dim { size: 2 }
dim { size: 4 }
}
}
feed {
id { node_name: "input1" }
shape {
dim { size: 2 }
dim { size: 4 }
}
}
fetch {
id { node_name: "Add/FakeQuantWithMinMaxVars" }
shape {
dim { size: 2 }
dim { size: 4 }
}
}
conversion_options {
custom_fake_quant_op_calls: true
}

View File

@ -1,218 +0,0 @@
node: {
name: "Add/FakeQuantWithMinMaxVars"
op: "FakeQuantWithMinMaxVars"
input: "Add"
input: "Add/FakeQuantWithMinMaxVars/min"
input: "Add/FakeQuantWithMinMaxVars/max"
attr: {
key: "num_bits"
value: {
i: 8
}
}
attr: {
key: "narrow_range"
value: {
b: false
}
}
}
node: {
name: "Add/FakeQuantWithMinMaxVars/min"
op: "Const"
attr: {
key: "value"
value: {
tensor: {
dtype: DT_FLOAT
tensor_shape: {
}
float_val: 0.0
}
}
}
attr: {
key: "dtype"
value: {
type: DT_FLOAT
}
}
}
node: {
name: "Add/FakeQuantWithMinMaxVars/max"
op: "Const"
attr: {
key: "value"
value: {
tensor: {
dtype: DT_FLOAT
tensor_shape: {
}
float_val: 127.0
}
}
}
attr: {
key: "dtype"
value: {
type: DT_FLOAT
}
}
}
node {
name: "Add"
op: "Add"
input: "input0/FakeQuantWithMinMaxVars"
input: "input1/FakeQuantWithMinMaxVars"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node: {
name: "input0/FakeQuantWithMinMaxVars"
op: "FakeQuantWithMinMaxVars"
input: "input0"
input: "input0/FakeQuantWithMinMaxVars/min"
input: "input0/FakeQuantWithMinMaxVars/max"
attr: {
key: "num_bits"
value: {
i: 8
}
}
attr: {
key: "narrow_range"
value: {
b: false
}
}
}
node: {
name: "input0/FakeQuantWithMinMaxVars/min"
op: "Const"
attr: {
key: "value"
value: {
tensor: {
dtype: DT_FLOAT
tensor_shape: {
}
float_val: 0.0
}
}
}
attr: {
key: "dtype"
value: {
type: DT_FLOAT
}
}
}
node: {
name: "input0/FakeQuantWithMinMaxVars/max"
op: "Const"
attr: {
key: "value"
value: {
tensor: {
dtype: DT_FLOAT
tensor_shape: {
}
float_val: 127.0
}
}
}
attr: {
key: "dtype"
value: {
type: DT_FLOAT
}
}
}
node {
name: "input0"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
}
node: {
name: "input1/FakeQuantWithMinMaxVars"
op: "FakeQuantWithMinMaxVars"
input: "input1"
input: "input1/FakeQuantWithMinMaxVars/min"
input: "input1/FakeQuantWithMinMaxVars/max"
attr: {
key: "num_bits"
value: {
i: 8
}
}
attr: {
key: "narrow_range"
value: {
b: false
}
}
}
node: {
name: "input1/FakeQuantWithMinMaxVars/min"
op: "Const"
attr: {
key: "value"
value: {
tensor: {
dtype: DT_FLOAT
tensor_shape: {
}
float_val: 0.0
}
}
}
attr: {
key: "dtype"
value: {
type: DT_FLOAT
}
}
}
node: {
name: "input1/FakeQuantWithMinMaxVars/max"
op: "Const"
attr: {
key: "value"
value: {
tensor: {
dtype: DT_FLOAT
tensor_shape: {
}
float_val: 127.0
}
}
}
attr: {
key: "dtype"
value: {
type: DT_FLOAT
}
}
}
node {
name: "input1"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
}
versions {
producer: 27
}

View File

@ -1,54 +0,0 @@
// RUN: tf-opt -xla-hlo-materialize-quant %s | FileCheck %s
// CHECK-LABEL: func @quantize_rewrite
func @quantize_rewrite(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> {
// CHECK: %[[qcst:.*]] = constant dense<{{\[\[}}21004416], [-1056997248]]> : tensor<2x1xi32>
// CHECK-NEXT: %[[dq:.*]] = "xla_hlo.dequantize"(%[[qcst]]) {is_16bits = false, max_range = 0.996078431 : f32, min_range = -1.00392163 : f32,
// CHECK-SAME: mode = "MIN_COMBINED", transpose_output = false} : (tensor<2x1xi32>) -> tensor<2x4xbf16>
// CHECK-NEXT: %[[cast:.*]] = "xla_hlo.convert"(%[[dq]]) : (tensor<2x4xbf16>) -> tensor<2x4xf32>
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[cast]] : tensor<2x4xf32>
// CHECK-NEXT: return %[[mul]] : tensor<2x4xf32>
%w = constant dense<[[-1.0, -0.5, 0.0, 0.0], [0.5, 1.0, 0.0, 0.0]]> : tensor<2x4xf32>
%q = "quant.qcast"(%w) : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
%dq = "quant.dcast"(%q) : (tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x4xf32>
%mul = xla_hlo.multiply %arg0, %dq : tensor<2x4xf32>
return %mul: tensor<2x4xf32>
}
// CHECK-LABEL: func @quantize_small
func @quantize_small(%arg0: tensor<1x4xf32>) -> tensor<1x4xf32> {
// CHECK: %[[w:.*]] = constant dense<1.000000e+00> : tensor<1x4xf32>
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[w]] : tensor<1x4xf32>
// CHECK-NEXT: return %[[mul]] : tensor<1x4xf32>
%w = constant dense<1.0> : tensor<1x4xf32>
%q = "quant.qcast"(%w) : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
%dq = "quant.dcast"(%q) : (tensor<1x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<1x4xf32>
%mul = xla_hlo.multiply %arg0, %dq : tensor<1x4xf32>
return %mul: tensor<1x4xf32>
}
// CHECK-LABEL: func @quantize_non_cst
func @quantize_non_cst(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> {
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %arg0 : tensor<2x4xf32>
// CHECK-NEXT: return %[[mul]] : tensor<2x4xf32>
%q = "quant.qcast"(%arg0) : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
%dq = "quant.dcast"(%q) : (tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x4xf32>
%mul = xla_hlo.multiply %arg0, %dq : tensor<2x4xf32>
return %mul: tensor<2x4xf32>
}
// CHECK-LABEL: func @quantize_non_4x
func @quantize_non_4x(%arg0: tensor<2x5xf32>) -> tensor<2x5xf32> {
// CHECK: %[[w:.*]] = constant dense<1.000000e+00> : tensor<2x5xf32>
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[w]] : tensor<2x5xf32>
// CHECK-NEXT: return %[[mul]] : tensor<2x5xf32>
%w = constant dense<1.0> : tensor<2x5xf32>
%q = "quant.qcast"(%w) : (tensor<2x5xf32>) -> tensor<2x5x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
%dq = "quant.dcast"(%q) : (tensor<2x5x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x5xf32>
%mul = xla_hlo.multiply %arg0, %dq : tensor<2x5xf32>
return %mul: tensor<2x5xf32>
}

View File

@ -1,69 +0,0 @@
// RUN: tf-opt -xla-hlo-propagate-quant %s | FileCheck %s --dump-input-on-failure
// -----
// CHECK-LABEL: @mul_add_source_no_params
func @mul_add_source_no_params(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
%region = "quant.region"(%arg0, %arg1, %arg2) ( {
^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
%mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
%add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
"quant.return"(%add) : (tensor<4xf32>) -> ()
}) {input_specs = [f32, f32, f32], logical_kernel = "generic.mul_add", output_specs = [f32]} :
(tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %region : tensor<4xf32>
// CHECK: input_specs = [f32, f32, f32]
// CHECK-SAME: output_specs = [f32]
}
// -----
// CHECK-LABEL: @mul_add_annotated_no_narrow_range
func @mul_add_annotated_no_narrow_range(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
%region = "quant.region"(%arg0, %arg1, %arg2) ( {
^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
%mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
%add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
"quant.return"(%add) : (tensor<4xf32>) -> ()
}) {input_specs = [!quant.uniform<i8:f32, 1.0:-128>, !quant.uniform<i8:f32, 1.0:-128>, f32],
logical_kernel = "generic.mul_add", output_specs = [!quant.uniform<i8:f32, 1.0:-128>]} :
(tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %region : tensor<4xf32>
// CHECK: input_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>, !quant.uniform<i8:f32, 1.000000e+00:-128>, f32]
// CHECK-SAME: output_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>]
}
// -----
// CHECK-LABEL: @mul_add_annotated
func @mul_add_annotated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
%region = "quant.region"(%arg0, %arg1, %arg2) ( {
^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
%mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
%add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
"quant.return"(%add) : (tensor<4xf32>) -> ()
}) {input_specs = [!quant.uniform<i8:f32, 1.0:-128>, !quant.uniform<i8<-127:127>:f32, 1.0:-128>, f32],
logical_kernel = "generic.mul_add", output_specs = [!quant.uniform<i8:f32, 1.0:-128>]} :
(tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %region : tensor<4xf32>
// CHECK: input_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>, !quant.uniform<i8<-127:127>:f32, 1.000000e+00:-128>, !quant.uniform<i32:f32, 1.000000e+00>]
// CHECK-SAME: output_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>]
}
// -----
// CHECK-LABEL: @same_scale_1_1
func @same_scale_1_1(%arg0: tensor<1x7x7x64xf32>) -> (tensor<1x3136xf32>) {
%region = "quant.region"(%arg0) ( {
^bb0(%arg1: tensor<1x7x7x64xf32>): // no predecessors
%r = "xla_hlo.reshape"(%arg1) : (tensor<1x7x7x64xf32>) -> (tensor<1x3136xf32>)
"quant.return"(%r) : (tensor<1x3136xf32>) -> ()
}) {input_specs = [!quant.uniform<i8:f32, 1.0>], logical_kernel = "generic.reshape", output_specs = [f32]} : (tensor<1x7x7x64xf32>) -> tensor<1x3136xf32>
return %region : tensor<1x3136xf32>
// CHECK: input_specs = [!quant.uniform<i8:f32, 1.000000e+00>]
// CHECK-SAME: output_specs = [!quant.uniform<i8:f32, 1.000000e+00>]
}

View File

@ -1,25 +0,0 @@
// RUN: tf-opt -xla-hlo-propagate-quant %s | FileCheck %s
// CHECK-LABEL: func @mul
func @mul(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[w:.*]] = constant dense<{{\[\[}}-1.000000e+00, -5.000000e-01], [5.000000e-01, 1.000000e+00]]> : tensor<2x2xf32>
// CHECK-NEXT: %[[q:.*]] = "quant.qcast"(%[[w]]) : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
// CHECK-NEXT: %[[dq:.*]] = "quant.dcast"(%[[q]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x2xf32>
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[dq]] : tensor<2x2xf32>
// CHECK-NEXT: return %[[mul]] : tensor<2x2xf32>
%w = constant dense<[[-1.0, -0.5], [0.5, 1.0]]> : tensor<2x2xf32>
%mul = xla_hlo.multiply %arg0, %w : tensor<2x2xf32>
return %mul: tensor<2x2xf32>
}
// CHECK-LABEL: func @add
func @add(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[b:.*]] = constant dense<1.000000e+00> : tensor<2xf32>
// CHECK-NEXT: %[[q:.*]] = "quant.qcast"(%[[b]]) : (tensor<2xf32>) -> tensor<2x!quant.uniform<u8:f32, 0.0039215686274509803>>
// CHECK-NEXT: %[[dq:.*]] = "quant.dcast"(%[[q]]) : (tensor<2x!quant.uniform<u8:f32, 0.0039215686274509803>>) -> tensor<2xf32>
// CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg0, %[[dq]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x2xf32>, tensor<2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: return %[[add]] : tensor<2x2xf32>
%b = constant dense<1.0> : tensor<2xf32>
%add = "xla_hlo.add"(%arg0, %b) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x2xf32>, tensor<2xf32>) -> tensor<2x2xf32>
return %add: tensor<2x2xf32>
}

View File

@ -39,7 +39,7 @@ versions {
# CHECK-LABEL: func @main
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<4xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<4xi32>) -> tensor<*xi32>
# CHECK-SAME: control_outputs = ""
# CHECK-SAME inputs = "input0,input1"
# CHECK-SAME: inputs = "input0,input1"
# CHECK-SAME: outputs = "output"
# CHECK-NEXT: %[[OP:[a-z0-9]+]] = "tf.BannaPotatoSaladWithColeslaw"(%[[ARG_0]], %[[ARG_1]]) {T = i32, device = ""} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
# CHECK-NEXT: return %[[OP]] : tensor<*xi32>

View File

@ -12,6 +12,7 @@ glob_lit_tests(
test_file_exts = [
"mlir",
"cc",
"json",
],
)
@ -22,8 +23,10 @@ filegroup(
data = [
":importer_test_legacy_reshape",
":importer_test_min_max",
":test_schema.fbs",
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
"//tensorflow/compiler/mlir/lite:flatbuffer_translate",
"//tensorflow/compiler/mlir/lite:json_to_flatbuffer",
"@llvm-project//llvm:FileCheck",
],
)

View File

@ -0,0 +1,83 @@
// RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --dump-input-on-failure %s
// CHECK: %cst = constant unit
// CHECK: %[[RES0:.*]] = "tfl.conv_2d"(%arg0, %arg1, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 0 : i32, stride_w = 0 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, none) -> tensor<256x32x32x16xf32>
// CHECK: return %[[RES0]] : tensor<256x32x32x16xf32>
{
version: 3,
operator_codes: [
{
builtin_code: "CONV_2D",
}
],
subgraphs: [
{
tensors: [
{
shape: [
256,
32,
32,
3
],
name: "arg0",
quantization: {
}
},
{
shape: [
16,
3,
3,
3
],
name: "arg1",
quantization: {
}
},
{
shape: [
0
],
name: "cst"
},
{
shape: [
256,
32,
32,
16
],
name: "output",
quantization: {
}
},
],
inputs: [
0,
1
],
outputs: [
3
],
operators: [
{
inputs: [
0,
1,
-1
],
outputs: [
3
],
builtin_options_type: "Conv2DOptions",
builtin_options: {
}
}
],
name: "main"
}
],
description: "MLIR Converted."
}

View File

@ -1,15 +1,15 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
// Ensure lstm roundtrip exactly
func @main(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>, %arg4: tensor<4 x f32>, %arg5: tensor<4 x f32>, %arg6: tensor<4 x f32>, %arg7: tensor<4 x f32>, %arg8: tensor<4 x f32>, %arg9: tensor<4 x f32>, %arg10: tensor<4 x f32>, %arg11: tensor<4 x f32>, %arg12: tensor<4 x f32>, %arg13: tensor<4 x f32>, %arg14: tensor<4 x f32>, %arg15: tensor<4 x f32>, %arg16: tensor<4 x f32>, %arg17: tensor<4 x f32>, %arg18: tensor<4 x f32>, %arg19: tensor<4 x f32>, %arg20: tensor<4 x f32>, %arg21: tensor<4 x f32>) -> tensor<4 x f32> {
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %24 : tensor<4xf32>
func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>) -> tensor<1x4xf32> {
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %24 : tensor<1x4xf32>
// CHECK-LABEL: main
// seperate lines since there is no region for this op. third_party/tensorflow/compiler/mlir/lite/ir/tfl_ops.td: 3252
// CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg22, %arg23, %arg18, %arg19, %arg20, %arg21) ( {
// CHECK: }) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK: }) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
// CHECK: return %[[RES0]]
}

View File

@ -0,0 +1,78 @@
// RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --dump-input-on-failure %s
// This test is to test that if the flatbuffer omits the last optional input `bias` of tfl.conv_2d op, the flatbuffer_importer will automatically adds `none` value to tfl.conv_2d.
// CHECK: %cst = constant unit
// CHECK: %[[RES0:.*]] = "tfl.conv_2d"(%arg0, %arg1, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 0 : i32, stride_w = 0 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, none) -> tensor<256x32x32x16xf32>
// CHECK: return %[[RES0]] : tensor<256x32x32x16xf32>
{
version: 3,
operator_codes: [
{
builtin_code: "CONV_2D",
}
],
subgraphs: [
{
tensors: [
{
shape: [
256,
32,
32,
3
],
name: "arg0",
quantization: {
}
},
{
shape: [
16,
3,
3,
3
],
name: "arg1",
quantization: {
}
},
{
shape: [
256,
32,
32,
16
],
name: "output",
quantization: {
}
},
],
inputs: [
0,
1
],
outputs: [
2
],
operators: [
{
inputs: [
0,
1
],
outputs: [
2
],
builtin_options_type: "Conv2DOptions",
builtin_options: {
}
}
],
name: "main"
}
],
description: "MLIR Converted."
}

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -inline -mlir-disable-inline-simplify | FileCheck %s --dump-input=fail
// RUN: tf-opt %s -inline="disable-simplify" | FileCheck %s --dump-input=fail
// Inline a function that contains only tfl ops.
func @func_with_tfl_ops(%arg0 : tensor<2xi32>) -> tensor<2xi32> {

View File

@ -1,5 +1,5 @@
// RUN: tf-opt --tfl-legalize-tf-while %s -o - | FileCheck %s --dump-input-on-failure
// RUN: tf-opt --tfl-legalize-tf-while %s -o - --tfl-legalize-tf-while --inline --mlir-disable-inline-simplify | FileCheck %s --dump-input-on-failure --check-prefix=INLINE
// RUN: tf-opt --tfl-legalize-tf-while %s -o - --tfl-legalize-tf-while --inline="disable-simplify" | FileCheck %s --dump-input-on-failure --check-prefix=INLINE
// RUN: tf-opt --tfl-legalize-tf-while %s -o - --tfl-legalize-tf-while --inline | FileCheck %s --dump-input-on-failure --check-prefix=CANON
func @while_main(%arg0: tensor<?x256x256xf32>) -> (tensor<i32>, tensor<256x256xf32>, tensor<?x256x256xf32>) attributes {tf.entry_function = {inputs = "input", outputs = "Identity,Identity_1,Identity_2"}} {

View File

@ -9,6 +9,20 @@ func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
// CHECK: return
}
// CHECK-LABEL: testAddHighDimsHaveSameShape
func @testAddHighDimsHaveSameShape(%arg0: tensor<1x2x3x4x5x6x7x8xi32>, %arg1: tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32> {
// CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"}
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6x7x8xi32>, tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32>
return %0 : tensor<1x2x3x4x5x6x7x8xi32>
}
// CHECK-LABEL: testAddTooHighBroadcastableDims
func @testAddTooHighBroadcastableDims(%arg0: tensor<1x2x3x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// expected-error @+1 {{'tfl.add' op failed to verify that operand #0 and operand #1 have the same shape or broadcastable shapes within the rank 4}}
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
%2 = "tf.LeakyRelu"(%arg0) {alpha = 0.1 : f32} : (tensor<1xf32>) -> tensor<1xf32>
return %2: tensor<1xf32>
@ -38,14 +52,14 @@ func @squeezeAndReshape(%arg0: tensor<1x1x10xf32>, %arg1: tensor<?x10xf32>) -> i
%1 = "tf.Squeeze"(%arg1) : (tensor<?x10xf32>) -> tensor<*xf32>
%2 = "tf.Const"() { value = dense<[2, 5]> : tensor<2xi32> } : () -> tensor<2xi32>
%3 = "tf.Reshape" (%0, %2) : (tensor<1x10xf32>, tensor<2xi32>) -> tensor<2x5xf32>
%4 = "some_op"(%1, %3) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
%4 = "tf.some_op"(%1, %3) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
return %4 : i32
// CHECK-LABEL: squeezeAndReshape
// CHECK: "tfl.squeeze"(%arg0) {squeeze_dims = [0]} : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
// CHECK: %1 = "tfl.squeeze"(%arg1) {squeeze_dims = []} : (tensor<?x10xf32>) -> tensor<*xf32>
// CHECK: %cst = constant dense<[2, 5]> : tensor<2xi32>
// CHECK: %2 = "tfl.reshape"(%0, %cst) : (tensor<1x10xf32>, tensor<2xi32>) -> tensor<2x5xf32>
// CHECK: %3 = "some_op"(%1, %2) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
// CHECK: %3 = "tf.some_op"(%1, %2) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
// CHECK: return
}
@ -1448,7 +1462,7 @@ func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<f32>
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<f32>) -> tensor<3x3xf32>
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// CHECK return [[MUL]] : tensor<3x3xf32>
// CHECK: return [[MUL]] : tensor<3x3xf32>
}
func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
@ -1459,5 +1473,5 @@ func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<i32>
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<i32>) -> tensor<3x3xi32>
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32>
// CHECK return [[MUL]] : tensor<3x3xi32>
// CHECK: return [[MUL]] : tensor<3x3xi32>
}

View File

@ -1,4 +1,4 @@
// RUN: tf-opt -tfl-load-recipe %s | FileCheck %s --dump-input-on-failure
// RUN: tf-opt -allow-unregistered-dialect -tfl-load-recipe %s | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: testLstm
func @testLstm(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>, %arg3: tensor<*xf32>, %arg4: tensor<*xf32>, %arg5: tensor<*xf32>, %arg6: tensor<*xf32>, %arg7: tensor<*xf32>, %arg8: tensor<*xf32>, %arg9: tensor<*xf32>, %arg10: tensor<*xf32>, %arg11: tensor<*xf32>, %arg12: tensor<*xf32>, %arg13: tensor<*xf32>, %arg14: tensor<*xf32>, %arg15: tensor<*xf32>, %arg16: tensor<*xf32>, %arg17: tensor<*xf32>, %arg18: tensor<*xf32>, %arg19: tensor<*xf32>, %arg20: tensor<*xf32>, %arg21: tensor<*xf32>, %arg22: tensor<*xf32>, %arg23: tensor<*xf32>) -> tensor<*xf32> {

View File

@ -1,6 +1,6 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -> tensor<4 x f32> {
func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
@ -9,126 +9,126 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 1, 4 ],
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "arg0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "arg1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "arg2",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "arg3",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 5,
// CHECK-NEXT: name: "arg4",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 6,
// CHECK-NEXT: name: "arg5",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 7,
// CHECK-NEXT: name: "arg6",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 8,
// CHECK-NEXT: name: "arg7",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 9,
// CHECK-NEXT: name: "arg8",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 10,
// CHECK-NEXT: name: "arg9",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 11,
// CHECK-NEXT: name: "arg10",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 12,
// CHECK-NEXT: name: "arg11",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 1, 4 ],
// CHECK-NEXT: buffer: 13,
// CHECK-NEXT: name: "arg12",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 1, 4 ],
// CHECK-NEXT: buffer: 14,
// CHECK-NEXT: name: "arg13",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 1, 4 ],
// CHECK-NEXT: buffer: 15,
// CHECK-NEXT: name: "arg14",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 1, 4 ],
// CHECK-NEXT: buffer: 16,
// CHECK-NEXT: name: "arg15",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 17,
// CHECK-NEXT: name: "arg16",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 1, 4 ],
// CHECK-NEXT: buffer: 18,
// CHECK-NEXT: name: "arg17",
// CHECK-NEXT: quantization: {
@ -163,21 +163,21 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 1, 4 ],
// CHECK-NEXT: name: "Const",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: },
// CHECK-NEXT: is_variable: true
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 1, 4 ],
// CHECK-NEXT: name: "Const1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: },
// CHECK-NEXT: is_variable: true
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 1, 4 ],
// CHECK-NEXT: buffer: 25,
// CHECK-NEXT: name: "tfl.lstm",
// CHECK-NEXT: quantization: {
@ -261,9 +261,9 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t
// CHECK-EMPTY:
^bb0(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>, %arg4: tensor<4 x f32>, %arg5: tensor<4 x f32>, %arg6: tensor<4 x f32>, %arg7: tensor<4 x f32>, %arg8: tensor<4 x f32>, %arg9: tensor<4 x f32>, %arg10: tensor<4 x f32>, %arg11: tensor<4 x f32>, %arg12: tensor<4 x f32>, %arg13: tensor<4 x f32>, %arg14: tensor<4 x f32>, %arg15: tensor<4 x f32>, %arg16: tensor<4 x f32>, %arg17: tensor<4 x f32>, %arg18: tensor<4 x f32>, %arg19: tensor<4 x f32>, %arg20: tensor<4 x f32>, %arg21: tensor<4 x f32>):
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %24 : tensor<4xf32>
^bb0(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>):
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %24 : tensor<1x4xf32>
}

View File

@ -675,6 +675,41 @@ func @testLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>
return %0 : tensor<?xf32>
}
// -----
// test invalid input dimension, the first input operand for lstm op should be at least 2D tensor.
func @testLstmWithInvalidInputDimension(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>, %arg4: tensor<4 x f32>, %arg5: tensor<4 x f32>, %arg6: tensor<4 x f32>, %arg7: tensor<4 x f32>, %arg8: tensor<4 x f32>, %arg9: tensor<4 x f32>, %arg10: tensor<4 x f32>, %arg11: tensor<4 x f32>, %arg12: tensor<4 x f32>, %arg13: tensor<4 x f32>, %arg14: tensor<4 x f32>, %arg15: tensor<4 x f32>, %arg16: tensor<4 x f32>, %arg17: tensor<4 x f32>, %arg18: tensor<4 x f32>, %arg19: tensor<4 x f32>, %arg20: tensor<4 x f32>, %arg21: tensor<4 x f32>) -> tensor<4 x f32> {
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
// expected-error @+1 {{'tfl.lstm' op the first input operand should have more than 2 dimensions.}}
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %24 : tensor<4xf32>
}
// -----
// 'input_to_output_weights' input for lstm op has unmatched rank with `input`.
func @testLstmWithInvalidInputsRankMatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4x2xf32>, %arg2: tensor<4x2xf32>, %arg3: tensor<4x2xf32>, %arg4: tensor<4x2xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>) -> tensor<1x4xf32> {
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
// expected-error @+1 {{'tfl.lstm' op inputs don't match with the dimensions.}}
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %24 : tensor<1x4xf32>
}
// -----
// Coefficient inputs of LSTM op don't match the dimension with input operand `input_to_output_weights`.
func @testLstmWithInvalidInputsRankMatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<3xf32>, %arg19: tensor<3xf32>, %arg20: tensor<3xf32>, %arg21: tensor<3xf32>) -> tensor<1x4xf32> {
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
// expected-error @+1 {{'tfl.lstm' op coefficient inputs have more than 2 dimensions or don't match the dimension with input operand `input_to_output_weights`.}}
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<1x4xf32>
return %24 : tensor<1x4xf32>
}
// -----
// test invalid kernel type

View File

@ -439,6 +439,31 @@ func @NotReorderReshapeAddIfNotTailingDim(%arg0: tensor<40x40x1xf32>) -> tensor<
// CHECK: return %[[rs2]]
}
// CHECK-LABEL: @ReorderElementwiseValueOpAndMoveOp
func @ReorderElementwiseValueOpAndMoveOp(%arg0: tensor<40x40x1xf32>) -> tensor<40x40xf32> {
%shape = constant dense<[40, 40]> : tensor<2xi32>
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<40x40x1xf32>, tensor<2xi32>) -> tensor<40x40xf32>
%2 = "tfl.relu"(%1) : (tensor<40x40xf32>) -> tensor<40x40xf32>
return %2 : tensor<40x40xf32>
// CHECK: %[[rs1:.*]] = "tfl.relu"(%arg0
// CHECK: %[[rs2:.*]] = "tfl.reshape"(%[[rs1]]
// CHECK: return %[[rs2]]
}
// CHECK-LABEL: @NotReorderElementwiseValueOpAndMoveOp
func @NotReorderElementwiseValueOpAndMoveOp(%arg0: tensor<40x40x1xf32>) -> (tensor<40x40xf32>, tensor<40x40xf32>) {
%shape = constant dense<[40, 40]> : tensor<2xi32>
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<40x40x1xf32>, tensor<2xi32>) -> tensor<40x40xf32>
%2 = "tfl.relu"(%1) : (tensor<40x40xf32>) -> tensor<40x40xf32>
return %1, %2 : tensor<40x40xf32>, tensor<40x40xf32>
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0
// CHECK: %[[rs2:.*]] = "tfl.relu"(%[[rs1]]
// CHECK: return %[[rs1]], %[[rs2]]
}
// CHECK-LABEL: @FuseFullyConnectedRelu
func @FuseFullyConnectedRelu(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> {
%0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32>
@ -450,6 +475,28 @@ func @FuseFullyConnectedRelu(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32
// CHECK: return %[[RES]]
}
// CHECK-LABEL: @FuseFullyConnectedRelu6
func @FuseFullyConnectedRelu6(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> {
%0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32>
%1 = "tfl.relu6"(%0) : (tensor<1x128xf32>) -> tensor<1x128xf32>
return %1 : tensor<1x128xf32>
// CHECK: %[[RES:[0-9].*]] = "tfl.fully_connected"
// CHECK-SAME: fused_activation_function = "RELU6"
// CHECK: return %[[RES]]
}
// CHECK-LABEL: @FuseFullyConnectedRelu1
func @FuseFullyConnectedRelu1(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> {
%0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32>
%1 = "tfl.relu_n1_to_1"(%0) : (tensor<1x128xf32>) -> tensor<1x128xf32>
return %1 : tensor<1x128xf32>
// CHECK: %[[RES:[0-9].*]] = "tfl.fully_connected"
// CHECK-SAME: fused_activation_function = "RELU_N1_TO_1"
// CHECK: return %[[RES]]
}
// CHECK-LABEL: @HardSwishPattern
func @HardSwishPattern(%arg0: tensor<1xf32>) -> tensor<1xf32> {
%three = constant dense<3.> : tensor<f32>

View File

@ -161,7 +161,7 @@ func @_functionalize_if_else_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>
}
func @_functionalize_if_then_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
%0 = "my_unknown_op.blah"() : () -> tensor<i1>
%0 = "tf.blah"() : () -> tensor<i1>
return %0 : tensor<i1>
}
@ -199,7 +199,7 @@ func @_functionalize_if_else_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>
}
func @_functionalize_if_then_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
%0 = "my_unknown_op.blah"() : () -> tensor<i1>
%0 = "tf.blah"() : () -> tensor<i1>
return %0 : tensor<i1>
}

View File

@ -29,7 +29,7 @@ limitations under the License.
namespace mlir {
/// Create a pass to convert from the TFExecutor to the TF control dialect.
std::unique_ptr<OpPassBase<FuncOp>>
std::unique_ptr<OperationPass<FuncOp>>
CreateTFExecutorToControlDialectConversion();
} // namespace mlir
@ -134,6 +134,10 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
pass_manager->addPass(
mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
if (pass_config.shape_inference) {
// Add a shape inference pass to optimize away the unnecessary casts.
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
}
// Legalize while early to allow further constant folding.
// TODO(jpienaar): This may not actually matter as we do canonicalization
// after the legalize below, for now it needs to be below the above passes
@ -160,11 +164,6 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
// constant ops.
pass_manager->addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass());
if (pass_config.shape_inference) {
// Add a shape inference pass to optimize away the unnecessary casts.
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
}
// The below passes only make sense if Builtin TFLite ops are enabled
// for emission.
if (pass_config.emit_builtin_tflite_ops) {
@ -173,7 +172,8 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
pass_manager->addPass(
mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul));
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pass_manager->addPass(mlir::TFL::CreateLegalizeTFPass());
pass_manager->addPass(
mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
pass_manager->addPass(mlir::TFL::CreateOptimizePass());
// This pass operates on TensorFlow ops but is triggered after legalization
// so that it can target constants introduced once TensorFlow Identity ops
@ -255,7 +255,8 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
// TFLite dialect passes.
pm.addPass(mlir::TFL::CreatePrepareTFPass(true));
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pm.addPass(mlir::TFL::CreateLegalizeTFPass());
pm.addPass(
mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true));
pm.addPass(mlir::TFL::CreateOptimizePass());
pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
@ -268,7 +269,7 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
}
// Registers a pass pipeline for the standard TFL passes.

Some files were not shown because too many files have changed in this diff Show More