Merge branch 'master' into master
This commit is contained in:
commit
7ba6bca1d9
12
.bazelrc
12
.bazelrc
@ -73,6 +73,10 @@
|
|||||||
# tensorflow_testing_rbe_linux: RBE options to use RBE with tensorflow-testing project on linux
|
# tensorflow_testing_rbe_linux: RBE options to use RBE with tensorflow-testing project on linux
|
||||||
# tensorflow_testing_rbe_win: RBE options to use RBE with tensorflow-testing project on windows
|
# tensorflow_testing_rbe_win: RBE options to use RBE with tensorflow-testing project on windows
|
||||||
#
|
#
|
||||||
|
# Embedded Linux options (experimental and only tested with TFLite build yet)
|
||||||
|
# elinux: General Embedded Linux options shared by all flavors.
|
||||||
|
# elinux_aarch64: Embedded Linux options for aarch64 (ARM64) CPU support.
|
||||||
|
# elinux_armhf: Embedded Linux options for armhf (ARMv7) CPU support.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -432,6 +436,14 @@ build:tensorflow_testing_rbe_linux --config=rbe_linux
|
|||||||
|
|
||||||
common:tensorflow_testing_rbe_win --remote_instance_name=projects/tensorflow-testing/instances/windows
|
common:tensorflow_testing_rbe_win --remote_instance_name=projects/tensorflow-testing/instances/windows
|
||||||
build:tensorflow_testing_rbe_win --config=tensorflow_testing_rbe
|
build:tensorflow_testing_rbe_win --config=tensorflow_testing_rbe
|
||||||
|
|
||||||
|
# TFLite build configs for generic embedded Linux
|
||||||
|
build:elinux --crosstool_top=@local_config_embedded_arm//:toolchain
|
||||||
|
build:elinux --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
|
||||||
|
build:elinux_aarch64 --config=elinux
|
||||||
|
build:elinux_aarch64 --cpu=aarch64
|
||||||
|
build:elinux_armhf --config=elinux
|
||||||
|
build:elinux_armhf --cpu=armhf
|
||||||
# END TF REMOTE BUILD EXECUTION OPTIONS
|
# END TF REMOTE BUILD EXECUTION OPTIONS
|
||||||
|
|
||||||
# Default options should come above this line
|
# Default options should come above this line
|
||||||
|
@ -1 +1 @@
|
|||||||
2.0.0
|
3.0.0
|
||||||
|
34
.github/ISSUE_TEMPLATE/00-bug-issue.md
vendored
34
.github/ISSUE_TEMPLATE/00-bug-issue.md
vendored
@ -10,32 +10,30 @@ labels: 'type:bug'
|
|||||||
we only address code/doc bugs, performance issues, feature requests and
|
we only address code/doc bugs, performance issues, feature requests and
|
||||||
build/installation issues on GitHub. tag:bug_template</em>
|
build/installation issues on GitHub. tag:bug_template</em>
|
||||||
|
|
||||||
**System information**
|
**System information**
|
||||||
- Have I written custom code (as opposed to using a stock
|
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
|
||||||
example script provided in TensorFlow):
|
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
|
||||||
- OS Platform and Distribution (e.g.,
|
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
|
||||||
Linux Ubuntu 16.04):
|
- TensorFlow installed from (source or binary):
|
||||||
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
|
- TensorFlow version (use command below):
|
||||||
the issue happens on mobile device:
|
- Python version:
|
||||||
- TensorFlow installed from (source or
|
- Bazel version (if compiling from source):
|
||||||
binary): - TensorFlow version (use command below):
|
- GCC/Compiler version (if compiling from source):
|
||||||
- Python version: - Bazel
|
- CUDA/cuDNN version:
|
||||||
version (if compiling from source):
|
- GPU model and memory:
|
||||||
- GCC/Compiler version (if compiling from
|
|
||||||
source):
|
|
||||||
- CUDA/cuDNN version: - GPU model and memory:
|
|
||||||
|
|
||||||
You can collect some of this information using our environment capture
|
You can collect some of this information using our environment capture
|
||||||
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
|
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
|
||||||
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
|
You can also obtain the TensorFlow version with:
|
||||||
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
|
1. TF 1.0: `python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"`
|
||||||
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
2. TF 2.0: `python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
||||||
|
|
||||||
|
|
||||||
**Describe the current behavior**
|
**Describe the current behavior**
|
||||||
|
|
||||||
**Describe the expected behavior**
|
**Describe the expected behavior**
|
||||||
|
|
||||||
**Standalone code to reproduce the issue**
|
**Standalone code to reproduce the issue**
|
||||||
Provide a reproducible test case that is the bare minimum necessary to generate
|
Provide a reproducible test case that is the bare minimum necessary to generate
|
||||||
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
|
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
|
||||||
|
|
||||||
|
33
.github/ISSUE_TEMPLATE/80-performance-issue.md
vendored
33
.github/ISSUE_TEMPLATE/80-performance-issue.md
vendored
@ -11,32 +11,29 @@ As per our
|
|||||||
we only address code/doc bugs, performance issues, feature requests and
|
we only address code/doc bugs, performance issues, feature requests and
|
||||||
build/installation issues on GitHub. tag:performance_template</em>
|
build/installation issues on GitHub. tag:performance_template</em>
|
||||||
|
|
||||||
**System information**
|
**System information**
|
||||||
- Have I written custom code (as opposed to using a stock
|
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
|
||||||
example script provided in TensorFlow):
|
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
|
||||||
- OS Platform and Distribution (e.g.,
|
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
|
||||||
Linux Ubuntu 16.04):
|
- TensorFlow installed from (source or binary):
|
||||||
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
|
- TensorFlow version (use command below):
|
||||||
the issue happens on mobile device:
|
- Python version:
|
||||||
- TensorFlow installed from (source or
|
- Bazel version (if compiling from source):
|
||||||
binary): - TensorFlow version (use command below):
|
- GCC/Compiler version (if compiling from source):
|
||||||
- Python version: - Bazel
|
- CUDA/cuDNN version:
|
||||||
version (if compiling from source):
|
- GPU model and memory:
|
||||||
- GCC/Compiler version (if compiling from
|
|
||||||
source):
|
|
||||||
- CUDA/cuDNN version: - GPU model and memory:
|
|
||||||
|
|
||||||
You can collect some of this information using our environment capture
|
You can collect some of this information using our environment capture
|
||||||
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
|
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
|
||||||
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
|
You can also obtain the TensorFlow version with:
|
||||||
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
|
1. TF 1.0: `python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"`
|
||||||
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
2. TF 2.0: `python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
||||||
|
|
||||||
**Describe the current behavior**
|
**Describe the current behavior**
|
||||||
|
|
||||||
**Describe the expected behavior**
|
**Describe the expected behavior**
|
||||||
|
|
||||||
**Standalone code to reproduce the issue**
|
**Standalone code to reproduce the issue**
|
||||||
Provide a reproducible test case that is the bare minimum necessary to generate
|
Provide a reproducible test case that is the bare minimum necessary to generate
|
||||||
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
|
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
|
||||||
|
|
||||||
|
39
.github/stale.yml
vendored
Normal file
39
.github/stale.yml
vendored
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
#
|
||||||
|
# THIS IS A GENERATED DOCKERFILE.
|
||||||
|
#
|
||||||
|
# This file was assembled from multiple pieces, whose use is documented
|
||||||
|
# throughout. Please refer to the TensorFlow dockerfiles documentation
|
||||||
|
# for more information.
|
||||||
|
|
||||||
|
# Number of days of inactivity before an Issue or Pull Request becomes stale
|
||||||
|
daysUntilStale: 7
|
||||||
|
# Number of days of inactivity before a stale Issue or Pull Request is closed
|
||||||
|
daysUntilClose: 7
|
||||||
|
# Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable
|
||||||
|
onlyLabels:
|
||||||
|
- awaitingResponse
|
||||||
|
# Comment to post when marking as stale. Set to `false` to disable
|
||||||
|
markComment: >
|
||||||
|
This issue has been automatically marked as stale because it has not had
|
||||||
|
recent activity. It will be closed if no further activity occurs. Thank you.
|
||||||
|
# Comment to post when removing the stale label. Set to `false` to disable
|
||||||
|
unmarkComment: false
|
||||||
|
closeComment: >
|
||||||
|
Closing as stale. Please reopen if you'd like to work on this further.
|
||||||
|
limitPerRun: 30
|
||||||
|
# Limit to only `issues` or `pulls`
|
||||||
|
only: issues
|
2
configure
vendored
2
configure
vendored
@ -4,7 +4,7 @@ set -e
|
|||||||
set -o pipefail
|
set -o pipefail
|
||||||
|
|
||||||
if [ -z "$PYTHON_BIN_PATH" ]; then
|
if [ -z "$PYTHON_BIN_PATH" ]; then
|
||||||
PYTHON_BIN_PATH=$(which python || which python3 || true)
|
PYTHON_BIN_PATH=$(which python3 || which python || true)
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Set all env variables
|
# Set all env variables
|
||||||
|
@ -50,7 +50,7 @@ _TF_WORKSPACE_ROOT = ''
|
|||||||
_TF_BAZELRC = ''
|
_TF_BAZELRC = ''
|
||||||
_TF_CURRENT_BAZEL_VERSION = None
|
_TF_CURRENT_BAZEL_VERSION = None
|
||||||
_TF_MIN_BAZEL_VERSION = '2.0.0'
|
_TF_MIN_BAZEL_VERSION = '2.0.0'
|
||||||
_TF_MAX_BAZEL_VERSION = '2.0.0'
|
_TF_MAX_BAZEL_VERSION = '3.99.0'
|
||||||
|
|
||||||
NCCL_LIB_PATHS = [
|
NCCL_LIB_PATHS = [
|
||||||
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
|
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
|
||||||
@ -58,8 +58,6 @@ NCCL_LIB_PATHS = [
|
|||||||
|
|
||||||
# List of files to configure when building Bazel on Apple platforms.
|
# List of files to configure when building Bazel on Apple platforms.
|
||||||
APPLE_BAZEL_FILES = [
|
APPLE_BAZEL_FILES = [
|
||||||
'tensorflow/lite/experimental/delegates/coreml/BUILD',
|
|
||||||
'tensorflow/lite/experimental/delegates/coreml/builders/BUILD',
|
|
||||||
'tensorflow/lite/experimental/ios/BUILD',
|
'tensorflow/lite/experimental/ios/BUILD',
|
||||||
'tensorflow/lite/experimental/objc/BUILD',
|
'tensorflow/lite/experimental/objc/BUILD',
|
||||||
'tensorflow/lite/experimental/swift/BUILD',
|
'tensorflow/lite/experimental/swift/BUILD',
|
||||||
|
@ -214,6 +214,12 @@ config_setting(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
config_setting(
|
||||||
|
name = "linux_armhf",
|
||||||
|
values = {"cpu": "armhf"},
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
config_setting(
|
config_setting(
|
||||||
name = "linux_x86_64",
|
name = "linux_x86_64",
|
||||||
values = {"cpu": "k8"},
|
values = {"cpu": "k8"},
|
||||||
@ -703,8 +709,8 @@ tf_cc_shared_object(
|
|||||||
"//tensorflow/c:version_script.lds",
|
"//tensorflow/c:version_script.lds",
|
||||||
"//tensorflow/c/eager:c_api",
|
"//tensorflow/c/eager:c_api",
|
||||||
"//tensorflow/c/eager:c_api_experimental",
|
"//tensorflow/c/eager:c_api_experimental",
|
||||||
|
"//tensorflow/core:distributed_tensorflow_dependencies",
|
||||||
"//tensorflow/core:tensorflow",
|
"//tensorflow/core:tensorflow",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -118,6 +118,12 @@ cc_library(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "c_api_macros",
|
||||||
|
hdrs = ["c_api_macros.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cuda_library(
|
tf_cuda_library(
|
||||||
name = "c_api",
|
name = "c_api",
|
||||||
hdrs = [
|
hdrs = [
|
||||||
|
@ -186,10 +186,6 @@ struct TF_Server {
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
|
||||||
|
|
||||||
TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
|
|
||||||
|
|
||||||
Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
|
Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
|
||||||
TF_Buffer* out);
|
TF_Buffer* out);
|
||||||
|
|
||||||
|
33
tensorflow/c/c_api_macros.h
Normal file
33
tensorflow/c/c_api_macros.h
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_C_API_MACROS_H_
|
||||||
|
#define TENSORFLOW_C_C_API_MACROS_H_
|
||||||
|
|
||||||
|
#ifdef SWIG
|
||||||
|
#define TF_CAPI_EXPORT
|
||||||
|
#else
|
||||||
|
#if defined(_WIN32)
|
||||||
|
#ifdef TF_COMPILE_LIBRARY
|
||||||
|
#define TF_CAPI_EXPORT __declspec(dllexport)
|
||||||
|
#else
|
||||||
|
#define TF_CAPI_EXPORT __declspec(dllimport)
|
||||||
|
#endif // TF_COMPILE_LIBRARY
|
||||||
|
#else
|
||||||
|
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
|
||||||
|
#endif // _WIN32
|
||||||
|
#endif // SWIG
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_C_API_MACROS_H_
|
@ -240,11 +240,6 @@ tf_cuda_cc_test(
|
|||||||
"c_api_remote_test.cc",
|
"c_api_remote_test.cc",
|
||||||
],
|
],
|
||||||
extra_copts = tfe_xla_copts(),
|
extra_copts = tfe_xla_copts(),
|
||||||
tags = [
|
|
||||||
"guitar",
|
|
||||||
"multi_gpu",
|
|
||||||
"no_oss",
|
|
||||||
],
|
|
||||||
deps = [
|
deps = [
|
||||||
":c_api",
|
":c_api",
|
||||||
":c_api_experimental",
|
":c_api_experimental",
|
||||||
|
@ -1587,6 +1587,7 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
|||||||
// require TFE_Op* and just convert it internally a NameAttrValue, so
|
// require TFE_Op* and just convert it internally a NameAttrValue, so
|
||||||
// consider adding an overload to the C API to make this case easier.
|
// consider adding an overload to the C API to make this case easier.
|
||||||
TFE_OpSetAttrFunction(op, attr_name, func_op);
|
TFE_OpSetAttrFunction(op, attr_name, func_op);
|
||||||
|
TFE_DeleteOp(func_op);
|
||||||
} break;
|
} break;
|
||||||
case tensorflow::AttrValue::kList:
|
case tensorflow::AttrValue::kList:
|
||||||
TF_FALLTHROUGH_INTENDED;
|
TF_FALLTHROUGH_INTENDED;
|
||||||
|
@ -129,7 +129,45 @@ void TestRemoteExecute(bool async) {
|
|||||||
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
|
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
|
||||||
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
|
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
|
||||||
|
|
||||||
void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
string MatMulFunction() {
|
||||||
|
tensorflow::FunctionDef def;
|
||||||
|
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||||
|
" signature {"
|
||||||
|
" name: 'MatMulFunction'"
|
||||||
|
" input_arg {"
|
||||||
|
" name: 'a'"
|
||||||
|
" type: DT_FLOAT"
|
||||||
|
" }"
|
||||||
|
" input_arg {"
|
||||||
|
" name: 'b'"
|
||||||
|
" type: DT_FLOAT"
|
||||||
|
" }"
|
||||||
|
" output_arg {"
|
||||||
|
" name: 'm'"
|
||||||
|
" type: DT_FLOAT"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" node_def {"
|
||||||
|
" name: 'matmul'"
|
||||||
|
" op: 'MatMul'"
|
||||||
|
" input: 'a'"
|
||||||
|
" input: 'b'"
|
||||||
|
" attr {"
|
||||||
|
" key: 'T'"
|
||||||
|
" value {"
|
||||||
|
" type: DT_FLOAT"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" ret {"
|
||||||
|
" key: 'm'"
|
||||||
|
" value: 'matmul:product'"
|
||||||
|
" }",
|
||||||
|
&def));
|
||||||
|
return def.SerializeAsString();
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
|
||||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||||
|
|
||||||
// This server def has the task index set to 0.
|
// This server def has the task index set to 0.
|
||||||
@ -169,12 +207,36 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
|||||||
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
|
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
// Handles are on task0 (local), and task2, but op is on task1.
|
TFE_Op* matmul = nullptr;
|
||||||
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2);
|
if (func) {
|
||||||
|
string function_def = MatMulFunction();
|
||||||
|
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||||
|
status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
matmul = TFE_NewOp(ctx, "MatMulFunction", status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_OpAddInput(matmul, h0_task0, status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_OpAddInput(matmul, h1_task2, status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
} else {
|
||||||
|
// Handles are on task0 (local), and task2, but op is on task1.
|
||||||
|
matmul = MatMulOp(ctx, h0_task0, h1_task2);
|
||||||
|
}
|
||||||
if (remote) {
|
if (remote) {
|
||||||
TFE_OpSetDevice(matmul, task1_name, status);
|
TFE_OpSetDevice(matmul, task1_name, status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
} else if (!async) {
|
||||||
|
// Set the local device to CPU to easily validate mirroring
|
||||||
|
string cpu_device_name;
|
||||||
|
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
|
||||||
|
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
|
||||||
|
// The input handles should never change since they have been mirrored.
|
||||||
|
ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr));
|
||||||
}
|
}
|
||||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
|
|
||||||
TFE_TensorHandle* retvals[1];
|
TFE_TensorHandle* retvals[1];
|
||||||
int num_retvals = 1;
|
int num_retvals = 1;
|
||||||
@ -182,12 +244,10 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
|||||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
// TODO(gjn): Add support for waiting on async local mirrors
|
// TODO(gjn): Add support for waiting on async local mirrors
|
||||||
if (!async) {
|
if (!remote && !async) {
|
||||||
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
|
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
|
||||||
tensorflow::EagerOperation* op =
|
|
||||||
tensorflow::OperationFromInterface(matmul->operation);
|
|
||||||
// The input handles should never change since they have been mirrored.
|
// The input handles should never change since they have been mirrored.
|
||||||
ASSERT_EQ(op->Inputs()[1], remote_arg);
|
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
|
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
|
||||||
@ -217,6 +277,9 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
|||||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
TFE_DeleteExecutor(executor);
|
TFE_DeleteExecutor(executor);
|
||||||
|
if (func) {
|
||||||
|
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
|
||||||
|
}
|
||||||
TFE_DeleteContext(ctx);
|
TFE_DeleteContext(ctx);
|
||||||
|
|
||||||
TF_DeleteStatus(status);
|
TF_DeleteStatus(status);
|
||||||
@ -227,16 +290,22 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(CAPI, RemoteExecuteSilentCopies) {
|
TEST(CAPI, RemoteExecuteSilentCopies) {
|
||||||
TestRemoteExecuteSilentCopies(false, true);
|
TestRemoteExecuteSilentCopies(false, true, false);
|
||||||
}
|
}
|
||||||
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
|
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
|
||||||
TestRemoteExecuteSilentCopies(true, true);
|
TestRemoteExecuteSilentCopies(true, true, false);
|
||||||
|
}
|
||||||
|
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
|
||||||
|
TestRemoteExecuteSilentCopies(true, true, true);
|
||||||
}
|
}
|
||||||
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
|
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
|
||||||
TestRemoteExecuteSilentCopies(false, false);
|
TestRemoteExecuteSilentCopies(false, false, false);
|
||||||
}
|
}
|
||||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
|
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
|
||||||
TestRemoteExecuteSilentCopies(true, false);
|
TestRemoteExecuteSilentCopies(true, false, false);
|
||||||
|
}
|
||||||
|
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
|
||||||
|
TestRemoteExecuteSilentCopies(true, false, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||||
|
@ -78,11 +78,18 @@ void BM_Execute(int iters, int async) {
|
|||||||
TFE_DeleteContextOptions(opts);
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
TFE_Op* matmul = TFE_NewOp(ctx, "MatMul", status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
TFE_TensorHandle* retvals[1];
|
TFE_TensorHandle* retvals[1];
|
||||||
int num_retvals = 1;
|
int num_retvals = 1;
|
||||||
tensorflow::testing::StartTiming();
|
tensorflow::testing::StartTiming();
|
||||||
for (int i = 0; i < iters; ++i) {
|
for (int i = 0; i < iters; ++i) {
|
||||||
|
TFE_OpReset(matmul, "MatMul", nullptr, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_OpAddInput(matmul, m, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_OpAddInput(matmul, m, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
}
|
}
|
||||||
@ -113,11 +120,15 @@ void BM_Execute_Identity(int iters, int async) {
|
|||||||
TFE_DeleteContextOptions(opts);
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||||
TFE_Op* identity = IdentityOp(ctx, m);
|
TFE_Op* identity = TFE_NewOp(ctx, "Identity", status);
|
||||||
TFE_TensorHandle* retvals[1];
|
TFE_TensorHandle* retvals[1];
|
||||||
int num_retvals = 1;
|
int num_retvals = 1;
|
||||||
tensorflow::testing::StartTiming();
|
tensorflow::testing::StartTiming();
|
||||||
for (int i = 0; i < iters; ++i) {
|
for (int i = 0; i < iters; ++i) {
|
||||||
|
TFE_OpReset(identity, "Identity", nullptr, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_OpAddInput(identity, m, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
TFE_Execute(identity, &retvals[0], &num_retvals, status);
|
TFE_Execute(identity, &retvals[0], &num_retvals, status);
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
}
|
}
|
||||||
@ -405,6 +416,11 @@ void TensorHandleSilentCopy(bool async,
|
|||||||
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
||||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
auto cpu_arg = tensorflow::TensorHandleFromInterface(hcpu->handle);
|
||||||
|
auto gpu_arg = tensorflow::TensorHandleFromInterface(hgpu->handle);
|
||||||
|
auto gpu_device = absl::get<tensorflow::Device*>(gpu_arg->device());
|
||||||
|
ASSERT_FALSE(cpu_arg->HasLocalMirror(gpu_device));
|
||||||
|
|
||||||
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
|
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
|
||||||
if (cpu_op) {
|
if (cpu_op) {
|
||||||
string cpu_device_name;
|
string cpu_device_name;
|
||||||
@ -420,15 +436,8 @@ void TensorHandleSilentCopy(bool async,
|
|||||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
|
TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
|
||||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
// Validate if the input was replaced with a different TensorHandle
|
// The CPU handle should have been copied and have a mirror on the GPU
|
||||||
auto arg0 = tensorflow::TensorHandleFromInterface(hcpu->handle);
|
ASSERT_TRUE(cpu_arg->HasLocalMirror(gpu_device));
|
||||||
auto arg1 = tensorflow::TensorHandleFromInterface(hgpu->handle);
|
|
||||||
tensorflow::EagerOperation* op =
|
|
||||||
tensorflow::OperationFromInterface(matmul->operation);
|
|
||||||
|
|
||||||
// The input handles should never change since they have been mirrored.
|
|
||||||
EXPECT_EQ(op->Inputs()[0], arg0);
|
|
||||||
EXPECT_EQ(op->Inputs()[1], arg1);
|
|
||||||
|
|
||||||
TFE_DeleteOp(matmul);
|
TFE_DeleteOp(matmul);
|
||||||
TFE_DeleteTensorHandle(retvals[0]);
|
TFE_DeleteTensorHandle(retvals[0]);
|
||||||
@ -626,17 +635,6 @@ void ExecuteAdd(bool async, bool forward_input, bool tfrt) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int num_retvals = 1;
|
int num_retvals = 1;
|
||||||
|
|
||||||
if (async) {
|
|
||||||
// Enqueue dummy ops so we backlog async execution & actually test async.
|
|
||||||
for (int i = 0; i < 10000; ++i) {
|
|
||||||
TFE_TensorHandle* dummy = nullptr;
|
|
||||||
TFE_Execute(add_op, &dummy, &num_retvals, status);
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
TFE_DeleteTensorHandle(dummy);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TFE_TensorHandle* retval = nullptr;
|
TFE_TensorHandle* retval = nullptr;
|
||||||
TFE_Execute(add_op, &retval, &num_retvals, status);
|
TFE_Execute(add_op, &retval, &num_retvals, status);
|
||||||
EXPECT_EQ(1, num_retvals);
|
EXPECT_EQ(1, num_retvals);
|
||||||
|
@ -38,96 +38,159 @@ typedef void (*ExecuteOperation)(TF_AbstractOp* op, int num_inputs,
|
|||||||
TF_AbstractTensor* const* inputs,
|
TF_AbstractTensor* const* inputs,
|
||||||
TF_OutputList* o, TF_ExecutionContext* ctx,
|
TF_OutputList* o, TF_ExecutionContext* ctx,
|
||||||
TF_Status* s);
|
TF_Status* s);
|
||||||
|
|
||||||
struct TF_ExecutionContext {
|
struct TF_ExecutionContext {
|
||||||
explicit TF_ExecutionContext() {}
|
// Needed to implement our own version of RTTI since dynamic_cast is not
|
||||||
absl::variant<TFE_Context*, TF_GraphContext*> ctx;
|
// supported in mobile builds.
|
||||||
ExecuteOperation execution_callback;
|
enum ExecutionContextKind { GraphContext, EagerContext };
|
||||||
};
|
explicit TF_ExecutionContext(ExecutionContextKind kind) : k(kind) {}
|
||||||
|
ExecutionContextKind getKind() const { return k; }
|
||||||
|
|
||||||
struct TF_AbstractTensor {
|
virtual void ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||||
absl::variant<TFE_TensorHandle*, TF_GraphTensor*> t;
|
TF_AbstractTensor* const* inputs,
|
||||||
};
|
TF_OutputList* o, TF_Status* s) = 0;
|
||||||
|
virtual TF_AbstractOp* CreateOperation() = 0;
|
||||||
|
virtual void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) = 0;
|
||||||
|
virtual ~TF_ExecutionContext() {}
|
||||||
|
|
||||||
struct TF_AbstractOp {
|
private:
|
||||||
string op_type;
|
const ExecutionContextKind k;
|
||||||
string op_name;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
TF_ExecutionContext* TF_NewExecutionContext() {
|
|
||||||
return new TF_ExecutionContext();
|
|
||||||
}
|
|
||||||
|
|
||||||
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete c; }
|
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete c; }
|
||||||
|
|
||||||
TF_AbstractOp* TF_NewAbstractOp() {
|
template <typename T, typename S>
|
||||||
TF_AbstractOp* op = new TF_AbstractOp;
|
T* dynamic_cast_helper(S source) {
|
||||||
return op;
|
if (source->getKind() != T::kKind) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return tensorflow::down_cast<T*>(source);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete op; }
|
class TF_GraphContext;
|
||||||
|
class TF_EagerContext;
|
||||||
TF_AbstractTensor* TF_NewAbstractTensor() {
|
|
||||||
TF_AbstractTensor* t = new TF_AbstractTensor;
|
|
||||||
return t;
|
|
||||||
}
|
|
||||||
|
|
||||||
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete t; }
|
|
||||||
|
|
||||||
struct TF_GraphContext {
|
|
||||||
TF_Graph* graph;
|
|
||||||
// TODO(srbs): Handle captures.
|
|
||||||
};
|
|
||||||
|
|
||||||
TF_GraphContext* TF_NewGraphContext(TF_Graph* g) {
|
|
||||||
auto ctx = new TF_GraphContext;
|
|
||||||
ctx->graph = g;
|
|
||||||
return ctx;
|
|
||||||
}
|
|
||||||
|
|
||||||
void TF_DeleteGraphContext(TF_GraphContext* ctx) { delete ctx; }
|
|
||||||
|
|
||||||
struct TF_GraphTensor {
|
struct TF_GraphTensor {
|
||||||
TF_Output output;
|
TF_Output output;
|
||||||
TF_GraphContext* ctx;
|
TF_GraphContext* ctx;
|
||||||
};
|
};
|
||||||
TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* ctx, TF_Output output,
|
|
||||||
TF_Status* s) {
|
struct TF_AbstractTensor {
|
||||||
TF_GraphTensor* t = new TF_GraphTensor;
|
absl::variant<TFE_TensorHandle*, TF_GraphTensor*> t;
|
||||||
t->output = output;
|
|
||||||
t->ctx = ctx;
|
~TF_AbstractTensor() {
|
||||||
return t;
|
if (absl::holds_alternative<TFE_TensorHandle*>(t)) {
|
||||||
}
|
TFE_DeleteTensorHandle(absl::get<TFE_TensorHandle*>(t));
|
||||||
TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s) {
|
} else if (absl::holds_alternative<TF_GraphTensor*>(t)) {
|
||||||
return t->output;
|
delete absl::get<TF_GraphTensor*>(t);
|
||||||
}
|
}
|
||||||
void TF_DeleteGraphTensor(TF_GraphTensor* t) { delete t; }
|
|
||||||
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
|
|
||||||
TF_Status* s) {
|
|
||||||
at->t = t;
|
|
||||||
}
|
|
||||||
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
|
||||||
TF_Status* s) {
|
|
||||||
if (!absl::holds_alternative<TFE_TensorHandle*>(at->t)) {
|
|
||||||
string msg = absl::StrCat("Not an eager tensor handle.",
|
|
||||||
reinterpret_cast<uintptr_t>(at));
|
|
||||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
return absl::get<TFE_TensorHandle*>(at->t);
|
};
|
||||||
|
|
||||||
|
struct TF_AbstractOp {
|
||||||
|
// Needed to implement our own version of RTTI since dynamic_cast is not
|
||||||
|
// supported in mobile builds.
|
||||||
|
enum AbstractOpKind { GraphOp, EagerOp };
|
||||||
|
explicit TF_AbstractOp(AbstractOpKind kind) : k(kind) {}
|
||||||
|
AbstractOpKind getKind() const { return k; }
|
||||||
|
virtual void SetOpType(const char* const op_type, TF_Status* s) = 0;
|
||||||
|
virtual void SetOpName(const char* const op_name, TF_Status* s) = 0;
|
||||||
|
virtual void SetAttrType(const char* const attr_name, TF_DataType value,
|
||||||
|
TF_Status* s) = 0;
|
||||||
|
virtual ~TF_AbstractOp() {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const AbstractOpKind k;
|
||||||
|
};
|
||||||
|
|
||||||
|
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
|
||||||
|
return c->CreateOperation();
|
||||||
}
|
}
|
||||||
void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t,
|
|
||||||
TF_Status* s) {
|
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete op; }
|
||||||
at->t = t;
|
|
||||||
}
|
class TF_GraphOp : public TF_AbstractOp {
|
||||||
TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at,
|
public:
|
||||||
TF_Status* s) {
|
explicit TF_GraphOp(TF_Graph* g) : TF_AbstractOp(kKind), g_(g) {}
|
||||||
if (!absl::holds_alternative<TF_GraphTensor*>(at->t)) {
|
void SetOpType(const char* const op_type, TF_Status* s) override {
|
||||||
string msg = absl::StrCat("Not an graph tensor handle.");
|
if (op_) {
|
||||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
TF_SetStatus(
|
||||||
return nullptr;
|
s, TF_FAILED_PRECONDITION,
|
||||||
|
absl::StrCat("SetOpType called on already built op.").c_str());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (op_name_ != nullptr) {
|
||||||
|
op_.reset(TF_NewOperation(g_, op_type, op_name_));
|
||||||
|
op_name_ = nullptr;
|
||||||
|
} else {
|
||||||
|
op_type_ = op_type;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return absl::get<TF_GraphTensor*>(at->t);
|
void SetOpName(const char* const op_name, TF_Status* s) override {
|
||||||
}
|
if (op_) {
|
||||||
|
TF_SetStatus(
|
||||||
|
s, TF_FAILED_PRECONDITION,
|
||||||
|
absl::StrCat("SetOpName called on already built op.").c_str());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (op_type_ != nullptr) {
|
||||||
|
op_.reset(TF_NewOperation(g_, op_type_, op_name));
|
||||||
|
op_type_ = nullptr;
|
||||||
|
} else {
|
||||||
|
op_name_ = op_name;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void SetAttrType(const char* const attr_name, TF_DataType value,
|
||||||
|
TF_Status* s) override {
|
||||||
|
if (!op_) {
|
||||||
|
TF_SetStatus(
|
||||||
|
s, TF_FAILED_PRECONDITION,
|
||||||
|
"op_type and op_name must be specified before specifying attrs.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
TF_SetAttrType(op_.get(), attr_name, value);
|
||||||
|
}
|
||||||
|
~TF_GraphOp() override {}
|
||||||
|
|
||||||
|
static constexpr AbstractOpKind kKind = GraphOp;
|
||||||
|
|
||||||
|
private:
|
||||||
|
friend class TF_GraphContext; // For access to op_.
|
||||||
|
TF_Graph* g_;
|
||||||
|
std::unique_ptr<TF_OperationDescription> op_;
|
||||||
|
// Hold `op_type` and `op_name` till both are available since we need both
|
||||||
|
// to build a graph operation.
|
||||||
|
const char* op_type_ = nullptr;
|
||||||
|
const char* op_name_ = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
class TF_EagerOp : public TF_AbstractOp {
|
||||||
|
public:
|
||||||
|
explicit TF_EagerOp(TFE_Context* ctx) : TF_AbstractOp(kKind), ctx_(ctx) {}
|
||||||
|
void SetOpType(const char* const op_type, TF_Status* s) override {
|
||||||
|
op_ = TFE_NewOp(ctx_, op_type, s);
|
||||||
|
}
|
||||||
|
void SetOpName(const char* const op_name, TF_Status* s) override {
|
||||||
|
// Name is ignored in eager mode.
|
||||||
|
}
|
||||||
|
void SetAttrType(const char* const attr_name, TF_DataType value,
|
||||||
|
TF_Status* s) override {
|
||||||
|
if (op_ == nullptr) {
|
||||||
|
TF_SetStatus(s, TF_FAILED_PRECONDITION,
|
||||||
|
"op_type must be specified before specifying attrs.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
TFE_OpSetAttrType(op_, attr_name, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
~TF_EagerOp() override { TFE_DeleteOp(op_); }
|
||||||
|
static constexpr AbstractOpKind kKind = EagerOp;
|
||||||
|
|
||||||
|
private:
|
||||||
|
friend class TF_EagerContext; // For access to op_.
|
||||||
|
TFE_Op* op_ = nullptr;
|
||||||
|
TFE_Context* ctx_;
|
||||||
|
};
|
||||||
|
|
||||||
bool IsEagerTensor(const TF_AbstractTensor* const t) {
|
bool IsEagerTensor(const TF_AbstractTensor* const t) {
|
||||||
return absl::holds_alternative<TFE_TensorHandle*>(t->t);
|
return absl::holds_alternative<TFE_TensorHandle*>(t->t);
|
||||||
@ -138,6 +201,221 @@ struct TF_OutputList {
|
|||||||
int expected_num_outputs = -1;
|
int expected_num_outputs = -1;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct TF_AbstractFunction {
|
||||||
|
TF_Function* func = nullptr;
|
||||||
|
|
||||||
|
~TF_AbstractFunction() { TF_DeleteFunction(func); }
|
||||||
|
};
|
||||||
|
|
||||||
|
class TF_EagerContext : public TF_ExecutionContext {
|
||||||
|
public:
|
||||||
|
TF_EagerContext() : TF_ExecutionContext(kKind) {}
|
||||||
|
|
||||||
|
void Build(TFE_ContextOptions* options, TF_Status* status) {
|
||||||
|
eager_ctx_ = TFE_NewContext(options, status);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_AbstractOp* CreateOperation() override {
|
||||||
|
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
||||||
|
return new TF_EagerOp(eager_ctx_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||||
|
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||||
|
TF_Status* s) override {
|
||||||
|
auto* eager_op = dynamic_cast_helper<TF_EagerOp>(op);
|
||||||
|
if (eager_op == nullptr) {
|
||||||
|
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||||
|
"Unable to cast TF_AbstractOp to TF_EagerOp.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto* tfe_op = eager_op->op_;
|
||||||
|
if (TF_GetCode(s) != TF_OK) return;
|
||||||
|
for (int i = 0; i < num_inputs; ++i) {
|
||||||
|
if (!IsEagerTensor(inputs[i])) {
|
||||||
|
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
TFE_OpAddInput(tfe_op, absl::get<TFE_TensorHandle*>(inputs[i]->t), s);
|
||||||
|
if (TF_GetCode(s) != TF_OK) return;
|
||||||
|
}
|
||||||
|
if (o->expected_num_outputs == -1) {
|
||||||
|
string msg =
|
||||||
|
"The number of outputs must be provided in eager mode. Use "
|
||||||
|
"TF_OutputListSetNumOutputs.";
|
||||||
|
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals;
|
||||||
|
int num_retvals = o->expected_num_outputs;
|
||||||
|
retvals.resize(num_retvals);
|
||||||
|
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
|
||||||
|
if (TF_GetCode(s) != TF_OK) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
o->outputs.clear();
|
||||||
|
o->outputs.reserve(num_retvals);
|
||||||
|
for (int i = 0; i < num_retvals; ++i) {
|
||||||
|
auto* t = new TF_AbstractTensor();
|
||||||
|
t->t = retvals[i];
|
||||||
|
o->outputs.push_back(t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) override {
|
||||||
|
TFE_ContextAddFunction(eager_ctx_, func->func, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
~TF_EagerContext() override { TFE_DeleteContext(eager_ctx_); }
|
||||||
|
|
||||||
|
static constexpr ExecutionContextKind kKind = EagerContext;
|
||||||
|
|
||||||
|
private:
|
||||||
|
friend TFE_Context* TF_ExecutionContextGetTFEContext(
|
||||||
|
TF_ExecutionContext* ctx);
|
||||||
|
TFE_Context* eager_ctx_;
|
||||||
|
};
|
||||||
|
|
||||||
|
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete t; }
|
||||||
|
|
||||||
|
TF_GraphContext* GetGraphContext(TF_AbstractTensor const* t) {
|
||||||
|
return absl::get<TF_GraphTensor*>(t->t)->ctx;
|
||||||
|
}
|
||||||
|
|
||||||
|
class TF_GraphContext : public TF_ExecutionContext {
|
||||||
|
public:
|
||||||
|
TF_GraphContext()
|
||||||
|
: TF_ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {}
|
||||||
|
|
||||||
|
TF_AbstractOp* CreateOperation() override {
|
||||||
|
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
||||||
|
return new TF_GraphOp(graph_.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||||
|
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||||
|
TF_Status* s) override {
|
||||||
|
auto* graph_op = dynamic_cast_helper<TF_GraphOp>(op);
|
||||||
|
if (graph_op == nullptr) {
|
||||||
|
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||||
|
"Unable to cast TF_AbstractOp to TF_GraphOp.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto* tf_opdesc = graph_op->op_.release();
|
||||||
|
for (int i = 0; i < num_inputs; ++i) {
|
||||||
|
auto* input = inputs[i];
|
||||||
|
if (IsEagerTensor(input)) {
|
||||||
|
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||||
|
"Capturing eager tensors is not supported yet.");
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
if (GetGraphContext(input) != this) {
|
||||||
|
TF_SetStatus(
|
||||||
|
s, TF_INVALID_ARGUMENT,
|
||||||
|
"Capturing tensors from other graphs is not supported yet.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
TF_AddInput(tf_opdesc, absl::get<TF_GraphTensor*>(input->t)->output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto* operation = TF_FinishOperation(tf_opdesc, s);
|
||||||
|
// TF_FinishOperation deletes `tf_opdesc` so clear its reference.
|
||||||
|
graph_op->op_ = nullptr;
|
||||||
|
if (TF_GetCode(s) != TF_OK) return;
|
||||||
|
int num_outputs = TF_OperationNumOutputs(operation);
|
||||||
|
o->outputs.clear();
|
||||||
|
o->outputs.reserve(num_outputs);
|
||||||
|
for (int i = 0; i < num_outputs; ++i) {
|
||||||
|
auto* t = new TF_AbstractTensor;
|
||||||
|
TF_GraphTensor* graph_t = new TF_GraphTensor;
|
||||||
|
graph_t->ctx = this;
|
||||||
|
graph_t->output = {operation, i};
|
||||||
|
t->t = graph_t;
|
||||||
|
o->outputs.push_back(t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_Function* ToFunction(const char* fn_name, int num_inputs,
|
||||||
|
const TF_AbstractTensor* inputs, int num_outputs,
|
||||||
|
const TF_AbstractTensor* outputs,
|
||||||
|
TF_Status* status) const {
|
||||||
|
std::vector<TF_Output> graph_inputs;
|
||||||
|
graph_inputs.resize(num_inputs);
|
||||||
|
std::vector<TF_Output> graph_outputs;
|
||||||
|
graph_outputs.resize(num_outputs);
|
||||||
|
for (int i = 0; i < num_inputs; i++) {
|
||||||
|
graph_inputs[i] = absl::get<TF_GraphTensor*>(inputs[i].t)->output;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < num_outputs; i++) {
|
||||||
|
graph_outputs[i] = absl::get<TF_GraphTensor*>(outputs[i].t)->output;
|
||||||
|
}
|
||||||
|
|
||||||
|
return TF_GraphToFunction(graph_.get(), fn_name, 0, -1, nullptr,
|
||||||
|
graph_inputs.size(), graph_inputs.data(),
|
||||||
|
graph_outputs.size(), graph_outputs.data(),
|
||||||
|
nullptr, nullptr, fn_name, status);
|
||||||
|
}
|
||||||
|
|
||||||
|
void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) override {
|
||||||
|
TF_SetStatus(s, TF_UNIMPLEMENTED,
|
||||||
|
"Registering graph functions has not been implemented yet.");
|
||||||
|
}
|
||||||
|
|
||||||
|
~TF_GraphContext() override {}
|
||||||
|
|
||||||
|
static constexpr ExecutionContextKind kKind = GraphContext;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TF_GraphContextOptions {};
|
||||||
|
struct TF_EagerContextOptions {
|
||||||
|
explicit TF_EagerContextOptions(TFE_ContextOptions* options)
|
||||||
|
: options(options) {}
|
||||||
|
TFE_ContextOptions* options; // Not owned.
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TF_ExecutionContextOptions {
|
||||||
|
absl::variant<TF_GraphContextOptions*, TF_EagerContextOptions*> options;
|
||||||
|
~TF_ExecutionContextOptions() {
|
||||||
|
if (absl::holds_alternative<TF_GraphContextOptions*>(options)) {
|
||||||
|
delete absl::get<TF_GraphContextOptions*>(options);
|
||||||
|
} else if (absl::holds_alternative<TF_EagerContextOptions*>(options)) {
|
||||||
|
delete absl::get<TF_EagerContextOptions*>(options);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TF_ExecutionContextOptions* TF_NewGraphContextOptions() {
|
||||||
|
auto* options = new TF_ExecutionContextOptions();
|
||||||
|
options->options = new TF_GraphContextOptions();
|
||||||
|
return options;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TF_DeleteExecutionContextOptions(TF_ExecutionContextOptions* options) {
|
||||||
|
delete options;
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_ExecutionContextOptions* TF_NewEagerContextOptions(
|
||||||
|
TFE_ContextOptions* tfe_options) {
|
||||||
|
auto* options = new TF_ExecutionContextOptions();
|
||||||
|
options->options = new TF_EagerContextOptions(tfe_options);
|
||||||
|
return options;
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_ExecutionContext* TF_NewExecutionContext(TF_ExecutionContextOptions* options,
|
||||||
|
TF_Status* s) {
|
||||||
|
if (absl::holds_alternative<TF_EagerContextOptions*>(options->options)) {
|
||||||
|
auto* ctx = new TF_EagerContext();
|
||||||
|
ctx->Build(absl::get<TF_EagerContextOptions*>(options->options)->options,
|
||||||
|
s);
|
||||||
|
return ctx;
|
||||||
|
} else {
|
||||||
|
return new TF_GraphContext();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TF_OutputList* TF_NewOutputList() { return new TF_OutputList; }
|
TF_OutputList* TF_NewOutputList() { return new TF_OutputList; }
|
||||||
void TF_DeleteOutputList(TF_OutputList* o) { delete o; }
|
void TF_DeleteOutputList(TF_OutputList* o) { delete o; }
|
||||||
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
|
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
|
||||||
@ -149,113 +427,74 @@ TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
|
|||||||
return o->outputs[i];
|
return o->outputs[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
void ExecuteOperationEager(TF_AbstractOp* op, int num_inputs,
|
|
||||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
|
||||||
TF_ExecutionContext* ctx, TF_Status* s) {
|
|
||||||
auto* tfe_op =
|
|
||||||
TFE_NewOp(absl::get<TFE_Context*>(ctx->ctx), op->op_type.c_str(), s);
|
|
||||||
if (TF_GetCode(s) != TF_OK) return;
|
|
||||||
for (int i = 0; i < num_inputs; ++i) {
|
|
||||||
if (!IsEagerTensor(inputs[i])) {
|
|
||||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
TFE_OpAddInput(tfe_op, absl::get<TFE_TensorHandle*>(inputs[i]->t), s);
|
|
||||||
if (TF_GetCode(s) != TF_OK) return;
|
|
||||||
}
|
|
||||||
if (o->expected_num_outputs == -1) {
|
|
||||||
string msg =
|
|
||||||
"The number of outputs must be provided in eager mode. Use "
|
|
||||||
"TF_OutputListSetNumOutputs.";
|
|
||||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals;
|
|
||||||
int num_retvals = o->expected_num_outputs;
|
|
||||||
retvals.resize(num_retvals);
|
|
||||||
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
|
|
||||||
TFE_DeleteOp(tfe_op);
|
|
||||||
if (TF_GetCode(s) != TF_OK) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
o->outputs.clear();
|
|
||||||
o->outputs.reserve(num_retvals);
|
|
||||||
for (int i = 0; i < num_retvals; ++i) {
|
|
||||||
auto* t = TF_NewAbstractTensor();
|
|
||||||
t->t = retvals[i];
|
|
||||||
o->outputs.push_back(t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_GraphContext* GetGraphContext(TF_AbstractTensor const* t) {
|
|
||||||
return absl::get<TF_GraphTensor*>(t->t)->ctx;
|
|
||||||
}
|
|
||||||
|
|
||||||
void ExecuteOperationGraph(TF_AbstractOp* op, int num_inputs,
|
|
||||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
|
||||||
TF_ExecutionContext* ctx, TF_Status* s) {
|
|
||||||
TF_GraphContext* graph_ctx = absl::get<TF_GraphContext*>(ctx->ctx);
|
|
||||||
TF_Graph* g = graph_ctx->graph;
|
|
||||||
auto* tf_opdesc =
|
|
||||||
TF_NewOperation(g, op->op_type.c_str(), op->op_name.c_str());
|
|
||||||
for (int i = 0; i < num_inputs; ++i) {
|
|
||||||
auto* input = inputs[i];
|
|
||||||
if (IsEagerTensor(input)) {
|
|
||||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
|
||||||
"Capturing eager tensors is not supported yet.");
|
|
||||||
return;
|
|
||||||
} else {
|
|
||||||
if (GetGraphContext(input) != graph_ctx) {
|
|
||||||
TF_SetStatus(
|
|
||||||
s, TF_INVALID_ARGUMENT,
|
|
||||||
"Capturing tensors from other graphs is not supported yet.");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
TF_AddInput(tf_opdesc, absl::get<TF_GraphTensor*>(input->t)->output);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto* operation = TF_FinishOperation(tf_opdesc, s);
|
|
||||||
if (TF_GetCode(s) != TF_OK) return;
|
|
||||||
int num_outputs = TF_OperationNumOutputs(operation);
|
|
||||||
o->outputs.clear();
|
|
||||||
o->outputs.reserve(num_outputs);
|
|
||||||
for (int i = 0; i < num_outputs; ++i) {
|
|
||||||
auto* t = TF_NewAbstractTensor();
|
|
||||||
TF_GraphTensor* output_t = TF_NewGraphTensor(graph_ctx, {operation, i}, s);
|
|
||||||
if (TF_GetCode(s) != TF_OK) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
t->t = output_t;
|
|
||||||
o->outputs.push_back(t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context,
|
|
||||||
TFE_Context* eager_context,
|
|
||||||
TF_Status* s) {
|
|
||||||
context->ctx = eager_context;
|
|
||||||
context->execution_callback = &ExecuteOperationEager;
|
|
||||||
}
|
|
||||||
|
|
||||||
void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context,
|
|
||||||
TF_GraphContext* graph_context,
|
|
||||||
TF_Status* s) {
|
|
||||||
context->ctx = graph_context;
|
|
||||||
context->execution_callback = &ExecuteOperationGraph;
|
|
||||||
}
|
|
||||||
|
|
||||||
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
||||||
TF_Status* s) {
|
TF_Status* s) {
|
||||||
op->op_type = op_type;
|
op->SetOpType(op_type, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
|
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
|
||||||
TF_Status* s) {
|
TF_Status* s) {
|
||||||
op->op_name = op_name;
|
op->SetOpName(op_name, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
|
||||||
|
TF_DataType value, TF_Status* s) {
|
||||||
|
op->SetAttrType(attr_name, value, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||||
TF_ExecutionContext* ctx, TF_Status* s) {
|
TF_ExecutionContext* ctx, TF_Status* s) {
|
||||||
ctx->execution_callback(op, num_inputs, inputs, o, ctx, s);
|
ctx->ExecuteOperation(op, num_inputs, inputs, o, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_AbstractFunction* TF_ExecutionContextToFunction(
|
||||||
|
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
||||||
|
const TF_AbstractTensor* inputs, int num_outputs,
|
||||||
|
const TF_AbstractTensor* outputs, TF_Status* status) {
|
||||||
|
auto* graph_ctx = dynamic_cast_helper<const TF_GraphContext>(fn_body);
|
||||||
|
if (graph_ctx == nullptr) {
|
||||||
|
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||||
|
"fn_body is not a TF_GraphContext.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
TF_AbstractFunction* func = new TF_AbstractFunction;
|
||||||
|
func->func = graph_ctx->ToFunction(fn_name, num_inputs, inputs, num_outputs,
|
||||||
|
outputs, status);
|
||||||
|
return func;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TF_DeleteAbstractFunction(TF_AbstractFunction* func) { delete func; }
|
||||||
|
|
||||||
|
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext* ctx,
|
||||||
|
TF_AbstractFunction* func,
|
||||||
|
TF_Status* s) {
|
||||||
|
ctx->RegisterFunction(func, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Temporary APIs till we figure out how to create scalar valued Eager
|
||||||
|
// tensors and how to get value out of eager abstract tensors.
|
||||||
|
TF_AbstractTensor* TF_NewAbstractTensor() {
|
||||||
|
TF_AbstractTensor* t = new TF_AbstractTensor;
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
|
||||||
|
TF_Status* s) {
|
||||||
|
at->t = t;
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
||||||
|
TF_Status* s) {
|
||||||
|
if (!absl::holds_alternative<TFE_TensorHandle*>(at->t)) {
|
||||||
|
string msg = absl::StrCat("Not an eager tensor handle.",
|
||||||
|
reinterpret_cast<uintptr_t>(at));
|
||||||
|
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return absl::get<TFE_TensorHandle*>(at->t);
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx) {
|
||||||
|
return dynamic_cast_helper<TF_EagerContext>(ctx)->eager_ctx_;
|
||||||
}
|
}
|
||||||
|
@ -15,8 +15,8 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
|
#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
|
||||||
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
|
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
|
||||||
|
|
||||||
#include "tensorflow/c/c_api.h"
|
|
||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
@ -41,32 +41,19 @@ typedef struct TF_AbstractTensor TF_AbstractTensor;
|
|||||||
// could contain the op type and other attributes.
|
// could contain the op type and other attributes.
|
||||||
typedef struct TF_AbstractOp TF_AbstractOp;
|
typedef struct TF_AbstractOp TF_AbstractOp;
|
||||||
|
|
||||||
TF_ExecutionContext* TF_NewExecutionContext();
|
// `TF_ExecutionContextOptions` define what type of `TF_ExecutionContext` is
|
||||||
|
// created. It can be used to pass context specific params.
|
||||||
|
typedef struct TF_ExecutionContextOptions TF_ExecutionContextOptions;
|
||||||
|
void TF_DeleteExecutionContextOptions(TF_ExecutionContextOptions*);
|
||||||
|
|
||||||
|
TF_ExecutionContext* TF_NewExecutionContext(TF_ExecutionContextOptions*,
|
||||||
|
TF_Status* s);
|
||||||
void TF_DeleteExecutionContext(TF_ExecutionContext*);
|
void TF_DeleteExecutionContext(TF_ExecutionContext*);
|
||||||
|
|
||||||
TF_AbstractOp* TF_NewAbstractOp();
|
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx);
|
||||||
void TF_DeleteAbstractOp(TF_AbstractOp*);
|
void TF_DeleteAbstractOp(TF_AbstractOp*);
|
||||||
|
|
||||||
TF_AbstractTensor* TF_NewAbstractTensor();
|
|
||||||
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
|
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
|
||||||
// APIs for Eager and graph modes
|
|
||||||
// -----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
// Keeps track of the current graph and other state e.g. captures etc.
|
|
||||||
typedef struct TF_GraphContext TF_GraphContext;
|
|
||||||
TF_GraphContext* TF_NewGraphContext(TF_Graph*);
|
|
||||||
void TF_DeleteGraphContext(TF_GraphContext*);
|
|
||||||
|
|
||||||
// `eager_context` must outlive `context`.
|
|
||||||
void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context,
|
|
||||||
TFE_Context* eager_context, TF_Status*);
|
|
||||||
// `graph_context` must outlive `context`.
|
|
||||||
void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context,
|
|
||||||
TF_GraphContext* graph_context,
|
|
||||||
TF_Status*);
|
|
||||||
|
|
||||||
// TODO(srbs): Add APIs for specifying attrs etc.
|
// TODO(srbs): Add APIs for specifying attrs etc.
|
||||||
// `op_type` must outlive `op`.
|
// `op_type` must outlive `op`.
|
||||||
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
||||||
@ -74,25 +61,9 @@ void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
|||||||
// `op_name` must outlive `op`.
|
// `op_name` must outlive `op`.
|
||||||
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
|
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
|
||||||
TF_Status* s);
|
TF_Status* s);
|
||||||
|
// `attr_name` must outlive `op`.
|
||||||
// Wrapper for TF_Output but contains a pointer to TF_GraphContext as well.
|
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
|
||||||
typedef struct TF_GraphTensor TF_GraphTensor;
|
TF_DataType value, TF_Status* s);
|
||||||
TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* c, TF_Output t,
|
|
||||||
TF_Status* s);
|
|
||||||
TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s);
|
|
||||||
void TF_DeleteGraphTensor(TF_GraphTensor* t);
|
|
||||||
|
|
||||||
// `t` must outlive `at`.
|
|
||||||
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
|
|
||||||
TF_Status* s);
|
|
||||||
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
|
||||||
TF_Status* s);
|
|
||||||
|
|
||||||
// `t` must outlive `at`.
|
|
||||||
void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t,
|
|
||||||
TF_Status* s);
|
|
||||||
TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at,
|
|
||||||
TF_Status* s);
|
|
||||||
|
|
||||||
// TF_OutputList just lets us not specify the number of outputs of an operation
|
// TF_OutputList just lets us not specify the number of outputs of an operation
|
||||||
// beforehand. This forces a memory allocation in the runtime, which is bad, but
|
// beforehand. This forces a memory allocation in the runtime, which is bad, but
|
||||||
@ -104,6 +75,17 @@ void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*);
|
|||||||
int TF_OutputListNumOutputs(TF_OutputList* o);
|
int TF_OutputListNumOutputs(TF_OutputList* o);
|
||||||
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
|
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
|
||||||
|
|
||||||
|
// Stores a function representation that can be used for execution or for
|
||||||
|
// setting functional attributes of other composite ops e.g. control flow.
|
||||||
|
typedef struct TF_AbstractFunction TF_AbstractFunction;
|
||||||
|
TF_AbstractFunction* TF_ExecutionContextToFunction(
|
||||||
|
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
||||||
|
const TF_AbstractTensor* inputs, int num_outputs,
|
||||||
|
const TF_AbstractTensor* outputs, TF_Status* status);
|
||||||
|
void TF_DeleteAbstractFunction(TF_AbstractFunction*);
|
||||||
|
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext*,
|
||||||
|
TF_AbstractFunction*, TF_Status*);
|
||||||
|
|
||||||
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
|
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
|
||||||
// capture some inputs and then add a node in the graph, and after
|
// capture some inputs and then add a node in the graph, and after
|
||||||
// execution/node creation it'll go and record things that happened in any tape
|
// execution/node creation it'll go and record things that happened in any tape
|
||||||
@ -112,6 +94,23 @@ void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
|||||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||||
TF_ExecutionContext* ctx, TF_Status* s);
|
TF_ExecutionContext* ctx, TF_Status* s);
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// APIs specific to Eager and graph modes
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
TF_ExecutionContextOptions* TF_NewGraphContextOptions();
|
||||||
|
TF_ExecutionContextOptions* TF_NewEagerContextOptions(TFE_ContextOptions*);
|
||||||
|
|
||||||
|
// Temporary APIs till we figure out how to create scalar valued Eager
|
||||||
|
// tensors and how to get value out of eager abstract tensors.
|
||||||
|
TF_AbstractTensor* TF_NewAbstractTensor();
|
||||||
|
void TF_AbstractTensorSetEagerTensor(
|
||||||
|
TF_AbstractTensor* at, TFE_TensorHandle* t,
|
||||||
|
TF_Status* s); // `at` takes ownership of `t`.
|
||||||
|
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
||||||
|
TF_Status* s);
|
||||||
|
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext*);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} /* end extern "C" */
|
} /* end extern "C" */
|
||||||
#endif
|
#endif
|
||||||
|
@ -33,26 +33,25 @@ namespace tensorflow {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
TEST(UnifedCAPI, TestBasicEager) {
|
TEST(UnifedCAPI, TestBasicEager) {
|
||||||
TF_ExecutionContext* ctx = TF_NewExecutionContext();
|
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
TFE_Context* eager_ctx = TFE_NewContext(opts, status.get());
|
TF_ExecutionContextOptions* options = TF_NewEagerContextOptions(opts);
|
||||||
|
TF_ExecutionContext* ctx = TF_NewExecutionContext(options, status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
TFE_DeleteContextOptions(opts);
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
// Enter the eager context.
|
|
||||||
TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
// Build an abstract input tensor.
|
// Build an abstract input tensor.
|
||||||
|
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx);
|
||||||
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
|
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||||
TF_AbstractTensor* at = TF_NewAbstractTensor();
|
TF_AbstractTensor* at = TF_NewAbstractTensor();
|
||||||
TF_AbstractTensorSetEagerTensor(at, t, status.get());
|
TF_AbstractTensorSetEagerTensor(at, t, status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
// Build an abstract operation.
|
// Build an abstract operation.
|
||||||
auto* op = TF_NewAbstractOp();
|
auto* op = TF_NewAbstractOp(ctx);
|
||||||
TF_AbstractOpSetOpType(op, "Add", status.get());
|
TF_AbstractOpSetOpType(op, "Add", status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
@ -69,7 +68,6 @@ TEST(UnifedCAPI, TestBasicEager) {
|
|||||||
// Clean up operation and inputs.
|
// Clean up operation and inputs.
|
||||||
TF_DeleteAbstractOp(op);
|
TF_DeleteAbstractOp(op);
|
||||||
TF_DeleteAbstractTensor(at);
|
TF_DeleteAbstractTensor(at);
|
||||||
TFE_DeleteTensorHandle(t);
|
|
||||||
|
|
||||||
// Verify the results.
|
// Verify the results.
|
||||||
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
|
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
|
||||||
@ -83,100 +81,98 @@ TEST(UnifedCAPI, TestBasicEager) {
|
|||||||
|
|
||||||
TF_DeleteTensor(result_tensor);
|
TF_DeleteTensor(result_tensor);
|
||||||
TF_DeleteAbstractTensor(result);
|
TF_DeleteAbstractTensor(result);
|
||||||
TFE_DeleteTensorHandle(result_t);
|
|
||||||
TF_DeleteOutputList(o);
|
TF_DeleteOutputList(o);
|
||||||
TFE_DeleteContext(eager_ctx);
|
|
||||||
TF_DeleteExecutionContext(ctx);
|
TF_DeleteExecutionContext(ctx);
|
||||||
|
TF_DeleteExecutionContextOptions(options);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnifedCAPI, TestBasicGraph) {
|
TEST(UnifedCAPI, TestBasicGraph) {
|
||||||
TF_ExecutionContext* ctx = TF_NewExecutionContext();
|
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
|
||||||
// Enter a graph context.
|
TF_ExecutionContext* graph_ctx =
|
||||||
TF_Graph* g = TF_NewGraph();
|
TF_NewExecutionContext(options, status.get());
|
||||||
TF_GraphContext* graph_context = TF_NewGraphContext(g);
|
|
||||||
TF_ExecutionContextSetGraphContext(ctx, graph_context, status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
// Add a placeholder to the graph.
|
// Add a placeholder to the graph.
|
||||||
auto* placeholder_op = TF_NewOperation(g, "Placeholder", "Placeholder");
|
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
|
||||||
TF_SetAttrType(placeholder_op, "dtype", TF_FLOAT);
|
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
|
||||||
auto* operation = TF_FinishOperation(placeholder_op, status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
TF_Output placeholder_t = {operation, 0};
|
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
|
||||||
TF_GraphTensor* graph_t =
|
|
||||||
TF_NewGraphTensor(graph_context, placeholder_t, status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
TF_AbstractTensor* t = TF_NewAbstractTensor();
|
TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get());
|
||||||
TF_AbstractTensorSetGraphTensor(t, graph_t, status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
|
||||||
|
|
||||||
// Build an abstract operation.
|
|
||||||
auto* op = TF_NewAbstractOp();
|
|
||||||
TF_AbstractOpSetOpType(op, "Add", status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
|
||||||
TF_AbstractOpSetOpName(op, "my_add", status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
// Build inputs and outputs.
|
// Build inputs and outputs.
|
||||||
TF_AbstractTensor* inputs[2] = {t, t};
|
TF_OutputList* placeholder_outputs = TF_NewOutputList();
|
||||||
TF_OutputList* o = TF_NewOutputList();
|
|
||||||
|
|
||||||
// Execute.
|
// Execute.
|
||||||
TF_ExecuteOperation(op, 2, inputs, o, ctx, status.get());
|
TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs,
|
||||||
|
graph_ctx, status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs));
|
||||||
|
TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0);
|
||||||
|
|
||||||
|
// Delete placeholder op.
|
||||||
|
TF_DeleteAbstractOp(placeholder_op);
|
||||||
|
|
||||||
|
// Build an abstract operation.
|
||||||
|
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||||
|
TF_AbstractOpSetOpType(add_op, "Add", status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
TF_AbstractOpSetOpName(add_op, "my_add", status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
|
// Build inputs and outputs.
|
||||||
|
TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t};
|
||||||
|
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||||
|
|
||||||
|
// Execute.
|
||||||
|
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
TF_AbstractTensor* output_t = TF_OutputListGet(add_outputs, 0);
|
||||||
|
|
||||||
// Clean up operation and inputs.
|
// Clean up operation and inputs.
|
||||||
TF_DeleteAbstractOp(op);
|
TF_DeleteAbstractOp(add_op);
|
||||||
TF_DeleteAbstractTensor(t);
|
|
||||||
TF_DeleteGraphTensor(graph_t);
|
|
||||||
|
|
||||||
TF_AbstractTensor* result = TF_OutputListGet(o, 0);
|
|
||||||
TF_GraphTensor* result_graph_tensor =
|
|
||||||
TF_AbstractTensorGetGraphTensor(result, status.get());
|
|
||||||
TF_DeleteAbstractTensor(result);
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
|
||||||
TF_Output result_output =
|
|
||||||
TF_GraphTensorToOutput(result_graph_tensor, status.get());
|
|
||||||
TF_DeleteGraphTensor(result_graph_tensor);
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
|
||||||
string fn_name = "double";
|
string fn_name = "double";
|
||||||
TF_Function* f = TF_GraphToFunction(
|
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
|
||||||
g, fn_name.c_str(), 0, -1, nullptr, 1, &placeholder_t, 1, &result_output,
|
graph_ctx, fn_name.c_str(), 1, placeholder_t, 1, output_t, status.get());
|
||||||
nullptr, nullptr, fn_name.c_str(), status.get());
|
TF_DeleteAbstractTensor(placeholder_t);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
TF_DeleteAbstractTensor(output_t);
|
||||||
|
|
||||||
// Build an eager context to run the function.
|
// Build eager context.
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
TFE_Context* eager_ctx = TFE_NewContext(opts, status.get());
|
TF_ExecutionContextOptions* eager_ctx_options =
|
||||||
|
TF_NewEagerContextOptions(opts);
|
||||||
|
TF_ExecutionContext* eager_execution_ctx =
|
||||||
|
TF_NewExecutionContext(eager_ctx_options, status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
TFE_DeleteContextOptions(opts);
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
// Build the abstract op to run the function.
|
TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, status.get());
|
||||||
TFE_ContextAddFunction(eager_ctx, f, status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
TF_AbstractOp* fn_op = TF_NewAbstractOp();
|
// Build the abstract op to run the function.
|
||||||
|
TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx);
|
||||||
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), status.get());
|
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
// Build an abstract input tensor.
|
// Build an abstract input tensor.
|
||||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
|
|
||||||
TF_AbstractTensor* input_t = TF_NewAbstractTensor();
|
TF_AbstractTensor* input_t = TF_NewAbstractTensor();
|
||||||
|
TFE_Context* eager_ctx =
|
||||||
|
TF_ExecutionContextGetTFEContext(eager_execution_ctx);
|
||||||
|
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||||
TF_AbstractTensorSetEagerTensor(input_t, input_eager, status.get());
|
TF_AbstractTensorSetEagerTensor(input_t, input_eager, status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
// Enter the eager context.
|
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
|
||||||
TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
TF_OutputListSetNumOutputs(o, 1, status.get());
|
TF_ExecuteOperation(fn_op, 1, &input_t, add_outputs, eager_execution_ctx,
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
status.get());
|
||||||
TF_ExecuteOperation(fn_op, 1, &input_t, o, ctx, status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
|
ASSERT_EQ(1, TF_OutputListNumOutputs(add_outputs));
|
||||||
TF_AbstractTensor* final_result = TF_OutputListGet(o, 0);
|
TF_AbstractTensor* final_result = TF_OutputListGet(add_outputs, 0);
|
||||||
TFE_TensorHandle* final =
|
TFE_TensorHandle* final =
|
||||||
TF_AbstractTensorGetEagerTensor(final_result, status.get());
|
TF_AbstractTensorGetEagerTensor(final_result, status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
@ -185,19 +181,201 @@ TEST(UnifedCAPI, TestBasicGraph) {
|
|||||||
float* f_value = static_cast<float*>(TF_TensorData(f_t));
|
float* f_value = static_cast<float*>(TF_TensorData(f_t));
|
||||||
ASSERT_EQ(*f_value, 4.0);
|
ASSERT_EQ(*f_value, 4.0);
|
||||||
|
|
||||||
TF_DeleteOutputList(o);
|
TF_DeleteOutputList(add_outputs);
|
||||||
|
TF_DeleteOutputList(placeholder_outputs);
|
||||||
TF_DeleteAbstractOp(fn_op);
|
TF_DeleteAbstractOp(fn_op);
|
||||||
TF_DeleteAbstractTensor(input_t);
|
TF_DeleteAbstractTensor(input_t);
|
||||||
TFE_DeleteTensorHandle(input_eager);
|
|
||||||
TF_DeleteAbstractTensor(final_result);
|
TF_DeleteAbstractTensor(final_result);
|
||||||
TFE_DeleteTensorHandle(final);
|
|
||||||
TF_DeleteTensor(f_t);
|
TF_DeleteTensor(f_t);
|
||||||
TF_DeleteFunction(f);
|
TF_DeleteAbstractFunction(func);
|
||||||
|
|
||||||
|
TF_DeleteExecutionContext(graph_ctx);
|
||||||
|
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||||
|
TF_DeleteExecutionContextOptions(eager_ctx_options);
|
||||||
|
TF_DeleteExecutionContextOptions(options);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UnifedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TF_ExecutionContextOptions* options = TF_NewEagerContextOptions(opts);
|
||||||
|
TF_ExecutionContext* ctx = TF_NewExecutionContext(options, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
|
||||||
|
ctx, nullptr, 0, nullptr, 0, nullptr, status.get());
|
||||||
|
ASSERT_EQ(nullptr, func);
|
||||||
|
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||||
|
|
||||||
TF_DeleteGraphContext(graph_context);
|
|
||||||
TF_DeleteGraph(g);
|
|
||||||
TFE_DeleteContext(eager_ctx);
|
|
||||||
TF_DeleteExecutionContext(ctx);
|
TF_DeleteExecutionContext(ctx);
|
||||||
|
TF_DeleteExecutionContextOptions(options);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UnifedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
|
||||||
|
TF_ExecutionContext* graph_ctx =
|
||||||
|
TF_NewExecutionContext(options, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
|
// Add a placeholder to the graph.
|
||||||
|
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
|
||||||
|
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
|
// This should fail.
|
||||||
|
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
|
||||||
|
ASSERT_EQ(TF_FAILED_PRECONDITION, TF_GetCode(status.get()));
|
||||||
|
|
||||||
|
TF_DeleteAbstractOp(placeholder_op);
|
||||||
|
TF_DeleteExecutionContext(graph_ctx);
|
||||||
|
TF_DeleteExecutionContextOptions(options);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UnifedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
|
||||||
|
TF_ExecutionContext* graph_ctx =
|
||||||
|
TF_NewExecutionContext(options, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
|
// Add a placeholder to the graph.
|
||||||
|
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
|
||||||
|
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
|
// This should fail.
|
||||||
|
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
|
||||||
|
ASSERT_EQ(TF_FAILED_PRECONDITION, TF_GetCode(status.get()));
|
||||||
|
|
||||||
|
TF_DeleteAbstractOp(placeholder_op);
|
||||||
|
TF_DeleteExecutionContext(graph_ctx);
|
||||||
|
TF_DeleteExecutionContextOptions(options);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||||
|
// Build an Eager context.
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TF_ExecutionContextOptions* options = TF_NewEagerContextOptions(opts);
|
||||||
|
TF_ExecutionContext* ctx = TF_NewExecutionContext(options, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
|
// Build an Eager operation.
|
||||||
|
auto* op = TF_NewAbstractOp(ctx);
|
||||||
|
TF_AbstractOpSetOpType(op, "Add", status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
|
// Build an abstract input tensor.
|
||||||
|
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx);
|
||||||
|
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||||
|
TF_AbstractTensor* at = TF_NewAbstractTensor();
|
||||||
|
TF_AbstractTensorSetEagerTensor(at, t, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
|
// Build inputs and outputs.
|
||||||
|
TF_AbstractTensor* inputs[2] = {at, at};
|
||||||
|
TF_OutputList* o = TF_NewOutputList();
|
||||||
|
TF_OutputListSetNumOutputs(o, 1, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
|
// Build a Graph context.
|
||||||
|
TF_ExecutionContextOptions* graph_options = TF_NewGraphContextOptions();
|
||||||
|
TF_ExecutionContext* graph_ctx =
|
||||||
|
TF_NewExecutionContext(graph_options, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
|
// Execute eager op using graph context.
|
||||||
|
TF_ExecuteOperation(op, 2, inputs, o, graph_ctx, status.get());
|
||||||
|
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||||
|
|
||||||
|
// Clean up operation and inputs.
|
||||||
|
TF_DeleteAbstractOp(op);
|
||||||
|
TF_DeleteAbstractTensor(at);
|
||||||
|
|
||||||
|
TF_DeleteOutputList(o);
|
||||||
|
TF_DeleteExecutionContext(ctx);
|
||||||
|
TF_DeleteExecutionContextOptions(options);
|
||||||
|
TF_DeleteExecutionContext(graph_ctx);
|
||||||
|
TF_DeleteExecutionContextOptions(graph_options);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UnifedCAPI, TestExecutingGraphOpInEagerModeRaises) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
|
||||||
|
TF_ExecutionContext* graph_ctx =
|
||||||
|
TF_NewExecutionContext(options, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
|
// Add a placeholder to the graph.
|
||||||
|
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
|
||||||
|
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
|
// Build inputs and outputs.
|
||||||
|
TF_OutputList* placeholder_outputs = TF_NewOutputList();
|
||||||
|
|
||||||
|
// Execute.
|
||||||
|
TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs,
|
||||||
|
graph_ctx, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs));
|
||||||
|
TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0);
|
||||||
|
|
||||||
|
// Delete placeholder op.
|
||||||
|
TF_DeleteAbstractOp(placeholder_op);
|
||||||
|
|
||||||
|
// Build an abstract operation.
|
||||||
|
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||||
|
TF_AbstractOpSetOpType(add_op, "Add", status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
TF_AbstractOpSetOpName(add_op, "my_add", status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
|
// Build inputs and outputs.
|
||||||
|
TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t};
|
||||||
|
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||||
|
|
||||||
|
// Build eager context.
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TF_ExecutionContextOptions* eager_ctx_options =
|
||||||
|
TF_NewEagerContextOptions(opts);
|
||||||
|
TF_ExecutionContext* eager_execution_ctx =
|
||||||
|
TF_NewExecutionContext(eager_ctx_options, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
// Execute.
|
||||||
|
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, eager_execution_ctx,
|
||||||
|
status.get());
|
||||||
|
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||||
|
|
||||||
|
// Clean up operation and inputs.
|
||||||
|
TF_DeleteAbstractTensor(placeholder_t);
|
||||||
|
TF_DeleteAbstractOp(add_op);
|
||||||
|
TF_DeleteOutputList(add_outputs);
|
||||||
|
TF_DeleteOutputList(placeholder_outputs);
|
||||||
|
TF_DeleteExecutionContext(graph_ctx);
|
||||||
|
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||||
|
TF_DeleteExecutionContextOptions(eager_ctx_options);
|
||||||
|
TF_DeleteExecutionContextOptions(options);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
// A simple logging device to test custom device registration.
|
// A simple logging device to test custom device registration.
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
#include "absl/strings/match.h"
|
||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
@ -25,7 +26,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
|
||||||
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
@ -176,7 +176,7 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
|
|||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
|
TEST(CUSTOM_DEVICE, AccessVariableOnCustomDevice) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||||
@ -226,16 +226,21 @@ TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
|
|||||||
|
|
||||||
// Read the variable's value.
|
// Read the variable's value.
|
||||||
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
|
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
|
||||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
|
||||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||||
executed = false;
|
executed = false;
|
||||||
num_retvals = 1;
|
num_retvals = 1;
|
||||||
TFE_TensorHandle* var_value = nullptr;
|
TFE_TensorHandle* var_value = nullptr;
|
||||||
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
|
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
|
||||||
EXPECT_FALSE(TF_GetCode(status.get()) == TF_OK)
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
<< "Execution should fail because the variable is being used on the "
|
ASSERT_TRUE(executed);
|
||||||
"wrong device.";
|
ASSERT_EQ(
|
||||||
|
tensorflow::string(name),
|
||||||
|
tensorflow::string(TFE_TensorHandleDeviceName(var_value, status.get())));
|
||||||
|
TFE_DeleteTensorHandle(var_value);
|
||||||
|
|
||||||
// Free the backing buffer for the variable.
|
// Free the backing buffer for the variable.
|
||||||
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
|
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
|
||||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||||
@ -246,6 +251,79 @@ TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
|
|||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CUSTOM_DEVICE, InputBasedPlacement) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||||
|
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||||
|
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||||
|
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
const char* custom0 = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||||
|
const char* custom1 = "/job:localhost/replica:0/task:0/device:CUSTOM:1";
|
||||||
|
bool arrived = false;
|
||||||
|
bool executed = false;
|
||||||
|
RegisterLoggingDevice(context.get(), custom0, &arrived, &executed,
|
||||||
|
status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
RegisterLoggingDevice(context.get(), custom1, &arrived, &executed,
|
||||||
|
status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcpu(
|
||||||
|
TestMatrixTensorHandle(context.get()), TFE_DeleteTensorHandle);
|
||||||
|
ASSERT_FALSE(arrived);
|
||||||
|
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcustom0(
|
||||||
|
TFE_TensorHandleCopyToDevice(hcpu.get(), context.get(), custom0,
|
||||||
|
status.get()),
|
||||||
|
TFE_DeleteTensorHandle);
|
||||||
|
ASSERT_TRUE(arrived);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
arrived = false;
|
||||||
|
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcustom1(
|
||||||
|
TFE_TensorHandleCopyToDevice(hcpu.get(), context.get(), custom1,
|
||||||
|
status.get()),
|
||||||
|
TFE_DeleteTensorHandle);
|
||||||
|
ASSERT_TRUE(arrived);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
// Base case: two CPU inputs executes fine.
|
||||||
|
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> matmul(
|
||||||
|
MatMulOp(context.get(), hcpu.get(), hcpu.get()), TFE_DeleteOp);
|
||||||
|
TFE_TensorHandle* retval;
|
||||||
|
int num_retvals = 1;
|
||||||
|
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
TFE_DeleteTensorHandle(retval);
|
||||||
|
|
||||||
|
// Custom device: inputs in same custom device works.
|
||||||
|
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcustom0.get()));
|
||||||
|
num_retvals = 1;
|
||||||
|
executed = false;
|
||||||
|
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
ASSERT_TRUE(executed);
|
||||||
|
TFE_DeleteTensorHandle(retval);
|
||||||
|
|
||||||
|
// Custom device: inputs in different custom devices fails.
|
||||||
|
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcustom1.get()));
|
||||||
|
num_retvals = 1;
|
||||||
|
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||||
|
ASSERT_NE(TF_OK, TF_GetCode(status.get()));
|
||||||
|
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
|
||||||
|
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom1));
|
||||||
|
|
||||||
|
// Custom device: mix of custom/physical fails.
|
||||||
|
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
|
||||||
|
num_retvals = 1;
|
||||||
|
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||||
|
ASSERT_NE(TF_OK, TF_GetCode(status.get()));
|
||||||
|
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
|
||||||
|
ASSERT_TRUE(
|
||||||
|
absl::StrContains(TF_Message(status.get()), "[]")); // kVariantDeviceNull
|
||||||
|
}
|
||||||
|
|
||||||
TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
|
TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
@ -42,7 +42,28 @@ class AbstractOperationInterface {
|
|||||||
virtual Status Reset(const char* op, const char* raw_device_name) = 0;
|
virtual Status Reset(const char* op, const char* raw_device_name) = 0;
|
||||||
|
|
||||||
virtual const string& Name() const = 0;
|
virtual const string& Name() const = 0;
|
||||||
|
|
||||||
|
// Returns the operation's device name.
|
||||||
|
//
|
||||||
|
// The value returned may be different from the one set by SetDeviceName, but
|
||||||
|
// it will be compatible with it: the name will be updated by device placement
|
||||||
|
// logic to refer to the specific device chosen.
|
||||||
|
//
|
||||||
|
// Example: If one calls `op->SetDeviceName("/device:GPU")`, the value
|
||||||
|
// returned by DeviceName should be "/device:GPU:*" until a particular GPU is
|
||||||
|
// chosen for the operation by the device placement logic in the
|
||||||
|
// executor. After that, the value returned by DeviceName will be a full
|
||||||
|
// device name such as "/job:localhost/replica:0/task:0/device:GPU:1".
|
||||||
virtual const string& DeviceName() const = 0;
|
virtual const string& DeviceName() const = 0;
|
||||||
|
|
||||||
|
// Sets the operation device name.
|
||||||
|
//
|
||||||
|
// The given `name` must be parseable by DeviceNameUtils::ParseFullName, and
|
||||||
|
// the result will be used as a constraint for device placement. See the
|
||||||
|
// documentation for DeviceName for more details.
|
||||||
|
//
|
||||||
|
// The value will override the previous value - that is, no "merging" of
|
||||||
|
// existing and given constraints will be performed.
|
||||||
virtual Status SetDeviceName(const char* name) = 0;
|
virtual Status SetDeviceName(const char* name) = 0;
|
||||||
|
|
||||||
virtual Status AddInput(AbstractTensorHandleInterface* input) = 0;
|
virtual Status AddInput(AbstractTensorHandleInterface* input) = 0;
|
||||||
|
66
tensorflow/c/experimental/saved_model/README.md
Normal file
66
tensorflow/c/experimental/saved_model/README.md
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
# Tensorflow C SavedModel API
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
These are the new experimental C SavedModel APIs for loading and running
|
||||||
|
SavedModels in a TF2-idiomatic fashion. See
|
||||||
|
[RFC 207](https://github.com/tensorflow/community/pull/207) for additional
|
||||||
|
context.
|
||||||
|
|
||||||
|
The directory structure is as follows:
|
||||||
|
|
||||||
|
```none
|
||||||
|
saved_model/
|
||||||
|
|
||||||
|
public/
|
||||||
|
|
||||||
|
internal/
|
||||||
|
|
||||||
|
core/
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## saved_model/public
|
||||||
|
|
||||||
|
`saved_model/public` is intended to house *only the public headers* of the
|
||||||
|
SavedModel C API.
|
||||||
|
|
||||||
|
These headers:
|
||||||
|
|
||||||
|
1. declare opaque C types (like `TF_SavedModel`),
|
||||||
|
|
||||||
|
2. declare the functions that operate on these types (like `TF_LoadSavedModel`).
|
||||||
|
|
||||||
|
Once they leave experimental, these APIs should be considered stable for use
|
||||||
|
by external clients.
|
||||||
|
|
||||||
|
These headers are in a separate directory to make it obvious to clients which
|
||||||
|
headers they should depend on, and which headers are implementation details.
|
||||||
|
Separating these public headers by directory also allow future programmatic
|
||||||
|
checks to ensure that TF public headers only `#include` other public TF headers.
|
||||||
|
|
||||||
|
## saved_model/internal
|
||||||
|
|
||||||
|
`saved_model/internal` is the "glue" between the C API and the internal C++
|
||||||
|
implementation.
|
||||||
|
|
||||||
|
Its role is to:
|
||||||
|
|
||||||
|
1. implement the C API functions declared in `saved_model/public`
|
||||||
|
|
||||||
|
2. define the C API types declared in `saved_model/public`
|
||||||
|
|
||||||
|
The files fulfilling 1. are named `*.cc` (eg: `concrete_function.cc`), while
|
||||||
|
the files fulfilling 2. are `*type.h` (eg: `concrete_function_type.h`).
|
||||||
|
|
||||||
|
The headers exposing the internal implementation of the opaque C types are only
|
||||||
|
visible to other implementors of the C API. This is similar to how other
|
||||||
|
TF C API implementations use `tf_status_internal.h` (to extract the underlying
|
||||||
|
`tensorflow::Status`). All other targets in this directory are private.
|
||||||
|
|
||||||
|
## saved_model/core
|
||||||
|
|
||||||
|
`saved_model/core` contains pure C++ "Classes" underlying the C API types
|
||||||
|
in `saved_model/public/`. These are implementation
|
||||||
|
details subject to change, and have limited visibility to implementors only.
|
||||||
|
This is the bottom-most layer of the `C++ -> C -> C++` sandwich.
|
46
tensorflow/c/experimental/saved_model/core/BUILD
Normal file
46
tensorflow/c/experimental/saved_model/core/BUILD
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
# Experimental SavedModel C APIs for TensorFlow. See RFC
|
||||||
|
# https://github.com/tensorflow/community/pull/207
|
||||||
|
# Targets in this directory are pure C++ "Classes" underlying the C API types
|
||||||
|
# under tf/c/experimental/saved_model/public/. They are subject to change and
|
||||||
|
# have visibility limited to Tensorflow's implementation only.
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = [
|
||||||
|
"//tensorflow/c/experimental/saved_model/internal:__pkg__",
|
||||||
|
],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "concrete_function",
|
||||||
|
srcs = [
|
||||||
|
"concrete_function.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"concrete_function.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":function_metadata",
|
||||||
|
"//tensorflow/c/eager:operation_interface",
|
||||||
|
"//tensorflow/c/eager:tensor_handle_interface",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "function_metadata",
|
||||||
|
hdrs = [
|
||||||
|
"function_metadata.h",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "saved_model_api",
|
||||||
|
hdrs = [
|
||||||
|
"saved_model_api.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":concrete_function",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
@ -0,0 +1,32 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
const std::vector<tensorflow::AbstractTensorHandleInterface*>&
|
||||||
|
ConcreteFunction::Captures() const {
|
||||||
|
return captures_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const FunctionMetadata& ConcreteFunction::GetFunctionMetadata() const {
|
||||||
|
return metadata_;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -0,0 +1,55 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/operation_interface.h"
|
||||||
|
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||||
|
#include "tensorflow/core/framework/function.pb.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Note that ConcreteFunctions's lifetimes are effectively bound
|
||||||
|
// to the SavedModel they are loaded from, since they retain pointers
|
||||||
|
// to the TensorHandles owned by the SavedModel, and the FunctionDef
|
||||||
|
// of the SavedModel.
|
||||||
|
// Note(bmzhao): This class is only TEMPORARILY virtual, as a way to unblock
|
||||||
|
// TFRT integration with TF Serving. Do not add more virtual implementations of
|
||||||
|
// this class. Eventually we want to remove this virtual base class indirection
|
||||||
|
// and have only a single implementation.
|
||||||
|
class ConcreteFunction {
|
||||||
|
public:
|
||||||
|
virtual ~ConcreteFunction() = 0;
|
||||||
|
|
||||||
|
// This method returns the "Call" Op used to execute the function.
|
||||||
|
virtual AbstractOperationInterface* GetFunctionOp() = 0;
|
||||||
|
|
||||||
|
const std::vector<tensorflow::AbstractTensorHandleInterface*>& Captures()
|
||||||
|
const;
|
||||||
|
const FunctionMetadata& GetFunctionMetadata() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
FunctionMetadata metadata_;
|
||||||
|
std::vector<tensorflow::AbstractTensorHandleInterface*> captures_;
|
||||||
|
FunctionDef* function_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_
|
@ -1,4 +1,4 @@
|
|||||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -13,13 +13,15 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_LOGGING_H_
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_FUNCTION_METADATA_H_
|
||||||
#define TENSORFLOW_LITE_TOOLS_BENCHMARK_LOGGING_H_
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_FUNCTION_METADATA_H_
|
||||||
|
|
||||||
// TODO(b/149482807): completely remove this file from the code base.
|
namespace tensorflow {
|
||||||
#include "tensorflow/lite/tools/logging.h"
|
|
||||||
|
|
||||||
#define TFLITE_BENCHMARK_CHECK(condition) TFLITE_TOOLS_CHECK(condition)
|
class FunctionMetadata {
|
||||||
#define TFLITE_BENCHMARK_CHECK_EQ(a, b) TFLITE_TOOLS_CHECK(a == b)
|
// TODO(bmzhao): Fill in with fields as necessary
|
||||||
|
};
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_TOOLS_BENCHMARK_LOGGING_H_
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_FUNCTION_METADATA_H_
|
55
tensorflow/c/experimental/saved_model/core/saved_model_api.h
Normal file
55
tensorflow/c/experimental/saved_model/core/saved_model_api.h
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_API_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_API_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||||
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Note(bmzhao): This class is only TEMPORARILY virtual, as a way to unblock
|
||||||
|
// TFRT integration with TF Serving. Do not add more virtual implementations of
|
||||||
|
// this class. Eventually we want to remove this virtual base class indirection
|
||||||
|
// and have only a single implementation.
|
||||||
|
class SavedModelAPI {
|
||||||
|
public:
|
||||||
|
// Retrieve a function from the TF2 SavedModel, using the "path" to a function
|
||||||
|
// in a TF2 savedmodel.
|
||||||
|
// Note: `function` is a double pointer, so that implementations are
|
||||||
|
// able to return a pointer to an internal member.
|
||||||
|
virtual Status GetFunction(const std::string& function_path,
|
||||||
|
ConcreteFunction** function) = 0;
|
||||||
|
|
||||||
|
// Retrieve a function from a SavedModel, using the key of the
|
||||||
|
// SignatureDef map:
|
||||||
|
// https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89
|
||||||
|
virtual Status GetSignatureDefFunction(const std::string& signature_def_key,
|
||||||
|
ConcreteFunction** function) = 0;
|
||||||
|
|
||||||
|
virtual const std::vector<ConcreteFunction*>& ListFunctions() = 0;
|
||||||
|
|
||||||
|
virtual ~SavedModelAPI() = default;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_API_H_
|
157
tensorflow/c/experimental/saved_model/internal/BUILD
Normal file
157
tensorflow/c/experimental/saved_model/internal/BUILD
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
# Experimental Implementation of SavedModel C APIs for TensorFlow. See RFC
|
||||||
|
# https://github.com/tensorflow/community/pull/207
|
||||||
|
# External clients should not worry about this directory; all contents are implementation details.
|
||||||
|
# Code in this directory is intended to form the glue between the C API and the internal C++
|
||||||
|
# implementation by
|
||||||
|
# 1. mapping C API calls onto correponding methods of C++ objects
|
||||||
|
# 2. mapping opaque C types onto C++ classes
|
||||||
|
|
||||||
|
# Note(bmzhao): The *.cc files in this directory form the direct implementation of the
|
||||||
|
# C API functions exposed in tf/c/experimental/saved_model/public/.
|
||||||
|
|
||||||
|
# Note(bmzhao): All *type.h files in this directory are the internal definitions of
|
||||||
|
# the opaque C types. These headers should only be visible to internal tensorflow
|
||||||
|
# implementors.
|
||||||
|
|
||||||
|
package(
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "conversion_macros",
|
||||||
|
hdrs = [
|
||||||
|
"conversion_macros.h",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "concrete_function",
|
||||||
|
srcs = [
|
||||||
|
"concrete_function.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:concrete_function.h",
|
||||||
|
],
|
||||||
|
# TODO(bmzhao): Remove this as we refactor C API to granular targets,
|
||||||
|
# so that we can depend on c/eager/c_api_unified_experimental.h.
|
||||||
|
features = ["-layering_check"],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":concrete_function_type",
|
||||||
|
":function_metadata",
|
||||||
|
":function_metadata_type",
|
||||||
|
"//tensorflow/c:c_api_macros",
|
||||||
|
"//tensorflow/c/eager:c_api",
|
||||||
|
"//tensorflow/c/eager:c_api_internal",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core:concrete_function",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core:function_metadata",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "concrete_function_list",
|
||||||
|
srcs = [
|
||||||
|
"concrete_function_list.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:concrete_function_list.h",
|
||||||
|
],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":concrete_function",
|
||||||
|
":concrete_function_list_type",
|
||||||
|
":concrete_function_type",
|
||||||
|
"//tensorflow/c:c_api_macros",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core:concrete_function",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "concrete_function_list_type",
|
||||||
|
hdrs = [
|
||||||
|
"concrete_function_list_type.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":conversion_macros",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core:concrete_function",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "concrete_function_type",
|
||||||
|
hdrs = [
|
||||||
|
"concrete_function_type.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":conversion_macros",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core:concrete_function",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "function_metadata",
|
||||||
|
srcs = [
|
||||||
|
"function_metadata.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:function_metadata.h",
|
||||||
|
],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":function_metadata_type",
|
||||||
|
"//tensorflow/c:c_api_macros",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core:function_metadata",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "function_metadata_type",
|
||||||
|
hdrs = [
|
||||||
|
"function_metadata_type.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":conversion_macros",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core:function_metadata",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "saved_model_api",
|
||||||
|
srcs = [
|
||||||
|
"saved_model_api.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:saved_model_api.h",
|
||||||
|
],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":concrete_function",
|
||||||
|
":concrete_function_list",
|
||||||
|
":concrete_function_list_type",
|
||||||
|
":concrete_function_type",
|
||||||
|
":saved_model_api_type",
|
||||||
|
"//tensorflow/c:c_api_macros",
|
||||||
|
"//tensorflow/c:tf_status",
|
||||||
|
"//tensorflow/c:tf_status_internal",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "saved_model_api_type",
|
||||||
|
hdrs = [
|
||||||
|
"saved_model_api_type.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
|
||||||
|
],
|
||||||
|
)
|
@ -0,0 +1,40 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h"
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) {
|
||||||
|
return tensorflow::wrap(&tensorflow::unwrap(func)->GetFunctionMetadata());
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_OutputList* TF_ConcreteFunctionGetCaptures(TF_ConcreteFunction* func) {
|
||||||
|
// TODO(bmzhao): Refactor TF_OutputList struct definition into a separate
|
||||||
|
// internal header, and implement this function.
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_Op* TF_ConcreteFunctionGetOperation(TF_ConcreteFunction* func) {
|
||||||
|
return new TFE_Op{tensorflow::unwrap(func)->GetFunctionOp()};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end extern "C"
|
@ -0,0 +1,33 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
size_t TF_ConcreteFunctionListNumOutputs(TF_ConcreteFunctionList* list) {
|
||||||
|
return tensorflow::unwrap(list)->size();
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_ConcreteFunction* TF_ConcreteFunctionListGet(TF_ConcreteFunctionList* list,
|
||||||
|
int i) {
|
||||||
|
return tensorflow::wrap((*tensorflow::unwrap(list))[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end extern "C"
|
@ -0,0 +1,36 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/internal/conversion_macros.h"
|
||||||
|
|
||||||
|
// Internal structures used by the SavedModel C API. These are likely to change
|
||||||
|
// and should not be depended on.
|
||||||
|
|
||||||
|
typedef struct TF_ConcreteFunctionList TF_ConcreteFunctionList;
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
DEFINE_CONVERSION_FUNCTIONS(std::vector<tensorflow::ConcreteFunction*>,
|
||||||
|
TF_ConcreteFunctionList)
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
|
@ -0,0 +1,36 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/internal/conversion_macros.h"
|
||||||
|
|
||||||
|
// Internal structures used by the SavedModel C API. These are likely to change
|
||||||
|
// and should not be depended on.
|
||||||
|
|
||||||
|
// It doesn't make sense to wrap tensorflow::ConcreteFunction* in a separate
|
||||||
|
// struct, since the lifetime of the struct and the raw pointer it wraps would
|
||||||
|
// be different. Therefore TF_ConcreteFunction* = tensorflow::ConcreteFunction*.
|
||||||
|
typedef struct TF_ConcreteFunction TF_ConcreteFunction;
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ConcreteFunction, TF_ConcreteFunction)
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_
|
@ -0,0 +1,28 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONVERSION_MACROS_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONVERSION_MACROS_H_
|
||||||
|
|
||||||
|
#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \
|
||||||
|
inline cpp_impl *unwrap(wrapper *w) { \
|
||||||
|
return reinterpret_cast<cpp_impl *>(w); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline wrapper *wrap(const cpp_impl *i) { \
|
||||||
|
return reinterpret_cast<wrapper *>(const_cast<cpp_impl *>(i)); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONVERSION_MACROS_H_
|
@ -0,0 +1,20 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h"
|
||||||
|
|
||||||
|
// TODO(bmzhao): Add getter functions here as necessary.
|
@ -0,0 +1,30 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/internal/conversion_macros.h"
|
||||||
|
|
||||||
|
typedef struct TF_FunctionMetadata TF_FunctionMetadata;
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
DEFINE_CONVERSION_FUNCTIONS(tensorflow::FunctionMetadata, TF_FunctionMetadata)
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_
|
@ -0,0 +1,67 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
|
||||||
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
#include "tensorflow/c/tf_status_internal.h"
|
||||||
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
|
||||||
|
const char* const* tags, int tags_len,
|
||||||
|
TF_Status* status) {
|
||||||
|
// TODO(bmzhao): Add a virtual "LoadSavedModel" method to
|
||||||
|
// AbstractContextInterface, and call it here.
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TF_DeleteSavedModel(TF_SavedModel* model) { delete model; }
|
||||||
|
|
||||||
|
TF_ConcreteFunction* TF_GetSavedModelFunction(TF_SavedModel* model,
|
||||||
|
char* function_path,
|
||||||
|
TF_Status* status) {
|
||||||
|
tensorflow::ConcreteFunction* result = nullptr;
|
||||||
|
tensorflow::Status get_function_status =
|
||||||
|
model->saved_model->GetFunction(function_path, &result);
|
||||||
|
status->status.Update(get_function_status);
|
||||||
|
if (!get_function_status.ok()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return tensorflow::wrap(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
|
||||||
|
TF_SavedModel* model, char* signature_def_key, TF_Status* status) {
|
||||||
|
tensorflow::ConcreteFunction* result = nullptr;
|
||||||
|
tensorflow::Status get_function_status =
|
||||||
|
model->saved_model->GetSignatureDefFunction(signature_def_key, &result);
|
||||||
|
status->status.Update(get_function_status);
|
||||||
|
if (!get_function_status.ok()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return tensorflow::wrap(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_ConcreteFunctionList* TF_ListSavedModelFunctions(TF_SavedModel* model) {
|
||||||
|
return tensorflow::wrap(&model->saved_model->ListFunctions());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end extern "C"
|
@ -0,0 +1,30 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||||
|
|
||||||
|
// Internal structures used by the SavedModel C API. These are likely to change
|
||||||
|
// and should not be depended on.
|
||||||
|
|
||||||
|
struct TF_SavedModel {
|
||||||
|
std::unique_ptr<tensorflow::SavedModelAPI> saved_model;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_
|
63
tensorflow/c/experimental/saved_model/public/BUILD
Normal file
63
tensorflow/c/experimental/saved_model/public/BUILD
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
# Experimental SavedModel C APIs for TensorFlow.
|
||||||
|
# See RFC https://github.com/tensorflow/community/pull/207
|
||||||
|
# All headers are on the public surface of Tensorflow's C API.
|
||||||
|
# Once moved out of experimental, these will be stable.
|
||||||
|
# The idea behind a separate public/ directory is to make apparent
|
||||||
|
# which headers are part of TF's public interface (and which headers)
|
||||||
|
# are implementation details. This structure allows us to also perform future
|
||||||
|
# programmatic checks that all "public" headers only include other "public"
|
||||||
|
# headers.
|
||||||
|
|
||||||
|
package(
|
||||||
|
# This is intentionally public
|
||||||
|
default_visibility = [
|
||||||
|
"//visibility:public",
|
||||||
|
],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO(bmzhao): Remove these exports_files and rules, swap with cc_public_library instead.
|
||||||
|
# cc_public_library would allows us to separate the header dep graph from header+srcs dep graph.
|
||||||
|
exports_files(
|
||||||
|
[
|
||||||
|
"concrete_function.h",
|
||||||
|
"concrete_function_list.h",
|
||||||
|
"function_metadata.h",
|
||||||
|
"saved_model_api.h",
|
||||||
|
],
|
||||||
|
visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# The purpose of this header is to provide insulation against
|
||||||
|
# future changes where we rename/move a public header, without
|
||||||
|
# forcing all clients to change their "#includes".
|
||||||
|
cc_library(
|
||||||
|
name = "c_saved_model_api",
|
||||||
|
hdrs = ["c_saved_model_api.h"],
|
||||||
|
deps = [
|
||||||
|
":concrete_function",
|
||||||
|
":concrete_function_list",
|
||||||
|
":function_metadata",
|
||||||
|
":saved_model_api",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
alias(
|
||||||
|
name = "concrete_function",
|
||||||
|
actual = "//tensorflow/c/experimental/saved_model/internal:concrete_function",
|
||||||
|
)
|
||||||
|
|
||||||
|
alias(
|
||||||
|
name = "concrete_function_list",
|
||||||
|
actual = "//tensorflow/c/experimental/saved_model/internal:concrete_function_list",
|
||||||
|
)
|
||||||
|
|
||||||
|
alias(
|
||||||
|
name = "function_metadata",
|
||||||
|
actual = "//tensorflow/c/experimental/saved_model/internal:function_metadata",
|
||||||
|
)
|
||||||
|
|
||||||
|
alias(
|
||||||
|
name = "saved_model_api",
|
||||||
|
actual = "//tensorflow/c/experimental/saved_model/internal:saved_model_api",
|
||||||
|
)
|
@ -0,0 +1,26 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_
|
||||||
|
|
||||||
|
// IWYU pragma: begin_exports
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
|
||||||
|
// IWYU pragma: end_exports
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_
|
@ -0,0 +1,53 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api_macros.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif // __cplusplus
|
||||||
|
|
||||||
|
// An opaque type that corresponds to a Function loaded from a SavedModel.
|
||||||
|
// TODO(bmzhao): Work together w/srbs@ to make sure this composes w/the
|
||||||
|
// C++ Unified Eager/Graph API's AbstractFunction
|
||||||
|
typedef struct TF_ConcreteFunction TF_ConcreteFunction;
|
||||||
|
|
||||||
|
// Returns FunctionMetadata associated with `func`. Metadata's lifetime is
|
||||||
|
// bound to `func`, which is bound to the TF_SavedModel it was loaded from.
|
||||||
|
TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(
|
||||||
|
TF_ConcreteFunction* func);
|
||||||
|
|
||||||
|
// Returns a list of TensorHandles implicitly captured by this function.
|
||||||
|
TF_CAPI_EXPORT extern TF_OutputList* TF_ConcreteFunctionGetCaptures(
|
||||||
|
TF_ConcreteFunction* func);
|
||||||
|
|
||||||
|
// Returns a TFE_Op suitable for executing this function.
|
||||||
|
TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetOperation(
|
||||||
|
TF_ConcreteFunction* func);
|
||||||
|
|
||||||
|
// Deletes `func`.
|
||||||
|
TF_CAPI_EXPORT extern void TF_DeleteConcreteFunction(TF_ConcreteFunction* func);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} // end extern "C"
|
||||||
|
#endif // __cplusplus
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_
|
@ -0,0 +1,35 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api_macros.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||||
|
|
||||||
|
// An opaque type that is acts like a list of TF_ConcreteFunction pointers.
|
||||||
|
typedef struct TF_ConcreteFunctionList TF_ConcreteFunctionList;
|
||||||
|
|
||||||
|
// Returns the size of `list`.
|
||||||
|
TF_CAPI_EXPORT size_t
|
||||||
|
TF_ConcreteFunctionListSize(TF_ConcreteFunctionList* list);
|
||||||
|
|
||||||
|
// Returns the `i`th TF_ConcreteFunction in the list.
|
||||||
|
TF_CAPI_EXPORT TF_ConcreteFunction* TF_ConcreteFunctionListGet(
|
||||||
|
TF_ConcreteFunctionList* list, int i);
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
@ -0,0 +1,35 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_FUNCTION_METADATA_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_FUNCTION_METADATA_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api_macros.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif // __cplusplus
|
||||||
|
|
||||||
|
// An opaque type used to store any metadata associated with a function.
|
||||||
|
typedef struct TF_FunctionMetadata TF_FunctionMetadata;
|
||||||
|
|
||||||
|
// TODO(bmzhao): Add getters for fields as we determine what metadata
|
||||||
|
// we want to expose.
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} // end extern "C"
|
||||||
|
#endif // __cplusplus
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_FUNCTION_METADATA_H_
|
@ -0,0 +1,96 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SAVED_MODEL_API_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SAVED_MODEL_API_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api_macros.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
|
||||||
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif // __cplusplus
|
||||||
|
|
||||||
|
// An opaque type representing a Tensorflow "SavedModel"
|
||||||
|
// (https://www.tensorflow.org/guide/saved_model) that we always pass by pointer
|
||||||
|
// to achieve ABI stability.
|
||||||
|
typedef struct TF_SavedModel TF_SavedModel;
|
||||||
|
|
||||||
|
// Load a SavedModel from `dirname`.
|
||||||
|
//
|
||||||
|
// Params:
|
||||||
|
// dirname - A directory filepath that the SavedModel is at.
|
||||||
|
// ctx - A TFE_Context containing optional load/TF runtime options.
|
||||||
|
// `ctx` must outlive the returned TF_SavedModel pointer.
|
||||||
|
// tags - Pointer to char* array of SavedModel tags. Optional if the SavedModel
|
||||||
|
// contains a single Metagraph, as for those exported from
|
||||||
|
// `tf.saved_model.save`.
|
||||||
|
// tags_len - number of elements in the `tags` array.
|
||||||
|
// status - Set to OK on success and an appropriate error on failure.
|
||||||
|
// Returns:
|
||||||
|
// If status is not OK, returns nullptr. Otherwise, returns a newly created
|
||||||
|
// TF_SavedModel instance. It must be deleted by calling TF_DeleteSavedModel.
|
||||||
|
TF_CAPI_EXPORT extern TF_SavedModel* TF_LoadSavedModel(const char* dirname,
|
||||||
|
TFE_Context* ctx,
|
||||||
|
const char* const* tags,
|
||||||
|
int tags_len,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Deletes a TF_SavedModel, and frees any resources owned by it.
|
||||||
|
TF_CAPI_EXPORT extern void TF_DeleteSavedModel(TF_SavedModel* model);
|
||||||
|
|
||||||
|
// Retrieve a function from the TF2 SavedModel via function path.
|
||||||
|
//
|
||||||
|
// Params:
|
||||||
|
// model - The TF2 SavedModel to load a function from.
|
||||||
|
// function_path - A string containing the path from the root saved python
|
||||||
|
// object to a tf.function method.
|
||||||
|
// TODO(bmzhao): Add a detailed example of this with a
|
||||||
|
// python tf.module before moving this out of experimental.
|
||||||
|
// status - Set to OK on success and an appropriate error on failure.
|
||||||
|
// Returns:
|
||||||
|
// If status is not OK, returns nullptr. Otherwise, returns a
|
||||||
|
// TF_ConcreteFunction instance. The lifetime of this instance is
|
||||||
|
// "conceptually" bound to `model`. Once `model` is deleted, all
|
||||||
|
// `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted.
|
||||||
|
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelFunction(
|
||||||
|
TF_SavedModel* model, char* function_path, TF_Status* status);
|
||||||
|
|
||||||
|
// Retrieve a function from the TF SavedModel via a SignatureDef key.
|
||||||
|
//
|
||||||
|
// Params:
|
||||||
|
// model - The SavedModel to load a function from.
|
||||||
|
// signature_def_key - The string key of the SignatureDef map of a SavedModel:
|
||||||
|
// https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89
|
||||||
|
// status - Set to OK on success and an appropriate error on failure.
|
||||||
|
// Returns:
|
||||||
|
// If status is not OK, returns nullptr. Otherwise, returns a
|
||||||
|
// TF_ConcreteFunction instance. Once `model` is deleted, all
|
||||||
|
// `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted.
|
||||||
|
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
|
||||||
|
TF_SavedModel* model, char* signature_def_key, TF_Status* status);
|
||||||
|
|
||||||
|
// Returns a list of all ConcreteFunctions stored in this SavedModel.
|
||||||
|
// The lifetime of the returned list is bound to `model`.
|
||||||
|
TF_CAPI_EXPORT extern TF_ConcreteFunctionList* TF_ListSavedModelFunctions(
|
||||||
|
TF_SavedModel* model);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} // end extern "C"
|
||||||
|
#endif // __cplusplus
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SAVED_MODEL_API_H_
|
@ -119,6 +119,9 @@ inline Tensor& TensorFromInterface(AbstractTensorInterface* tensor) {
|
|||||||
return down_cast<TensorInterface*>(tensor)->Tensor();
|
return down_cast<TensorInterface*>(tensor)->Tensor();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
||||||
|
|
||||||
|
TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
|
#endif // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
|
||||||
|
@ -156,6 +156,7 @@ cc_library(
|
|||||||
":array_grad",
|
":array_grad",
|
||||||
":data_flow_grad",
|
":data_flow_grad",
|
||||||
":image_grad",
|
":image_grad",
|
||||||
|
":manip_grad",
|
||||||
":math_grad",
|
":math_grad",
|
||||||
":nn_grad",
|
":nn_grad",
|
||||||
],
|
],
|
||||||
@ -494,6 +495,32 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "manip_grad",
|
||||||
|
srcs = ["gradients/manip_grad.cc"],
|
||||||
|
deps = [
|
||||||
|
":cc_ops",
|
||||||
|
":grad_op_registry",
|
||||||
|
":gradients",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "gradients_manip_grad_test",
|
||||||
|
srcs = ["gradients/manip_grad_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":array_ops",
|
||||||
|
":cc_ops",
|
||||||
|
":gradient_checker",
|
||||||
|
":manip_grad",
|
||||||
|
":testutil",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# Generates separate libraries for array_ops and math_ops to reduce the dependency count of targets that depend on only these
|
# Generates separate libraries for array_ops and math_ops to reduce the dependency count of targets that depend on only these
|
||||||
tf_gen_op_wrappers_cc(
|
tf_gen_op_wrappers_cc(
|
||||||
name = "math_ops",
|
name = "math_ops",
|
||||||
|
40
tensorflow/cc/gradients/manip_grad.cc
Normal file
40
tensorflow/cc/gradients/manip_grad.cc
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||||
|
#include "tensorflow/cc/framework/gradients.h"
|
||||||
|
#include "tensorflow/cc/ops/manip_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace ops {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
Status RollGrad(const Scope& scope, const Operation& op,
|
||||||
|
const std::vector<Output>& grad_inputs,
|
||||||
|
std::vector<Output>* grad_outputs) {
|
||||||
|
auto shift = op.input(1);
|
||||||
|
auto axis = op.input(2);
|
||||||
|
auto grad_op = Roll(scope, grad_inputs[0], Neg(scope, shift), axis);
|
||||||
|
grad_outputs->push_back(grad_op);
|
||||||
|
grad_outputs->push_back(NoGradient());
|
||||||
|
grad_outputs->push_back(NoGradient());
|
||||||
|
return scope.status();
|
||||||
|
}
|
||||||
|
REGISTER_GRADIENT_OP("Roll", RollGrad);
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace tensorflow
|
51
tensorflow/cc/gradients/manip_grad_test.cc
Normal file
51
tensorflow/cc/gradients/manip_grad_test.cc
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/framework/gradient_checker.h"
|
||||||
|
#include "tensorflow/cc/ops/array_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/manip_ops.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ops::Placeholder;
|
||||||
|
using ops::Roll;
|
||||||
|
|
||||||
|
class ManipGradTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
ManipGradTest() : scope_(Scope::NewRootScope()) {}
|
||||||
|
|
||||||
|
void RunTest(const Output& x, const TensorShape& x_shape, const Output& y,
|
||||||
|
const TensorShape& y_shape) {
|
||||||
|
TF_ASSERT_OK(scope_.status());
|
||||||
|
float max_error;
|
||||||
|
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
|
||||||
|
scope_, {x}, {x_shape}, {y}, {y_shape}, &max_error)));
|
||||||
|
EXPECT_LT(max_error, 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
|
Scope scope_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(ManipGradTest, RollGrad) {
|
||||||
|
TensorShape shape({5, 4, 3});
|
||||||
|
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
||||||
|
auto y = Roll(scope_, x, {2, 1}, {0, 1});
|
||||||
|
RunTest(x, shape, y, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
@ -358,13 +358,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
|||||||
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_,
|
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_,
|
||||||
resources_, constants_, /*lazy=*/false, &client, &variables, &kernel,
|
resources_, constants_, /*lazy=*/false, &client, &variables, &kernel,
|
||||||
&executable);
|
&executable);
|
||||||
if (!s.ok() && (platform_info_.device_type().type_string() == DEVICE_CPU ||
|
|
||||||
platform_info_.device_type().type_string() == DEVICE_GPU)) {
|
|
||||||
// Suggest auto jit if the failure was with GPU or CPU.
|
|
||||||
errors::AppendToMessage(&s,
|
|
||||||
xla::status_macros::kPossibleAutoJitAlternative);
|
|
||||||
}
|
|
||||||
|
|
||||||
OP_REQUIRES_OK(ctx, s);
|
OP_REQUIRES_OK(ctx, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1891,6 +1891,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
|||||||
"DynamicStitch",
|
"DynamicStitch",
|
||||||
"Einsum",
|
"Einsum",
|
||||||
"EmptyTensorList",
|
"EmptyTensorList",
|
||||||
|
"EnsureShape",
|
||||||
"ExtractImagePatches",
|
"ExtractImagePatches",
|
||||||
"Igamma",
|
"Igamma",
|
||||||
"IgammaGradA",
|
"IgammaGradA",
|
||||||
|
@ -145,16 +145,9 @@ Status XlaCompileOnDemandOp::Compile(
|
|||||||
attrs.set_on_host(true);
|
attrs.set_on_host(true);
|
||||||
TF_RETURN_IF_ERROR(ctx->allocate_temp(
|
TF_RETURN_IF_ERROR(ctx->allocate_temp(
|
||||||
device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs));
|
device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs));
|
||||||
Notification n;
|
Status status = ctx->op_device_context()->CopyDeviceTensorToCPUSync(
|
||||||
Status status;
|
|
||||||
ctx->op_device_context()->CopyDeviceTensorToCPU(
|
|
||||||
&device_tensor, "ConstantArgument",
|
&device_tensor, "ConstantArgument",
|
||||||
reinterpret_cast<Device*>(ctx->device()), &host_tensor,
|
reinterpret_cast<Device*>(ctx->device()), &host_tensor);
|
||||||
[&](Status s) {
|
|
||||||
status = s;
|
|
||||||
n.Notify();
|
|
||||||
});
|
|
||||||
n.WaitForNotification();
|
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
LOG(ERROR) << "Copying tensor of shape "
|
LOG(ERROR) << "Copying tensor of shape "
|
||||||
<< device_tensor.shape().DebugString() << " from "
|
<< device_tensor.shape().DebugString() << " from "
|
||||||
|
@ -488,15 +488,8 @@ Status XlaDevice::MakeTensorFromProto(XlaDeviceContext* device_context,
|
|||||||
mutex_lock lock(mu_);
|
mutex_lock lock(mu_);
|
||||||
Allocator* allocator = GetAllocatorLocked(alloc_attrs);
|
Allocator* allocator = GetAllocatorLocked(alloc_attrs);
|
||||||
Tensor copy(allocator, parsed.dtype(), parsed.shape());
|
Tensor copy(allocator, parsed.dtype(), parsed.shape());
|
||||||
Notification n;
|
TF_RETURN_IF_ERROR(
|
||||||
device_context->CopyCPUTensorToDevice(
|
device_context->CopyCPUTensorToDeviceSync(&parsed, this, ©));
|
||||||
&parsed, this, ©,
|
|
||||||
[&n, &status](const Status& s) {
|
|
||||||
status = s;
|
|
||||||
n.Notify();
|
|
||||||
},
|
|
||||||
true /*sync_dst_compute*/);
|
|
||||||
n.WaitForNotification();
|
|
||||||
*tensor = copy;
|
*tensor = copy;
|
||||||
}
|
}
|
||||||
VLOG(2) << "Allocated tensor at " << DMAHelper::base(tensor);
|
VLOG(2) << "Allocated tensor at " << DMAHelper::base(tensor);
|
||||||
|
@ -69,6 +69,7 @@ absl::optional<AllocatorStats> XlaDeviceAllocator::GetStats() {
|
|||||||
tf_stats.bytes_reserved = se_stats->bytes_reserved;
|
tf_stats.bytes_reserved = se_stats->bytes_reserved;
|
||||||
tf_stats.peak_bytes_reserved = se_stats->peak_bytes_reserved;
|
tf_stats.peak_bytes_reserved = se_stats->peak_bytes_reserved;
|
||||||
tf_stats.bytes_reservable_limit = se_stats->bytes_reservable_limit;
|
tf_stats.bytes_reservable_limit = se_stats->bytes_reservable_limit;
|
||||||
|
tf_stats.largest_free_block_bytes = se_stats->largest_free_block_bytes;
|
||||||
return tf_stats;
|
return tf_stats;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -479,6 +479,15 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
|||||||
input_output_alias, output_num, ctx, i, shape, &output,
|
input_output_alias, output_num, ctx, i, shape, &output,
|
||||||
definition_event, stream, use_multiple_streams_));
|
definition_event, stream, use_multiple_streams_));
|
||||||
} else {
|
} else {
|
||||||
|
auto program_shape =
|
||||||
|
kernel->computation->GetProgramShape().ValueOrDie();
|
||||||
|
if (program_shape.result().IsTuple() &&
|
||||||
|
program_shape.result().tuple_shapes(output_num).IsTuple()) {
|
||||||
|
return errors::Unimplemented(
|
||||||
|
"Support for TensorList or Stack crossing the XLA/TF boundary "
|
||||||
|
"is not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||||
Tensor output_tensor = GetOrCreateTensorForOutput(
|
Tensor output_tensor = GetOrCreateTensorForOutput(
|
||||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||||
|
@ -40,11 +40,11 @@ cc_library(
|
|||||||
srcs = ["tf_mlir_opt_main.cc"],
|
srcs = ["tf_mlir_opt_main.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":init_mlir",
|
":init_mlir",
|
||||||
":passes",
|
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/platform:logging",
|
"//tensorflow/core/platform:logging",
|
||||||
"@llvm-project//llvm:support",
|
"@llvm-project//llvm:support",
|
||||||
"@llvm-project//mlir:AllPassesAndDialects",
|
"@llvm-project//mlir:AllPassesAndDialects",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:MlirOptLib",
|
"@llvm-project//mlir:MlirOptLib",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:Support",
|
"@llvm-project//mlir:Support",
|
||||||
@ -55,6 +55,7 @@ cc_library(
|
|||||||
cc_library(
|
cc_library(
|
||||||
name = "passes",
|
name = "passes",
|
||||||
visibility = [
|
visibility = [
|
||||||
|
":__subpackages__",
|
||||||
"//tensorflow/python:__subpackages__",
|
"//tensorflow/python:__subpackages__",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
@ -76,24 +77,6 @@ cc_library(
|
|||||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes",
|
"//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
|
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
|
||||||
"//tensorflow/compiler/mlir/xla:buffer_assignment",
|
|
||||||
"//tensorflow/compiler/mlir/xla:hlo",
|
|
||||||
"//tensorflow/compiler/mlir/xla:hlo_legalize_to_lhlo",
|
|
||||||
"//tensorflow/compiler/mlir/xla:lhlo",
|
|
||||||
"//tensorflow/compiler/mlir/xla:lhlo_copy_removal",
|
|
||||||
"//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg",
|
|
||||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine",
|
|
||||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu",
|
|
||||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_parallel_loops",
|
|
||||||
"//tensorflow/compiler/mlir/xla:xla_dialect_registration",
|
|
||||||
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
|
|
||||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
|
||||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
|
|
||||||
"//tensorflow/compiler/mlir/xla:xla_legalize_to_linalg",
|
|
||||||
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
|
|
||||||
"//tensorflow/compiler/mlir/xla:xla_lower",
|
|
||||||
"//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts",
|
|
||||||
"//tensorflow/compiler/mlir/xla:xla_test_passes",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -141,11 +124,14 @@ cc_library(
|
|||||||
tf_cc_binary(
|
tf_cc_binary(
|
||||||
name = "tf-opt",
|
name = "tf-opt",
|
||||||
deps = [
|
deps = [
|
||||||
|
":passes",
|
||||||
":tf_mlir_opt_main",
|
":tf_mlir_opt_main",
|
||||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
|
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration",
|
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
|
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
|
||||||
|
"//tensorflow/compiler/mlir/tfjs:tensorflow_js_dialect_registration",
|
||||||
|
"//tensorflow/compiler/mlir/xla:all_xla_passes_for_testing",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ _ALWAYS_EXCLUDE = [
|
|||||||
"**/* */**",
|
"**/* */**",
|
||||||
]
|
]
|
||||||
|
|
||||||
def _run_lit_test(name, data, size, tags, driver, features):
|
def _run_lit_test(name, data, size, tags, driver, features, exec_properties):
|
||||||
"""Runs lit on all tests it can find in `data` under tensorflow/compiler/mlir.
|
"""Runs lit on all tests it can find in `data` under tensorflow/compiler/mlir.
|
||||||
|
|
||||||
Note that, due to Bazel's hermetic builds, lit only sees the tests that
|
Note that, due to Bazel's hermetic builds, lit only sees the tests that
|
||||||
@ -64,6 +64,7 @@ def _run_lit_test(name, data, size, tags, driver, features):
|
|||||||
],
|
],
|
||||||
size = size,
|
size = size,
|
||||||
main = "lit.py",
|
main = "lit.py",
|
||||||
|
exec_properties = exec_properties,
|
||||||
)
|
)
|
||||||
|
|
||||||
def glob_lit_tests(
|
def glob_lit_tests(
|
||||||
@ -76,7 +77,8 @@ def glob_lit_tests(
|
|||||||
default_tags = _default_tags,
|
default_tags = _default_tags,
|
||||||
tags_override = {},
|
tags_override = {},
|
||||||
driver = _default_driver,
|
driver = _default_driver,
|
||||||
features = []):
|
features = [],
|
||||||
|
exec_properties = {}):
|
||||||
"""Creates all plausible Lit tests (and their inputs) under this directory.
|
"""Creates all plausible Lit tests (and their inputs) under this directory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -92,6 +94,7 @@ def glob_lit_tests(
|
|||||||
Note: use of a custom driver is not currently supported
|
Note: use of a custom driver is not currently supported
|
||||||
and specifying a default driver will abort the tests.
|
and specifying a default driver will abort the tests.
|
||||||
features: [str], list of extra features to enable.
|
features: [str], list of extra features to enable.
|
||||||
|
exec_properties: a dictionary of properties to pass on.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Ignore some patterns by default for tests and input data.
|
# Ignore some patterns by default for tests and input data.
|
||||||
@ -115,6 +118,7 @@ def glob_lit_tests(
|
|||||||
tags = default_tags + tags_override.pop(curr_test, []),
|
tags = default_tags + tags_override.pop(curr_test, []),
|
||||||
driver = driver,
|
driver = driver,
|
||||||
features = features,
|
features = features,
|
||||||
|
exec_properties = exec_properties,
|
||||||
)
|
)
|
||||||
|
|
||||||
def lit_test(
|
def lit_test(
|
||||||
@ -123,7 +127,8 @@ def lit_test(
|
|||||||
size = _default_size,
|
size = _default_size,
|
||||||
tags = _default_tags,
|
tags = _default_tags,
|
||||||
driver = _default_driver,
|
driver = _default_driver,
|
||||||
features = []):
|
features = [],
|
||||||
|
exec_properties = {}):
|
||||||
"""Runs test files under lit.
|
"""Runs test files under lit.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -136,4 +141,4 @@ def lit_test(
|
|||||||
and specifying a default driver will abort the tests.
|
and specifying a default driver will abort the tests.
|
||||||
features: [str], list of extra features to enable.
|
features: [str], list of extra features to enable.
|
||||||
"""
|
"""
|
||||||
_run_lit_test(name + ".test", data + [name], size, tags, driver, features)
|
_run_lit_test(name + ".test", data + [name], size, tags, driver, features, exec_properties)
|
||||||
|
@ -512,7 +512,7 @@ cc_library(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":tensorflow_lite",
|
":tensorflow_lite",
|
||||||
"//tensorflow/compiler/mlir/tensorflow",
|
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/core/platform:errors",
|
"//tensorflow/core/platform:errors",
|
||||||
"//tensorflow/core/platform:status",
|
"//tensorflow/core/platform:status",
|
||||||
@ -562,19 +562,16 @@ cc_library(
|
|||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "flatbuffer_translate_lib",
|
name = "flatbuffer_export",
|
||||||
srcs = [
|
srcs = [
|
||||||
"flatbuffer_export.cc",
|
"flatbuffer_export.cc",
|
||||||
"flatbuffer_import.cc",
|
|
||||||
"utils/convert_type.cc",
|
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"flatbuffer_export.h",
|
"flatbuffer_export.h",
|
||||||
"flatbuffer_export_flags.h",
|
"flatbuffer_export_flags.h",
|
||||||
"flatbuffer_import.h",
|
|
||||||
"utils/convert_type.h",
|
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":convert_type",
|
||||||
":flatbuffer_tflite_operator_lib",
|
":flatbuffer_tflite_operator_lib",
|
||||||
":stateful_ops_utils",
|
":stateful_ops_utils",
|
||||||
":tensorflow_lite",
|
":tensorflow_lite",
|
||||||
@ -592,14 +589,12 @@ cc_library(
|
|||||||
"//tensorflow/core/platform:errors",
|
"//tensorflow/core/platform:errors",
|
||||||
"//tensorflow/core/platform:logging",
|
"//tensorflow/core/platform:logging",
|
||||||
"//tensorflow/core/platform:status",
|
"//tensorflow/core/platform:status",
|
||||||
"//tensorflow/lite:framework",
|
|
||||||
"//tensorflow/lite:schema_fbs_version",
|
"//tensorflow/lite:schema_fbs_version",
|
||||||
"//tensorflow/lite:string_util",
|
"//tensorflow/lite:string_util",
|
||||||
"//tensorflow/lite/delegates/flex:whitelisted_flex_ops_lib",
|
"//tensorflow/lite/delegates/flex:whitelisted_flex_ops_lib",
|
||||||
"//tensorflow/lite/kernels/internal:kernel_utils",
|
"//tensorflow/lite/kernels/internal:kernel_utils",
|
||||||
"//tensorflow/lite/schema:schema_fbs",
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
"//tensorflow/lite/tools/versioning",
|
"//tensorflow/lite/tools/versioning",
|
||||||
"@com_google_absl//absl/base",
|
|
||||||
"@com_google_absl//absl/base:core_headers",
|
"@com_google_absl//absl/base:core_headers",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
@ -614,6 +609,78 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "flatbuffer_import",
|
||||||
|
srcs = [
|
||||||
|
"flatbuffer_import.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"flatbuffer_import.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":convert_type",
|
||||||
|
":flatbuffer_tflite_operator_lib",
|
||||||
|
":tensorflow_lite",
|
||||||
|
":tensorflow_lite_dialect_registration",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||||
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/platform:errors",
|
||||||
|
"//tensorflow/core/platform:status",
|
||||||
|
"//tensorflow/lite:framework",
|
||||||
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
|
"@com_google_absl//absl/base",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@llvm-project//llvm:support",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//mlir:QuantOps",
|
||||||
|
"@llvm-project//mlir:StandardOps",
|
||||||
|
"@llvm-project//mlir:Support",
|
||||||
|
"@llvm-project//mlir:Translation",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "convert_type",
|
||||||
|
srcs = [
|
||||||
|
"utils/convert_type.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"utils/convert_type.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||||
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/platform:errors",
|
||||||
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "flatbuffer_translate_lib",
|
||||||
|
hdrs = [
|
||||||
|
"flatbuffer_export.h",
|
||||||
|
"flatbuffer_export_flags.h",
|
||||||
|
"flatbuffer_import.h",
|
||||||
|
"utils/convert_type.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":flatbuffer_export",
|
||||||
|
":flatbuffer_import",
|
||||||
|
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
|
||||||
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "flatbuffer_translate_registeration",
|
name = "flatbuffer_translate_registeration",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -496,7 +496,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
|||||||
auto &value = op.getOperand(i);
|
auto &value = op.getOperand(i);
|
||||||
// Skip from from first variadic operands for now. Else getOperand index
|
// Skip from from first variadic operands for now. Else getOperand index
|
||||||
// used below doesn't match.
|
// used below doesn't match.
|
||||||
if (value.isVariadic()) break;
|
if (value.isVariableLength()) break;
|
||||||
if (!value.name.empty())
|
if (!value.name.empty())
|
||||||
verify_ctx.addSubst(value.name, formatv("op->getOperand({0})", i));
|
verify_ctx.addSubst(value.name, formatv("op->getOperand({0})", i));
|
||||||
}
|
}
|
||||||
@ -504,7 +504,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
|||||||
auto &value = op.getResult(i);
|
auto &value = op.getResult(i);
|
||||||
// Skip from from first variadic results for now. Else getResult index
|
// Skip from from first variadic results for now. Else getResult index
|
||||||
// used below doesn't match.
|
// used below doesn't match.
|
||||||
if (value.isVariadic()) break;
|
if (value.isVariableLength()) break;
|
||||||
if (!value.name.empty())
|
if (!value.name.empty())
|
||||||
verify_ctx.addSubst(value.name, formatv("op->getResult({0})", i));
|
verify_ctx.addSubst(value.name, formatv("op->getResult({0})", i));
|
||||||
}
|
}
|
||||||
|
@ -9,7 +9,9 @@ cc_library(
|
|||||||
name = "cost_estimators",
|
name = "cost_estimators",
|
||||||
textual_hdrs = [
|
textual_hdrs = [
|
||||||
"estimator.h",
|
"estimator.h",
|
||||||
|
"cpu_estimators.h",
|
||||||
"gpu_estimators.h",
|
"gpu_estimators.h",
|
||||||
"hardware.h",
|
"hardware.h",
|
||||||
|
"arithmetic_count_util.h",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -0,0 +1,45 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ARITHMETIC_COUNT_UTIL_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ARITHMETIC_COUNT_UTIL_H_
|
||||||
|
|
||||||
|
// For add/mul/div/sub and other broadcastable ops.
|
||||||
|
class ArithmeticCountUtilHelper {
|
||||||
|
public:
|
||||||
|
static bool GetArithmeticCountForBroadcastableOp(mlir::Operation* op,
|
||||||
|
int64_t* count) {
|
||||||
|
auto output = op->getResult(0);
|
||||||
|
auto output_type = output.getType().dyn_cast_or_null<RankedTensorType>();
|
||||||
|
if (!output_type || !output_type.hasStaticShape()) return false;
|
||||||
|
|
||||||
|
*count = output_type.getNumElements();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool GetInputTensorTotalSize(mlir::Operation* op, int64_t* count) {
|
||||||
|
int64_t total_count = 0;
|
||||||
|
for (auto input : op->getOperands()) {
|
||||||
|
auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
|
||||||
|
if (!input_type || !input_type.hasStaticShape()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
total_count += input_type.getNumElements();
|
||||||
|
}
|
||||||
|
*count = total_count;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ARITHMETIC_COUNT_UTIL_H_
|
@ -0,0 +1,103 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_CPU_ESTIMATORS_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_CPU_ESTIMATORS_H_
|
||||||
|
|
||||||
|
// CPU
|
||||||
|
constexpr float kCPUArithmeticUnitCost = 1.0;
|
||||||
|
|
||||||
|
// This basically assumes pure load/store. This is just fake data.
|
||||||
|
constexpr float kCPUCopyUnitCost = 0.5;
|
||||||
|
constexpr float kCPUDefaultCost = 3.0f;
|
||||||
|
|
||||||
|
// Default values.
|
||||||
|
constexpr float kCPUDefaultFixedValuedCost = 10000.0;
|
||||||
|
|
||||||
|
// tfl.add
|
||||||
|
template <>
|
||||||
|
class TFLiteCostEstimator<AddOp, hardware::CPU> {
|
||||||
|
public:
|
||||||
|
static double GetCost(mlir::Operation* op) {
|
||||||
|
int64_t count;
|
||||||
|
if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op,
|
||||||
|
&count))
|
||||||
|
return kCPUArithmeticUnitCost * count;
|
||||||
|
return kCPUDefaultFixedValuedCost;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// tfl.mul
|
||||||
|
template <>
|
||||||
|
class TFLiteCostEstimator<MulOp, hardware::CPU> {
|
||||||
|
public:
|
||||||
|
static double GetCost(mlir::Operation* op) {
|
||||||
|
int64_t count;
|
||||||
|
if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op,
|
||||||
|
&count))
|
||||||
|
return kCPUArithmeticUnitCost * count;
|
||||||
|
return kCPUDefaultFixedValuedCost;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// tfl.concatenation
|
||||||
|
template <>
|
||||||
|
class TFLiteCostEstimator<ConcatenationOp, hardware::CPU> {
|
||||||
|
public:
|
||||||
|
static double GetCost(mlir::Operation* op) {
|
||||||
|
int64_t count;
|
||||||
|
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
|
||||||
|
return kCPUCopyUnitCost * count;
|
||||||
|
return kCPUDefaultFixedValuedCost;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(renjieliu): We probably need to check for dynamic weights.
|
||||||
|
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// tfl.pack
|
||||||
|
template <>
|
||||||
|
class TFLiteCostEstimator<PackOp, hardware::CPU> {
|
||||||
|
public:
|
||||||
|
static double GetCost(mlir::Operation* op) {
|
||||||
|
int64_t count;
|
||||||
|
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
|
||||||
|
return kCPUCopyUnitCost * count;
|
||||||
|
return kCPUDefaultFixedValuedCost;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(renjieliu): We probably need to check for dynamic weights.
|
||||||
|
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// tfl.reshape
|
||||||
|
template <>
|
||||||
|
class TFLiteCostEstimator<ReshapeOp, hardware::CPU> {
|
||||||
|
public:
|
||||||
|
static double GetCost(mlir::Operation* op) {
|
||||||
|
int64_t count;
|
||||||
|
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
|
||||||
|
return kCPUCopyUnitCost * count;
|
||||||
|
return kCPUDefaultFixedValuedCost;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_CPU_ESTIMATORS_H_
|
@ -16,6 +16,16 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
|
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
|
||||||
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
|
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
|
||||||
|
|
||||||
|
// GPU
|
||||||
|
constexpr float kGPUArithmeticUnitCost = 0.2;
|
||||||
|
|
||||||
|
// The copy can be non-consectutive copy. This is just fake data.
|
||||||
|
constexpr float kGPUCopyUnitCost = 0.2;
|
||||||
|
constexpr float kGPUDefaultCost = 1.0f;
|
||||||
|
|
||||||
|
// Default values.
|
||||||
|
constexpr float kGPUDefaultFixedValuedCost = 10000.0;
|
||||||
|
|
||||||
// tfl.abs
|
// tfl.abs
|
||||||
template <>
|
template <>
|
||||||
class TFLiteCostEstimator<AbsOp, hardware::GPU> {
|
class TFLiteCostEstimator<AbsOp, hardware::GPU> {
|
||||||
@ -34,9 +44,11 @@ template <>
|
|||||||
class TFLiteCostEstimator<AddOp, hardware::GPU> {
|
class TFLiteCostEstimator<AddOp, hardware::GPU> {
|
||||||
public:
|
public:
|
||||||
static double GetCost(mlir::Operation* op) {
|
static double GetCost(mlir::Operation* op) {
|
||||||
llvm::errs() << "No defined cost function for op: "
|
int64_t count;
|
||||||
<< op->getName().getStringRef().str();
|
if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op,
|
||||||
return 0.0;
|
&count))
|
||||||
|
return kGPUArithmeticUnitCost * count;
|
||||||
|
return kGPUDefaultFixedValuedCost;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||||
@ -60,9 +72,10 @@ template <>
|
|||||||
class TFLiteCostEstimator<ConcatenationOp, hardware::GPU> {
|
class TFLiteCostEstimator<ConcatenationOp, hardware::GPU> {
|
||||||
public:
|
public:
|
||||||
static double GetCost(mlir::Operation* op) {
|
static double GetCost(mlir::Operation* op) {
|
||||||
llvm::errs() << "No defined cost function for op: "
|
int64_t count;
|
||||||
<< op->getName().getStringRef().str();
|
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
|
||||||
return 0.0;
|
return kGPUCopyUnitCost * count;
|
||||||
|
return kGPUDefaultFixedValuedCost;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(renjieliu): We probably need to check for dynamic weights.
|
// TODO(renjieliu): We probably need to check for dynamic weights.
|
||||||
@ -227,6 +240,33 @@ class TFLiteCostEstimator<MaximumOp, hardware::GPU> {
|
|||||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// tfl.max_unpooling_2d
|
||||||
|
template <>
|
||||||
|
class TFLiteCostEstimator<MaxUnpooling2DOp, hardware::GPU> {
|
||||||
|
public:
|
||||||
|
static double GetCost(mlir::Operation* op) {
|
||||||
|
llvm::errs() << "No defined cost function for op: "
|
||||||
|
<< op->getName().getStringRef().str();
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// tfl.mean
|
||||||
|
template <>
|
||||||
|
class TFLiteCostEstimator<MeanOp, hardware::GPU> {
|
||||||
|
public:
|
||||||
|
static double GetCost(mlir::Operation* op) {
|
||||||
|
llvm::errs() << "No defined cost function for op: "
|
||||||
|
<< op->getName().getStringRef().str();
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(renjieiu): check for constraints.
|
||||||
|
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||||
|
};
|
||||||
|
|
||||||
// tfl.minimum
|
// tfl.minimum
|
||||||
template <>
|
template <>
|
||||||
class TFLiteCostEstimator<MinimumOp, hardware::GPU> {
|
class TFLiteCostEstimator<MinimumOp, hardware::GPU> {
|
||||||
@ -245,9 +285,11 @@ template <>
|
|||||||
class TFLiteCostEstimator<MulOp, hardware::GPU> {
|
class TFLiteCostEstimator<MulOp, hardware::GPU> {
|
||||||
public:
|
public:
|
||||||
static double GetCost(mlir::Operation* op) {
|
static double GetCost(mlir::Operation* op) {
|
||||||
llvm::errs() << "No defined cost function for op: "
|
int64_t count;
|
||||||
<< op->getName().getStringRef().str();
|
if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op,
|
||||||
return 0.0;
|
&count))
|
||||||
|
return kGPUArithmeticUnitCost * count;
|
||||||
|
return kGPUDefaultFixedValuedCost;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||||
@ -321,6 +363,33 @@ class TFLiteCostEstimator<Relu6Op, hardware::GPU> {
|
|||||||
// tfl.reshape
|
// tfl.reshape
|
||||||
template <>
|
template <>
|
||||||
class TFLiteCostEstimator<ReshapeOp, hardware::GPU> {
|
class TFLiteCostEstimator<ReshapeOp, hardware::GPU> {
|
||||||
|
public:
|
||||||
|
static double GetCost(mlir::Operation* op) {
|
||||||
|
int64_t count;
|
||||||
|
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
|
||||||
|
return kGPUCopyUnitCost * count;
|
||||||
|
return kGPUDefaultFixedValuedCost;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// tfl.rsqrt
|
||||||
|
template <>
|
||||||
|
class TFLiteCostEstimator<RsqrtOp, hardware::GPU> {
|
||||||
|
public:
|
||||||
|
static double GetCost(mlir::Operation* op) {
|
||||||
|
llvm::errs() << "No defined cost function for op: "
|
||||||
|
<< op->getName().getStringRef().str();
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// tfl.sin
|
||||||
|
template <>
|
||||||
|
class TFLiteCostEstimator<SinOp, hardware::GPU> {
|
||||||
public:
|
public:
|
||||||
static double GetCost(mlir::Operation* op) {
|
static double GetCost(mlir::Operation* op) {
|
||||||
llvm::errs() << "No defined cost function for op: "
|
llvm::errs() << "No defined cost function for op: "
|
||||||
@ -357,6 +426,58 @@ class TFLiteCostEstimator<SoftmaxOp, hardware::GPU> {
|
|||||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// tfl.space_to_depth
|
||||||
|
template <>
|
||||||
|
class TFLiteCostEstimator<SpaceToDepthOp, hardware::GPU> {
|
||||||
|
public:
|
||||||
|
static double GetCost(mlir::Operation* op) {
|
||||||
|
llvm::errs() << "No defined cost function for op: "
|
||||||
|
<< op->getName().getStringRef().str();
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// tfl.sqrt
|
||||||
|
template <>
|
||||||
|
class TFLiteCostEstimator<SqrtOp, hardware::GPU> {
|
||||||
|
public:
|
||||||
|
static double GetCost(mlir::Operation* op) {
|
||||||
|
llvm::errs() << "No defined cost function for op: "
|
||||||
|
<< op->getName().getStringRef().str();
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// tfl.square
|
||||||
|
template <>
|
||||||
|
class TFLiteCostEstimator<SquareOp, hardware::GPU> {
|
||||||
|
public:
|
||||||
|
static double GetCost(mlir::Operation* op) {
|
||||||
|
llvm::errs() << "No defined cost function for op: "
|
||||||
|
<< op->getName().getStringRef().str();
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// tfl.squared_difference
|
||||||
|
template <>
|
||||||
|
class TFLiteCostEstimator<SquaredDifferenceOp, hardware::GPU> {
|
||||||
|
public:
|
||||||
|
static double GetCost(mlir::Operation* op) {
|
||||||
|
llvm::errs() << "No defined cost function for op: "
|
||||||
|
<< op->getName().getStringRef().str();
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||||
|
};
|
||||||
|
|
||||||
// tfl.strided_slice
|
// tfl.strided_slice
|
||||||
template <>
|
template <>
|
||||||
class TFLiteCostEstimator<StridedSliceOp, hardware::GPU> {
|
class TFLiteCostEstimator<StridedSliceOp, hardware::GPU> {
|
||||||
@ -370,6 +491,19 @@ class TFLiteCostEstimator<StridedSliceOp, hardware::GPU> {
|
|||||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// tfl.tanh
|
||||||
|
template <>
|
||||||
|
class TFLiteCostEstimator<TanhOp, hardware::GPU> {
|
||||||
|
public:
|
||||||
|
static double GetCost(mlir::Operation* op) {
|
||||||
|
llvm::errs() << "No defined cost function for op: "
|
||||||
|
<< op->getName().getStringRef().str();
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||||
|
};
|
||||||
|
|
||||||
// tfl.transpose
|
// tfl.transpose
|
||||||
template <>
|
template <>
|
||||||
class TFLiteCostEstimator<TransposeOp, hardware::GPU> {
|
class TFLiteCostEstimator<TransposeOp, hardware::GPU> {
|
||||||
@ -383,5 +517,18 @@ class TFLiteCostEstimator<TransposeOp, hardware::GPU> {
|
|||||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// tfl.transpose_conv
|
||||||
|
template <>
|
||||||
|
class TFLiteCostEstimator<TransposeConvOp, hardware::GPU> {
|
||||||
|
public:
|
||||||
|
static double GetCost(mlir::Operation* op) {
|
||||||
|
llvm::errs() << "No defined cost function for op: "
|
||||||
|
<< op->getName().getStringRef().str();
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||||
|
};
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
|
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
|
||||||
|
|
||||||
|
@ -59,13 +59,11 @@ limitations under the License.
|
|||||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
#include "mlir/IR/Types.h" // from @llvm-project
|
#include "mlir/IR/Types.h" // from @llvm-project
|
||||||
#include "mlir/IR/Value.h" // from @llvm-project
|
#include "mlir/IR/Value.h" // from @llvm-project
|
||||||
#include "mlir/Support/Functional.h" // from @llvm-project
|
|
||||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||||
#include "mlir/Translation.h" // from @llvm-project
|
#include "mlir/Translation.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
|
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
|
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
@ -676,8 +674,8 @@ template <typename ContainerType>
|
|||||||
mlir::NamedAttribute BuildTFEntryFunctionAttribute(
|
mlir::NamedAttribute BuildTFEntryFunctionAttribute(
|
||||||
const tflite::SubGraphT& subgraph, Builder* builder, const std::string name,
|
const tflite::SubGraphT& subgraph, Builder* builder, const std::string name,
|
||||||
const ContainerType indices) {
|
const ContainerType indices) {
|
||||||
llvm::SmallVector<std::string, 8> tensor_names = mlir::functional::map(
|
auto tensor_names = llvm::map_range(
|
||||||
[&](int i) { return subgraph.tensors.at(i)->name; }, indices);
|
indices, [&](int i) { return subgraph.tensors.at(i)->name; });
|
||||||
return builder->getNamedAttr(
|
return builder->getNamedAttr(
|
||||||
name, builder->getStringAttr(llvm::join(tensor_names, ",")));
|
name, builder->getStringAttr(llvm::join(tensor_names, ",")));
|
||||||
}
|
}
|
||||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
|||||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
||||||
|
@ -28,7 +28,6 @@ limitations under the License.
|
|||||||
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
|
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
|
||||||
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
|
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
|
||||||
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
|
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
|
||||||
#include "mlir/Support/Functional.h" // from @llvm-project
|
|
||||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
@ -54,6 +53,8 @@ class TensorFlowLiteDialect : public Dialect {
|
|||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc"
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc"
|
||||||
// Include all specializes estimators below this line
|
// Include all specializes estimators below this line
|
||||||
|
#include "tensorflow/compiler/mlir/lite/experimental/estimators/arithmetic_count_util.h"
|
||||||
|
#include "tensorflow/compiler/mlir/lite/experimental/estimators/cpu_estimators.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h"
|
#include "tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h"
|
||||||
|
|
||||||
} // end namespace TFL
|
} // end namespace TFL
|
||||||
|
@ -450,7 +450,7 @@ retained with length 1.
|
|||||||
}
|
}
|
||||||
|
|
||||||
def TFL_TransposeConvOp:
|
def TFL_TransposeConvOp:
|
||||||
TFL_Op<"transpose_conv", [NoSideEffect]> {
|
TFL_Op<"transpose_conv", [NoSideEffect, TFL_GpuTargetOp]> {
|
||||||
let summary = "Transpose convolution operator";
|
let summary = "Transpose convolution operator";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -1658,7 +1658,7 @@ def TFL_MaxPoolingWithArgMax2DOp :
|
|||||||
}
|
}
|
||||||
|
|
||||||
def TFL_MaxUnpooling2DOp :
|
def TFL_MaxUnpooling2DOp :
|
||||||
Op<TFL_Dialect, "max_unpooling_2d", [NoSideEffect]> {
|
Op<TFL_Dialect, "max_unpooling_2d", [NoSideEffect, TFL_GpuTargetOp]> {
|
||||||
let summary = "Max Unpool 2D";
|
let summary = "Max Unpool 2D";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -1711,7 +1711,7 @@ def TFL_MaximumOp : TFL_Op<"maximum", [
|
|||||||
let hasOptions = 0;
|
let hasOptions = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect]> {
|
def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect, TFL_GpuTargetOp]> {
|
||||||
let summary = "Mean operator";
|
let summary = "Mean operator";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -2116,7 +2116,9 @@ def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape,
|
|||||||
let builders = [TFL_BroadcastableBinaryBuilder];
|
let builders = [TFL_BroadcastableBinaryBuilder];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect, TFL_GpuTargetOp]> {
|
def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect,
|
||||||
|
TFL_GpuTargetOp,
|
||||||
|
SameOperandsAndResultsScale]> {
|
||||||
let summary = "Parameterized Relu operator";
|
let summary = "Parameterized Relu operator";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -2165,6 +2167,17 @@ def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect,
|
|||||||
let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x);
|
let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x);
|
||||||
|
|
||||||
let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y);
|
let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y);
|
||||||
|
|
||||||
|
// This builder doesn't work with quantized type, so it can only be used by
|
||||||
|
// non-quantization tablegen patterns. Currently, it is used by the
|
||||||
|
// elementwise-move reordering pattern in the optimize_patterns.td
|
||||||
|
let builders = [OpBuilder<
|
||||||
|
"Builder *, OperationState &state, Value input",
|
||||||
|
[{
|
||||||
|
state.addOperands({input});
|
||||||
|
state.addTypes(input.getType());
|
||||||
|
}]>
|
||||||
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
|
def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
|
||||||
@ -2181,6 +2194,17 @@ def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
|
|||||||
let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x);
|
let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x);
|
||||||
|
|
||||||
let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y);
|
let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y);
|
||||||
|
|
||||||
|
// This builder doesn't work with quantized type, so it can only be used by
|
||||||
|
// non-quantization tablegen patterns. Currently, it is used by the
|
||||||
|
// elementwise-move reordering pattern in the optimize_patterns.td
|
||||||
|
let builders = [OpBuilder<
|
||||||
|
"Builder *, OperationState &state, Value input",
|
||||||
|
[{
|
||||||
|
state.addOperands({input});
|
||||||
|
state.addTypes(input.getType());
|
||||||
|
}]>
|
||||||
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect,
|
def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect,
|
||||||
@ -2196,6 +2220,17 @@ def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect,
|
|||||||
let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x);
|
let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x);
|
||||||
|
|
||||||
let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y);
|
let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y);
|
||||||
|
|
||||||
|
// This builder doesn't work with quantized type, so it can only be used by
|
||||||
|
// non-quantization tablegen patterns. Currently, it is used by the
|
||||||
|
// elementwise-move reordering pattern in the optimize_patterns.td
|
||||||
|
let builders = [OpBuilder<
|
||||||
|
"Builder *, OperationState &state, Value input",
|
||||||
|
[{
|
||||||
|
state.addOperands({input});
|
||||||
|
state.addTypes(input.getType());
|
||||||
|
}]>
|
||||||
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_ReshapeOp: TFL_Op<"reshape", [
|
def TFL_ReshapeOp: TFL_Op<"reshape", [
|
||||||
@ -2247,7 +2282,10 @@ slice `i`, with the first `seq_lengths[i]` slices along dimension
|
|||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
|
def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect,
|
||||||
|
SameOperandsAndResultType,
|
||||||
|
NoQuantizableResult,
|
||||||
|
TFL_GpuTargetOp]> {
|
||||||
let summary = "Reciprocal of square root operator";
|
let summary = "Reciprocal of square root operator";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -2395,7 +2433,10 @@ def TFL_SelectV2Op : TFL_Op<"select_v2", [NoSideEffect]> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
def TFL_SinOp: TFL_Op<"sin", [
|
def TFL_SinOp: TFL_Op<"sin", [
|
||||||
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
|
NoSideEffect,
|
||||||
|
SameOperandsAndResultType,
|
||||||
|
NoQuantizableResult,
|
||||||
|
TFL_GpuTargetOp]> {
|
||||||
let summary = "Sine operator";
|
let summary = "Sine operator";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -2437,7 +2478,10 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [
|
|||||||
}
|
}
|
||||||
|
|
||||||
def TFL_SqrtOp: TFL_Op<"sqrt", [
|
def TFL_SqrtOp: TFL_Op<"sqrt", [
|
||||||
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
|
NoSideEffect,
|
||||||
|
SameOperandsAndResultType,
|
||||||
|
NoQuantizableResult,
|
||||||
|
TFL_GpuTargetOp]> {
|
||||||
let summary = "Square root operator";
|
let summary = "Square root operator";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -2452,7 +2496,10 @@ def TFL_SqrtOp: TFL_Op<"sqrt", [
|
|||||||
}
|
}
|
||||||
|
|
||||||
def TFL_SquareOp: TFL_Op<"square", [
|
def TFL_SquareOp: TFL_Op<"square", [
|
||||||
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
|
NoSideEffect,
|
||||||
|
SameOperandsAndResultType,
|
||||||
|
NoQuantizableResult,
|
||||||
|
TFL_GpuTargetOp]> {
|
||||||
let summary = "Square operator";
|
let summary = "Square operator";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -2496,7 +2543,10 @@ def TFL_SubOp : TFL_Op<"sub", [ResultsBroadcastableShape, NoSideEffect]> {
|
|||||||
// TODO(jpienaar): Expand the kernel implementation to support all types besides
|
// TODO(jpienaar): Expand the kernel implementation to support all types besides
|
||||||
// I32 and F32.
|
// I32 and F32.
|
||||||
def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [
|
def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [
|
||||||
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
|
ResultsBroadcastableShape,
|
||||||
|
NoSideEffect,
|
||||||
|
NoQuantizableResult,
|
||||||
|
TFL_GpuTargetOp]> {
|
||||||
let summary = "Squared difference operator";
|
let summary = "Squared difference operator";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -2523,7 +2573,8 @@ def TFL_TanhOp: TFL_Op<"tanh", [
|
|||||||
// zero_point = central_value
|
// zero_point = central_value
|
||||||
// scale = 1. / (central_value - min_value)
|
// scale = 1. / (central_value - min_value)
|
||||||
FixedResultScale<Int8UniformQuantizedType<0, 78125, -7>>,
|
FixedResultScale<Int8UniformQuantizedType<0, 78125, -7>>,
|
||||||
FixedResultScale<UInt8UniformQuantizedType<128, 78125, -7>>]> {
|
FixedResultScale<UInt8UniformQuantizedType<128, 78125, -7>>,
|
||||||
|
TFL_GpuTargetOp]> {
|
||||||
let summary = "Hyperbolic tangent operator";
|
let summary = "Hyperbolic tangent operator";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -2533,6 +2584,17 @@ def TFL_TanhOp: TFL_Op<"tanh", [
|
|||||||
let arguments = (ins TFL_TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$x);
|
let arguments = (ins TFL_TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$x);
|
||||||
|
|
||||||
let results = (outs TFL_TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$y);
|
let results = (outs TFL_TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$y);
|
||||||
|
|
||||||
|
// This builder doesn't work with quantized type, so it can only be used by
|
||||||
|
// non-quantization tablegen patterns. Currently, it is used by the
|
||||||
|
// elementwise-move reordering pattern in the optimize_patterns.td
|
||||||
|
let builders = [OpBuilder<
|
||||||
|
"Builder *, OperationState &state, Value input",
|
||||||
|
[{
|
||||||
|
state.addOperands({input});
|
||||||
|
state.addTypes(input.getType());
|
||||||
|
}]>
|
||||||
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, SameOperandsAndResultsScale,
|
def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, SameOperandsAndResultsScale,
|
||||||
@ -2718,7 +2780,8 @@ def TFL_SpaceToDepthOp: TFL_Op<"space_to_depth", [
|
|||||||
NoSideEffect,
|
NoSideEffect,
|
||||||
SameOperandsAndResultsScale,
|
SameOperandsAndResultsScale,
|
||||||
PredOpTrait<"input and output must have same element type",
|
PredOpTrait<"input and output must have same element type",
|
||||||
TCresVTEtIsSameAsOp<0, 0>>
|
TCresVTEtIsSameAsOp<0, 0>>,
|
||||||
|
TFL_GpuTargetOp
|
||||||
]> {
|
]> {
|
||||||
let summary = "SpaceToDepth operator";
|
let summary = "SpaceToDepth operator";
|
||||||
|
|
||||||
@ -2981,14 +3044,13 @@ def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
// TODO: add uint8 support when ready.
|
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8]>:$input,
|
||||||
TFL_TensorOf<[F32, I32, I64]>:$input,
|
|
||||||
TFL_TensorOf<[I32, I64]>:$pad,
|
TFL_TensorOf<[I32, I64]>:$pad,
|
||||||
TFL_MirrorPaddingAttr:$mode
|
TFL_MirrorPaddingAttr:$mode
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TFL_TensorOf<[F32, I32, I64]>:$output
|
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8]>:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary")
|
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_native_cc_binary")
|
||||||
load(
|
load(
|
||||||
"//tensorflow/core/platform:build_config.bzl",
|
"//tensorflow/core/platform:build_config.bzl",
|
||||||
"tf_proto_library",
|
"tf_proto_library",
|
||||||
@ -115,11 +115,22 @@ tf_native_cc_binary(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "numerical_utils",
|
||||||
|
srcs = ["numerical_utils.cc"],
|
||||||
|
hdrs = ["numerical_utils.h"],
|
||||||
|
deps = [
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "device_target",
|
name = "device_target",
|
||||||
srcs = ["device_target.cc"],
|
srcs = ["device_target.cc"],
|
||||||
hdrs = ["device_target.h"],
|
hdrs = ["device_target.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":numerical_utils",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
"@llvm-project//llvm:support",
|
"@llvm-project//llvm:support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:QuantOps",
|
"@llvm-project//mlir:QuantOps",
|
||||||
@ -142,3 +153,13 @@ cc_library(
|
|||||||
"@llvm-project//mlir:Support",
|
"@llvm-project//mlir:Support",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "numerical_utils_test",
|
||||||
|
srcs = ["numerical_utils_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":numerical_utils",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -15,12 +15,18 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
|
#include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "absl/types/optional.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
#include "llvm/Support/Casting.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace quant {
|
namespace quant {
|
||||||
@ -39,7 +45,7 @@ DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) {
|
|||||||
assert(qi8n_ == qi8n_);
|
assert(qi8n_ == qi8n_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Optional<KernelSpec> DeviceTarget::Get(QuantizeRegionOp op) const {
|
Optional<KernelSpec> DeviceTarget::GetKernelSpec(QuantizeRegionOp op) const {
|
||||||
auto kernel_specs_it = specs_.find(op.logical_kernel());
|
auto kernel_specs_it = specs_.find(op.logical_kernel());
|
||||||
if (kernel_specs_it == specs_.end()) return llvm::None;
|
if (kernel_specs_it == specs_.end()) return llvm::None;
|
||||||
|
|
||||||
@ -50,9 +56,15 @@ Optional<KernelSpec> DeviceTarget::Get(QuantizeRegionOp op) const {
|
|||||||
return kernel_specs_it->getValue().Find(signature);
|
return kernel_specs_it->getValue().Find(signature);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ScaleDecomposeFn DeviceTarget::GetDecomposeFn(QuantizeRegionOp op) const {
|
||||||
|
auto kernel_specs_it = specs_.find(op.logical_kernel());
|
||||||
|
if (kernel_specs_it == specs_.end()) return ScaleDecomposeFn(nullptr);
|
||||||
|
return kernel_specs_it->second.GetDecomposeFn();
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult DeviceTarget::RegisterKernel(
|
LogicalResult DeviceTarget::RegisterKernel(
|
||||||
llvm::StringRef kernel, const KernelSpecs::Signature& signature,
|
llvm::StringRef kernel, const KernelSpecs::Signature& signature,
|
||||||
const ScaleFn& fn) {
|
const ScaleFn& fn, const ScaleDecomposeFn& dfn) {
|
||||||
return specs_[kernel].Add(signature, {ScaleConstraintType::CustomScale, fn});
|
return specs_[kernel].Add(signature, {ScaleConstraintType::CustomScale, fn});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,5 +90,49 @@ void DeviceTarget::AppendToSignature(ArrayAttr specs_attr,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult DeviceTarget::DecomposeMultiplyAccumulateScale(
|
||||||
|
Operation* op, quant::QuantizedMultipliers* input_multipliers,
|
||||||
|
quant::QuantizedMultipliers* output_multipliers,
|
||||||
|
quant::QuantizedRanges* output_ranges) {
|
||||||
|
auto rop = llvm::dyn_cast<quant::QuantizeRegionOp>(op);
|
||||||
|
if (!rop) return failure();
|
||||||
|
|
||||||
|
llvm::SmallVector<Type, 4> input_specs, out_specs;
|
||||||
|
for (auto spec : rop.input_specs()) {
|
||||||
|
input_specs.push_back(spec.cast<TypeAttr>().getValue());
|
||||||
|
}
|
||||||
|
for (auto spec : rop.output_specs()) {
|
||||||
|
out_specs.push_back(spec.cast<TypeAttr>().getValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto in_spec = input_specs[0].dyn_cast<quant::UniformQuantizedType>();
|
||||||
|
// TODO(fengliuai): handles the PerAxis QuantizedType.
|
||||||
|
auto w_spec = input_specs[1].dyn_cast<quant::UniformQuantizedType>();
|
||||||
|
auto b_spec = input_specs[2].dyn_cast<quant::UniformQuantizedType>();
|
||||||
|
auto o_spec = out_specs[0].dyn_cast<quant::UniformQuantizedType>();
|
||||||
|
if (!in_spec || !w_spec || !b_spec || !o_spec) return failure();
|
||||||
|
|
||||||
|
double scale_product = in_spec.getScale() * w_spec.getScale();
|
||||||
|
if (fabs(scale_product - b_spec.getScale()) < 1e-6) return failure();
|
||||||
|
|
||||||
|
// input multipliers
|
||||||
|
input_multipliers->append(3, kUnitQuantizedMultiplier);
|
||||||
|
|
||||||
|
// output multipliers
|
||||||
|
double real_multiplier = o_spec.getScale() / scale_product;
|
||||||
|
output_multipliers->push_back(quant::QuantizeMultiplier(real_multiplier));
|
||||||
|
|
||||||
|
// output ranges
|
||||||
|
auto min = rop.getAttrOfType<FloatAttr>("min");
|
||||||
|
auto max = rop.getAttrOfType<FloatAttr>("max");
|
||||||
|
output_ranges->push_back(quant::CalculateQuantizedRange(
|
||||||
|
o_spec.getScale(), o_spec.getZeroPoint(),
|
||||||
|
(min ? absl::optional<double>(min.getValueAsDouble()) : absl::nullopt),
|
||||||
|
(max ? absl::optional<double>(max.getValueAsDouble()) : absl::nullopt),
|
||||||
|
o_spec.getStorageTypeMin(), o_spec.getStorageTypeMax()));
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace quant
|
} // namespace quant
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@ -17,13 +17,13 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_DEVICE_TARGET_H_
|
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_DEVICE_TARGET_H_
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <ostream>
|
|
||||||
|
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
#include "llvm/ADT/Hashing.h"
|
#include "llvm/ADT/Hashing.h"
|
||||||
#include "llvm/ADT/MapVector.h"
|
#include "llvm/ADT/MapVector.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/StringMap.h"
|
#include "llvm/ADT/StringMap.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "llvm/Support/ErrorHandling.h"
|
#include "llvm/Support/ErrorHandling.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||||
@ -33,6 +33,7 @@ limitations under the License.
|
|||||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
#include "mlir/IR/Types.h" // from @llvm-project
|
#include "mlir/IR/Types.h" // from @llvm-project
|
||||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace quant {
|
namespace quant {
|
||||||
@ -40,9 +41,17 @@ namespace quant {
|
|||||||
class QuantizeContext;
|
class QuantizeContext;
|
||||||
|
|
||||||
using AdjacentOperations = llvm::SmallVectorImpl<Operation*>;
|
using AdjacentOperations = llvm::SmallVectorImpl<Operation*>;
|
||||||
|
using QuantizedMultipliers = llvm::SmallVector<QuantizedMultiplier, 4>;
|
||||||
|
using QuantizedRanges = llvm::SmallVector<QuantizedRange, 4>;
|
||||||
using ScaleFn = std::function<LogicalResult(QuantizeContext*, Operation*,
|
using ScaleFn = std::function<LogicalResult(QuantizeContext*, Operation*,
|
||||||
AdjacentOperations*, bool*)>;
|
AdjacentOperations*, bool*)>;
|
||||||
|
|
||||||
|
using ScaleDecomposeFn =
|
||||||
|
std::function<LogicalResult(Operation*, QuantizedMultipliers*,
|
||||||
|
QuantizedMultipliers*, QuantizedRanges*)>;
|
||||||
|
|
||||||
|
static const QuantizedMultiplier kUnitQuantizedMultiplier{1, 0};
|
||||||
|
|
||||||
enum class ScaleConstraintType {
|
enum class ScaleConstraintType {
|
||||||
OutputInputSameScale,
|
OutputInputSameScale,
|
||||||
OutputInputFreeScale,
|
OutputInputFreeScale,
|
||||||
@ -73,12 +82,25 @@ class KernelSpecs {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ScaleDecomposeFn GetDecomposeFn() const { return decompose_fn_; }
|
||||||
|
|
||||||
// Adds the kernel signature with the kernel specification.
|
// Adds the kernel signature with the kernel specification.
|
||||||
LogicalResult Add(const Signature& signature, const KernelSpec& spec) {
|
LogicalResult Add(const Signature& signature, const KernelSpec& spec) {
|
||||||
if (all_signatures_.insert({signature, spec}).second) return success();
|
if (all_signatures_.insert({signature, spec}).second) return success();
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
KernelSpecs& WithSignature(const KernelSpecs::Signature& signature,
|
||||||
|
const ScaleFn& fn) {
|
||||||
|
Add(signature, {ScaleConstraintType::CustomScale, fn});
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
KernelSpecs& WithImpl(const ScaleDecomposeFn& dfn) {
|
||||||
|
decompose_fn_ = dfn;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// The signature is pattern match based.
|
// The signature is pattern match based.
|
||||||
struct SignatureInfo : public llvm::DenseMapInfo<Signature> {
|
struct SignatureInfo : public llvm::DenseMapInfo<Signature> {
|
||||||
@ -101,6 +123,10 @@ class KernelSpecs {
|
|||||||
// Maps the signature to the kernel spec. Note that the matching is
|
// Maps the signature to the kernel spec. Note that the matching is
|
||||||
// pattern match based.
|
// pattern match based.
|
||||||
llvm::DenseMap<Signature, KernelSpec, SignatureInfo> all_signatures_;
|
llvm::DenseMap<Signature, KernelSpec, SignatureInfo> all_signatures_;
|
||||||
|
|
||||||
|
// A method to compute the effective multipliers. This is independent on the
|
||||||
|
// bits of the ports, thus all the signature shares the same here.
|
||||||
|
ScaleDecomposeFn decompose_fn_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class DeviceTarget {
|
class DeviceTarget {
|
||||||
@ -108,19 +134,26 @@ class DeviceTarget {
|
|||||||
explicit DeviceTarget(MLIRContext* ctx);
|
explicit DeviceTarget(MLIRContext* ctx);
|
||||||
|
|
||||||
// Retrieves the kernel spec for the quant region op.
|
// Retrieves the kernel spec for the quant region op.
|
||||||
Optional<KernelSpec> Get(quant::QuantizeRegionOp op) const;
|
Optional<KernelSpec> GetKernelSpec(quant::QuantizeRegionOp op) const;
|
||||||
|
|
||||||
|
// Retrieves the scale decomposition function for the quant region op.
|
||||||
|
ScaleDecomposeFn GetDecomposeFn(quant::QuantizeRegionOp op) const;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// Adds the kernel spec with the custom scale function for the kernel.
|
// Adds the kernel spec with the custom scale function for the kernel.
|
||||||
LogicalResult RegisterKernel(llvm::StringRef kernel,
|
LogicalResult RegisterKernel(llvm::StringRef kernel,
|
||||||
const KernelSpecs::Signature& signature,
|
const KernelSpecs::Signature& signature,
|
||||||
const ScaleFn& fn);
|
const ScaleFn& fn, const ScaleDecomposeFn& dfn);
|
||||||
|
|
||||||
// Adds the kernel spec with the scale constraint type for the kernel.
|
// Adds the kernel spec with the scale constraint type for the kernel.
|
||||||
LogicalResult RegisterKernel(llvm::StringRef kernel,
|
LogicalResult RegisterKernel(llvm::StringRef kernel,
|
||||||
const KernelSpecs::Signature& signature,
|
const KernelSpecs::Signature& signature,
|
||||||
const ScaleConstraintType constraint);
|
const ScaleConstraintType constraint);
|
||||||
|
|
||||||
|
// Adds the kernel with the name. Retrun an existing one if it has been
|
||||||
|
// added before.
|
||||||
|
KernelSpecs& RegisterKernel(llvm::StringRef kernel) { return specs_[kernel]; }
|
||||||
|
|
||||||
// converts specification to signature:
|
// converts specification to signature:
|
||||||
// - UniformedQuantizedType -> AnyQuantizedType
|
// - UniformedQuantizedType -> AnyQuantizedType
|
||||||
// - AnyQuantizedType (int) -> AnyQuantizedType
|
// - AnyQuantizedType (int) -> AnyQuantizedType
|
||||||
@ -128,6 +161,13 @@ class DeviceTarget {
|
|||||||
void AppendToSignature(ArrayAttr specs_attr,
|
void AppendToSignature(ArrayAttr specs_attr,
|
||||||
KernelSpecs::Signature* signature) const;
|
KernelSpecs::Signature* signature) const;
|
||||||
|
|
||||||
|
// For "mulmat->add" type of kernels, convert the scales of all the ports to
|
||||||
|
// multipliers.
|
||||||
|
static LogicalResult DecomposeMultiplyAccumulateScale(
|
||||||
|
Operation* op, quant::QuantizedMultipliers* input_multipliers,
|
||||||
|
quant::QuantizedMultipliers* output_multipliers,
|
||||||
|
quant::QuantizedRanges* output_ranges);
|
||||||
|
|
||||||
// A set of parameters are required to build the signatures.
|
// A set of parameters are required to build the signatures.
|
||||||
FloatType f32_;
|
FloatType f32_;
|
||||||
IntegerType i8_;
|
IntegerType i8_;
|
||||||
|
@ -33,7 +33,6 @@ limitations under the License.
|
|||||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
#include "mlir/Support/Functional.h" // from @llvm-project
|
|
||||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_info.pb.h"
|
#include "tensorflow/compiler/mlir/lite/quantization/quantization_info.pb.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
|
#include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
|
||||||
|
@ -0,0 +1,82 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h"
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include "absl/types/optional.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace quant {
|
||||||
|
|
||||||
|
// This method is adopted from TFLite:
|
||||||
|
// ["tensorflow/lite/kernels/internal/quantization_util.cc"]
|
||||||
|
QuantizedMultiplier QuantizeMultiplier(double double_multiplier) {
|
||||||
|
if (double_multiplier < 1e-6) {
|
||||||
|
return {0, 0};
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t shift;
|
||||||
|
const double q = frexp(double_multiplier, &shift);
|
||||||
|
auto q_fixed = static_cast<int64_t>(round(q * (1ll << 31)));
|
||||||
|
assert(q_fixed <= (1ll << 31));
|
||||||
|
if (q_fixed == (1ll << 31)) {
|
||||||
|
q_fixed /= 2;
|
||||||
|
++shift;
|
||||||
|
}
|
||||||
|
assert(q_fixed <= std::numeric_limits<int32_t>::max());
|
||||||
|
// A shift amount smaller than -31 would cause all bits to be shifted out
|
||||||
|
// and thus all results would be zero. We implement that instead with
|
||||||
|
// q_fixed==0, so as to avoid hitting issues with right-shift
|
||||||
|
// operations with shift amounts greater than 31. Note that this happens
|
||||||
|
// roughly when abs(double_multiplier) < 2^-31 and the present handling means
|
||||||
|
// that we're effectively flushing tiny double_multiplier's to zero.
|
||||||
|
// We could conceivably handle values in the range (roughly) [32, 63]
|
||||||
|
// as 'denormals' i.e. (shift==0, q_fixed < 2^30). In that point of view
|
||||||
|
// the present handling is just doing 'flush denormals to zero'. We could
|
||||||
|
// reconsider and actually generate nonzero denormals if a need arises.
|
||||||
|
if (shift < -31) {
|
||||||
|
shift = 0;
|
||||||
|
q_fixed = 0;
|
||||||
|
}
|
||||||
|
return {static_cast<int32_t>(q_fixed), shift};
|
||||||
|
}
|
||||||
|
|
||||||
|
QuantizedRange CalculateQuantizedRange(double scale, int32_t zero_point,
|
||||||
|
absl::optional<double> rmin,
|
||||||
|
absl::optional<double> rmax,
|
||||||
|
int32_t qmin, int32_t qmax) {
|
||||||
|
auto quantize = [scale, zero_point](float f) {
|
||||||
|
return zero_point + static_cast<int32_t>(std::round(f / scale));
|
||||||
|
};
|
||||||
|
|
||||||
|
if (rmin.has_value() && rmax.has_value()) {
|
||||||
|
return {std::max(qmin, quantize(rmin.value())),
|
||||||
|
std::min(qmax, quantize(rmax.value()))};
|
||||||
|
} else if (rmin.has_value()) {
|
||||||
|
return {std::max(qmin, quantize(rmin.value())), qmax};
|
||||||
|
} else if (rmax.has_value()) {
|
||||||
|
return {qmin, std::min(qmax, quantize(rmax.value()))};
|
||||||
|
} else {
|
||||||
|
return {qmin, qmax};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace quant
|
||||||
|
} // namespace mlir
|
45
tensorflow/compiler/mlir/lite/quantization/numerical_utils.h
Normal file
45
tensorflow/compiler/mlir/lite/quantization/numerical_utils.h
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_NUMERICAL_UTILS_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_NUMERICAL_UTILS_H_
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "absl/types/optional.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace quant {
|
||||||
|
|
||||||
|
using QuantizedMultiplier = std::pair<int32_t, int32_t>;
|
||||||
|
using QuantizedRange = std::pair<int32_t, int32_t>;
|
||||||
|
|
||||||
|
// Decompose double precision multiplier to integer multiplier and exponent.
|
||||||
|
// double_multiplier = int_multiplier * 2 ^ (-31 + exponent)
|
||||||
|
// int_multiplier will be range of (2^31, 2^30].
|
||||||
|
QuantizedMultiplier QuantizeMultiplier(double double_multiplier);
|
||||||
|
|
||||||
|
// Calculate the effective quantized value range for the scale, zero point. The
|
||||||
|
// range is the minimum range defined by [rmin, rmax] and [qmin, qmax].
|
||||||
|
QuantizedRange CalculateQuantizedRange(double scale, int32_t zero_point,
|
||||||
|
absl::optional<double> rmin,
|
||||||
|
absl::optional<double> rmax,
|
||||||
|
int32_t qmin, int32_t qmax);
|
||||||
|
|
||||||
|
} // namespace quant
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_NUMERICAL_UTILS_H_
|
@ -0,0 +1,114 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "absl/types/optional.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace quant {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
double ComposeScale(const QuantizedMultiplier& input) {
|
||||||
|
return input.first * exp2(-31 + input.second);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NumericalUtils, QuantizeMultiplier) {
|
||||||
|
// Decompose multiplier larger than 1.
|
||||||
|
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e6)), 1.0e6);
|
||||||
|
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e3)), 1.0e3);
|
||||||
|
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(10.)), 10.);
|
||||||
|
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(5.)), 5.);
|
||||||
|
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(2.)), 2.);
|
||||||
|
|
||||||
|
// Decompose multiplier between 1.0 and 1e-6.
|
||||||
|
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(0.0)), 0.0);
|
||||||
|
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0)), 1.0);
|
||||||
|
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-1)), 1.0e-1);
|
||||||
|
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-2)), 1.0e-2);
|
||||||
|
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-3)), 1.0e-3);
|
||||||
|
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-4)), 1.0e-4);
|
||||||
|
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-5)), 1.0e-5);
|
||||||
|
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-6)), 1.0e-6);
|
||||||
|
|
||||||
|
// When scale is smaller than 1.0e-6, it is decomposed to {0, 0}.
|
||||||
|
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-7)), 0.0);
|
||||||
|
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-8)), 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NumericalUtils, ActivationRange) {
|
||||||
|
// zero point = 0
|
||||||
|
auto a =
|
||||||
|
CalculateQuantizedRange(1e-6, 0, absl::nullopt, absl::nullopt, -128, 127);
|
||||||
|
ASSERT_EQ(a.first, -128);
|
||||||
|
ASSERT_EQ(a.second, 127);
|
||||||
|
|
||||||
|
auto b = CalculateQuantizedRange(1e-6, 0, 0.0, absl::nullopt, -128, 127);
|
||||||
|
ASSERT_EQ(b.first, 0);
|
||||||
|
ASSERT_EQ(b.second, 127);
|
||||||
|
|
||||||
|
auto c = CalculateQuantizedRange(1e-6, 0, -1.0, 1.0, -128, 127);
|
||||||
|
ASSERT_EQ(c.first, -128);
|
||||||
|
ASSERT_EQ(c.second, 127);
|
||||||
|
|
||||||
|
auto d = CalculateQuantizedRange(1e-6, 0, 0.0, 6.0, -128, 127);
|
||||||
|
ASSERT_EQ(d.first, 0);
|
||||||
|
ASSERT_EQ(d.second, 127);
|
||||||
|
|
||||||
|
// zero point = 100
|
||||||
|
auto e = CalculateQuantizedRange(1e-6, 100, absl::nullopt, absl::nullopt,
|
||||||
|
-128, 127);
|
||||||
|
ASSERT_EQ(e.first, -128);
|
||||||
|
ASSERT_EQ(e.second, 127);
|
||||||
|
|
||||||
|
auto f = CalculateQuantizedRange(1e-6, 100, 0.0, absl::nullopt, -128, 127);
|
||||||
|
ASSERT_EQ(f.first, 100);
|
||||||
|
ASSERT_EQ(f.second, 127);
|
||||||
|
|
||||||
|
auto g = CalculateQuantizedRange(1e-6, 100, -1.0, 1.0, -128, 127);
|
||||||
|
ASSERT_EQ(g.first, -128);
|
||||||
|
ASSERT_EQ(g.second, 127);
|
||||||
|
|
||||||
|
auto h = CalculateQuantizedRange(1e-6, 100, 0.0, 6.0, -128, 127);
|
||||||
|
ASSERT_EQ(h.first, 100);
|
||||||
|
ASSERT_EQ(h.second, 127);
|
||||||
|
|
||||||
|
// zero point = -100
|
||||||
|
auto i = CalculateQuantizedRange(1e-6, -100, absl::nullopt, absl::nullopt,
|
||||||
|
-128, 127);
|
||||||
|
ASSERT_EQ(i.first, -128);
|
||||||
|
ASSERT_EQ(i.second, 127);
|
||||||
|
|
||||||
|
auto j = CalculateQuantizedRange(1e-6, -100, 0.0, absl::nullopt, -128, 127);
|
||||||
|
ASSERT_EQ(j.first, -100);
|
||||||
|
ASSERT_EQ(j.second, 127);
|
||||||
|
|
||||||
|
auto k = CalculateQuantizedRange(1e-6, -100, -1.0, 1.0, -128, 127);
|
||||||
|
ASSERT_EQ(k.first, -128);
|
||||||
|
ASSERT_EQ(k.second, 127);
|
||||||
|
|
||||||
|
auto l = CalculateQuantizedRange(1e-6, -100, 0.0, 6.0, -128, 127);
|
||||||
|
ASSERT_EQ(l.first, -100);
|
||||||
|
ASSERT_EQ(l.second, 127);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace quant
|
||||||
|
} // namespace mlir
|
@ -67,7 +67,7 @@ std::vector<quant::QuantizeRegionOp> QuantizeContext::GetAllOps() {
|
|||||||
LogicalResult QuantizeContext::Handle(
|
LogicalResult QuantizeContext::Handle(
|
||||||
quant::QuantizeRegionOp op, llvm::SmallVectorImpl<Operation *> *new_items,
|
quant::QuantizeRegionOp op, llvm::SmallVectorImpl<Operation *> *new_items,
|
||||||
bool *changed) {
|
bool *changed) {
|
||||||
auto spec = target_spec_.Get(op);
|
auto spec = target_spec_.GetKernelSpec(op);
|
||||||
if (!spec.hasValue()) {
|
if (!spec.hasValue()) {
|
||||||
op.emitWarning(
|
op.emitWarning(
|
||||||
"Couldn't find kernel from the registeration for quantization.");
|
"Couldn't find kernel from the registeration for quantization.");
|
||||||
|
@ -146,7 +146,7 @@ void LegalizeTFToQuant::runOnFunction() {
|
|||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
auto *ctx = func.getContext();
|
auto *ctx = func.getContext();
|
||||||
patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant>(ctx);
|
patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant>(ctx);
|
||||||
applyPatternsGreedily(func, patterns);
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -23,10 +23,10 @@ filegroup(
|
|||||||
data = [
|
data = [
|
||||||
":importer_test_legacy_reshape",
|
":importer_test_legacy_reshape",
|
||||||
":importer_test_min_max",
|
":importer_test_min_max",
|
||||||
|
":test_schema.fbs",
|
||||||
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
|
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
|
||||||
"//tensorflow/compiler/mlir/lite:flatbuffer_translate",
|
"//tensorflow/compiler/mlir/lite:flatbuffer_translate",
|
||||||
"//tensorflow/compiler/mlir/lite:json_to_flatbuffer",
|
"//tensorflow/compiler/mlir/lite:json_to_flatbuffer",
|
||||||
"//tensorflow/lite/schema:schema.fbs",
|
|
||||||
"@llvm-project//llvm:FileCheck",
|
"@llvm-project//llvm:FileCheck",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: json_to_flatbuffer %p/../../../../../lite/schema/schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --dump-input-on-failure %s
|
// RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --dump-input-on-failure %s
|
||||||
|
|
||||||
// CHECK: %cst = constant unit
|
// CHECK: %cst = constant unit
|
||||||
// CHECK: %[[RES0:.*]] = "tfl.conv_2d"(%arg0, %arg1, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 0 : i32, stride_w = 0 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, none) -> tensor<256x32x32x16xf32>
|
// CHECK: %[[RES0:.*]] = "tfl.conv_2d"(%arg0, %arg1, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 0 : i32, stride_w = 0 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, none) -> tensor<256x32x32x16xf32>
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: json_to_flatbuffer %p/../../../../../lite/schema/schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --dump-input-on-failure %s
|
// RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --dump-input-on-failure %s
|
||||||
|
|
||||||
// This test is to test that if the flatbuffer omits the last optional input `bias` of tfl.conv_2d op, the flatbuffer_importer will automatically adds `none` value to tfl.conv_2d.
|
// This test is to test that if the flatbuffer omits the last optional input `bias` of tfl.conv_2d op, the flatbuffer_importer will automatically adds `none` value to tfl.conv_2d.
|
||||||
|
|
||||||
|
1092
tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/test_schema.fbs
Normal file
1092
tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/test_schema.fbs
Normal file
File diff suppressed because it is too large
Load Diff
@ -52,14 +52,14 @@ func @squeezeAndReshape(%arg0: tensor<1x1x10xf32>, %arg1: tensor<?x10xf32>) -> i
|
|||||||
%1 = "tf.Squeeze"(%arg1) : (tensor<?x10xf32>) -> tensor<*xf32>
|
%1 = "tf.Squeeze"(%arg1) : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
%2 = "tf.Const"() { value = dense<[2, 5]> : tensor<2xi32> } : () -> tensor<2xi32>
|
%2 = "tf.Const"() { value = dense<[2, 5]> : tensor<2xi32> } : () -> tensor<2xi32>
|
||||||
%3 = "tf.Reshape" (%0, %2) : (tensor<1x10xf32>, tensor<2xi32>) -> tensor<2x5xf32>
|
%3 = "tf.Reshape" (%0, %2) : (tensor<1x10xf32>, tensor<2xi32>) -> tensor<2x5xf32>
|
||||||
%4 = "some_op"(%1, %3) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
|
%4 = "tf.some_op"(%1, %3) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
|
||||||
return %4 : i32
|
return %4 : i32
|
||||||
// CHECK-LABEL: squeezeAndReshape
|
// CHECK-LABEL: squeezeAndReshape
|
||||||
// CHECK: "tfl.squeeze"(%arg0) {squeeze_dims = [0]} : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
|
// CHECK: "tfl.squeeze"(%arg0) {squeeze_dims = [0]} : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
|
||||||
// CHECK: %1 = "tfl.squeeze"(%arg1) {squeeze_dims = []} : (tensor<?x10xf32>) -> tensor<*xf32>
|
// CHECK: %1 = "tfl.squeeze"(%arg1) {squeeze_dims = []} : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
// CHECK: %cst = constant dense<[2, 5]> : tensor<2xi32>
|
// CHECK: %cst = constant dense<[2, 5]> : tensor<2xi32>
|
||||||
// CHECK: %2 = "tfl.reshape"(%0, %cst) : (tensor<1x10xf32>, tensor<2xi32>) -> tensor<2x5xf32>
|
// CHECK: %2 = "tfl.reshape"(%0, %cst) : (tensor<1x10xf32>, tensor<2xi32>) -> tensor<2x5xf32>
|
||||||
// CHECK: %3 = "some_op"(%1, %2) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
|
// CHECK: %3 = "tf.some_op"(%1, %2) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt -tfl-load-recipe %s | FileCheck %s --dump-input-on-failure
|
// RUN: tf-opt -allow-unregistered-dialect -tfl-load-recipe %s | FileCheck %s --dump-input-on-failure
|
||||||
|
|
||||||
// CHECK-LABEL: testLstm
|
// CHECK-LABEL: testLstm
|
||||||
func @testLstm(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>, %arg3: tensor<*xf32>, %arg4: tensor<*xf32>, %arg5: tensor<*xf32>, %arg6: tensor<*xf32>, %arg7: tensor<*xf32>, %arg8: tensor<*xf32>, %arg9: tensor<*xf32>, %arg10: tensor<*xf32>, %arg11: tensor<*xf32>, %arg12: tensor<*xf32>, %arg13: tensor<*xf32>, %arg14: tensor<*xf32>, %arg15: tensor<*xf32>, %arg16: tensor<*xf32>, %arg17: tensor<*xf32>, %arg18: tensor<*xf32>, %arg19: tensor<*xf32>, %arg20: tensor<*xf32>, %arg21: tensor<*xf32>, %arg22: tensor<*xf32>, %arg23: tensor<*xf32>) -> tensor<*xf32> {
|
func @testLstm(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>, %arg3: tensor<*xf32>, %arg4: tensor<*xf32>, %arg5: tensor<*xf32>, %arg6: tensor<*xf32>, %arg7: tensor<*xf32>, %arg8: tensor<*xf32>, %arg9: tensor<*xf32>, %arg10: tensor<*xf32>, %arg11: tensor<*xf32>, %arg12: tensor<*xf32>, %arg13: tensor<*xf32>, %arg14: tensor<*xf32>, %arg15: tensor<*xf32>, %arg16: tensor<*xf32>, %arg17: tensor<*xf32>, %arg18: tensor<*xf32>, %arg19: tensor<*xf32>, %arg20: tensor<*xf32>, %arg21: tensor<*xf32>, %arg22: tensor<*xf32>, %arg23: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
@ -439,6 +439,31 @@ func @NotReorderReshapeAddIfNotTailingDim(%arg0: tensor<40x40x1xf32>) -> tensor<
|
|||||||
// CHECK: return %[[rs2]]
|
// CHECK: return %[[rs2]]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @ReorderElementwiseValueOpAndMoveOp
|
||||||
|
func @ReorderElementwiseValueOpAndMoveOp(%arg0: tensor<40x40x1xf32>) -> tensor<40x40xf32> {
|
||||||
|
%shape = constant dense<[40, 40]> : tensor<2xi32>
|
||||||
|
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<40x40x1xf32>, tensor<2xi32>) -> tensor<40x40xf32>
|
||||||
|
%2 = "tfl.relu"(%1) : (tensor<40x40xf32>) -> tensor<40x40xf32>
|
||||||
|
return %2 : tensor<40x40xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[rs1:.*]] = "tfl.relu"(%arg0
|
||||||
|
// CHECK: %[[rs2:.*]] = "tfl.reshape"(%[[rs1]]
|
||||||
|
// CHECK: return %[[rs2]]
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @NotReorderElementwiseValueOpAndMoveOp
|
||||||
|
func @NotReorderElementwiseValueOpAndMoveOp(%arg0: tensor<40x40x1xf32>) -> (tensor<40x40xf32>, tensor<40x40xf32>) {
|
||||||
|
%shape = constant dense<[40, 40]> : tensor<2xi32>
|
||||||
|
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<40x40x1xf32>, tensor<2xi32>) -> tensor<40x40xf32>
|
||||||
|
%2 = "tfl.relu"(%1) : (tensor<40x40xf32>) -> tensor<40x40xf32>
|
||||||
|
return %1, %2 : tensor<40x40xf32>, tensor<40x40xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0
|
||||||
|
// CHECK: %[[rs2:.*]] = "tfl.relu"(%[[rs1]]
|
||||||
|
// CHECK: return %[[rs1]], %[[rs2]]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// CHECK-LABEL: @FuseFullyConnectedRelu
|
// CHECK-LABEL: @FuseFullyConnectedRelu
|
||||||
func @FuseFullyConnectedRelu(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> {
|
func @FuseFullyConnectedRelu(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> {
|
||||||
%0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32>
|
%0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32>
|
||||||
@ -450,6 +475,28 @@ func @FuseFullyConnectedRelu(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32
|
|||||||
// CHECK: return %[[RES]]
|
// CHECK: return %[[RES]]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @FuseFullyConnectedRelu6
|
||||||
|
func @FuseFullyConnectedRelu6(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> {
|
||||||
|
%0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32>
|
||||||
|
%1 = "tfl.relu6"(%0) : (tensor<1x128xf32>) -> tensor<1x128xf32>
|
||||||
|
return %1 : tensor<1x128xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[RES:[0-9].*]] = "tfl.fully_connected"
|
||||||
|
// CHECK-SAME: fused_activation_function = "RELU6"
|
||||||
|
// CHECK: return %[[RES]]
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @FuseFullyConnectedRelu1
|
||||||
|
func @FuseFullyConnectedRelu1(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> {
|
||||||
|
%0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32>
|
||||||
|
%1 = "tfl.relu_n1_to_1"(%0) : (tensor<1x128xf32>) -> tensor<1x128xf32>
|
||||||
|
return %1 : tensor<1x128xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[RES:[0-9].*]] = "tfl.fully_connected"
|
||||||
|
// CHECK-SAME: fused_activation_function = "RELU_N1_TO_1"
|
||||||
|
// CHECK: return %[[RES]]
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @HardSwishPattern
|
// CHECK-LABEL: @HardSwishPattern
|
||||||
func @HardSwishPattern(%arg0: tensor<1xf32>) -> tensor<1xf32> {
|
func @HardSwishPattern(%arg0: tensor<1xf32>) -> tensor<1xf32> {
|
||||||
%three = constant dense<3.> : tensor<f32>
|
%three = constant dense<3.> : tensor<f32>
|
||||||
|
@ -161,7 +161,7 @@ func @_functionalize_if_else_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>
|
|||||||
}
|
}
|
||||||
|
|
||||||
func @_functionalize_if_then_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
|
func @_functionalize_if_then_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
|
||||||
%0 = "my_unknown_op.blah"() : () -> tensor<i1>
|
%0 = "tf.blah"() : () -> tensor<i1>
|
||||||
return %0 : tensor<i1>
|
return %0 : tensor<i1>
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -199,7 +199,7 @@ func @_functionalize_if_else_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>
|
|||||||
}
|
}
|
||||||
|
|
||||||
func @_functionalize_if_then_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
|
func @_functionalize_if_then_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
|
||||||
%0 = "my_unknown_op.blah"() : () -> tensor<i1>
|
%0 = "tf.blah"() : () -> tensor<i1>
|
||||||
return %0 : tensor<i1>
|
return %0 : tensor<i1>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include "llvm/Support/InitLLVM.h"
|
#include "llvm/Support/InitLLVM.h"
|
||||||
#include "llvm/Support/SourceMgr.h"
|
#include "llvm/Support/SourceMgr.h"
|
||||||
#include "llvm/Support/ToolOutputFile.h"
|
#include "llvm/Support/ToolOutputFile.h"
|
||||||
|
#include "mlir/IR/AsmState.h" // from @llvm-project
|
||||||
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
||||||
#include "mlir/IR/Function.h" // from @llvm-project
|
#include "mlir/IR/Function.h" // from @llvm-project
|
||||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||||
@ -128,6 +129,7 @@ int main(int argc, char **argv) {
|
|||||||
// We need to disable duplicated ones to provide a cleaner command-line option
|
// We need to disable duplicated ones to provide a cleaner command-line option
|
||||||
// interface. That also means we need to relay the value set in one option to
|
// interface. That also means we need to relay the value set in one option to
|
||||||
// all its aliases.
|
// all its aliases.
|
||||||
|
mlir::registerAsmPrinterCLOptions();
|
||||||
llvm::cl::ParseCommandLineOptions(
|
llvm::cl::ParseCommandLineOptions(
|
||||||
argc, argv, "TF GraphDef to TFLite FlatBuffer converter\n");
|
argc, argv, "TF GraphDef to TFLite FlatBuffer converter\n");
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
|||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Support/Functional.h"
|
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
@ -30,7 +30,7 @@ void IdentifyDilatedConvPass::runOnFunction() {
|
|||||||
patterns.insert<ConvertTFDilatedConvOp<TF::Conv2DOp>,
|
patterns.insert<ConvertTFDilatedConvOp<TF::Conv2DOp>,
|
||||||
ConvertTFDilatedConvOp<TF::DepthwiseConv2dNativeOp>>(
|
ConvertTFDilatedConvOp<TF::DepthwiseConv2dNativeOp>>(
|
||||||
&getContext());
|
&getContext());
|
||||||
applyPatternsGreedily(func, patterns);
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -38,7 +38,6 @@ limitations under the License.
|
|||||||
#include "mlir/IR/Value.h" // from @llvm-project
|
#include "mlir/IR/Value.h" // from @llvm-project
|
||||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||||
#include "mlir/Support/Functional.h" // from @llvm-project
|
|
||||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||||
|
@ -36,7 +36,6 @@ limitations under the License.
|
|||||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
#include "mlir/Support/Functional.h" // from @llvm-project
|
|
||||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||||
@ -288,12 +287,10 @@ LogicalResult ConvertTFSplitOp::matchAndRewrite(
|
|||||||
Operation* op, PatternRewriter& rewriter) const {
|
Operation* op, PatternRewriter& rewriter) const {
|
||||||
auto tf_split_op = cast<TF::SplitOp>(op);
|
auto tf_split_op = cast<TF::SplitOp>(op);
|
||||||
|
|
||||||
auto output_types = functional::map([](Value v) { return v.getType(); },
|
|
||||||
tf_split_op.output());
|
|
||||||
// Number of splits cannot be negative.
|
// Number of splits cannot be negative.
|
||||||
auto num_split = rewriter.getI32IntegerAttr(tf_split_op.num_split());
|
auto num_split = rewriter.getI32IntegerAttr(tf_split_op.num_split());
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<TFL::SplitOp>(op, output_types,
|
rewriter.replaceOpWithNewOp<TFL::SplitOp>(op, tf_split_op.output().getTypes(),
|
||||||
tf_split_op.split_dim(),
|
tf_split_op.split_dim(),
|
||||||
tf_split_op.value(), num_split);
|
tf_split_op.value(), num_split);
|
||||||
return success();
|
return success();
|
||||||
@ -303,14 +300,12 @@ LogicalResult ConvertTFSplitVOp::matchAndRewrite(
|
|||||||
Operation* op, PatternRewriter& rewriter) const {
|
Operation* op, PatternRewriter& rewriter) const {
|
||||||
auto tf_splitv_op = cast<TF::SplitVOp>(op);
|
auto tf_splitv_op = cast<TF::SplitVOp>(op);
|
||||||
|
|
||||||
auto output_types = functional::map([](Value v) { return v.getType(); },
|
|
||||||
tf_splitv_op.output());
|
|
||||||
// Number of splits cannot be negative.
|
// Number of splits cannot be negative.
|
||||||
auto num_split = rewriter.getI32IntegerAttr(tf_splitv_op.num_split());
|
auto num_split = rewriter.getI32IntegerAttr(tf_splitv_op.num_split());
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<TFL::SplitVOp>(
|
rewriter.replaceOpWithNewOp<TFL::SplitVOp>(
|
||||||
op, output_types, tf_splitv_op.value(), tf_splitv_op.size_splits(),
|
op, tf_splitv_op.output().getTypes(), tf_splitv_op.value(),
|
||||||
tf_splitv_op.split_dim(), num_split);
|
tf_splitv_op.size_splits(), tf_splitv_op.split_dim(), num_split);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -402,13 +397,12 @@ LogicalResult ConvertTFUnpackOp::matchAndRewrite(
|
|||||||
auto tf_unpack_op = cast<TF::UnpackOp>(op);
|
auto tf_unpack_op = cast<TF::UnpackOp>(op);
|
||||||
|
|
||||||
auto input = tf_unpack_op.value();
|
auto input = tf_unpack_op.value();
|
||||||
auto output_types = functional::map([](Value v) { return v.getType(); },
|
|
||||||
tf_unpack_op.output());
|
|
||||||
auto num = rewriter.getI32IntegerAttr(tf_unpack_op.num());
|
auto num = rewriter.getI32IntegerAttr(tf_unpack_op.num());
|
||||||
// Axis can be negative.
|
// Axis can be negative.
|
||||||
auto axis = rewriter.getI32IntegerAttr(tf_unpack_op.axis().getSExtValue());
|
auto axis = rewriter.getI32IntegerAttr(tf_unpack_op.axis().getSExtValue());
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<UnpackOp>(op, output_types, input, num, axis);
|
rewriter.replaceOpWithNewOp<UnpackOp>(op, tf_unpack_op.output().getTypes(),
|
||||||
|
input, num, axis);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,7 +49,6 @@ limitations under the License.
|
|||||||
#include "mlir/IR/Value.h" // from @llvm-project
|
#include "mlir/IR/Value.h" // from @llvm-project
|
||||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||||
#include "mlir/Support/Functional.h" // from @llvm-project
|
|
||||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||||
|
@ -37,7 +37,6 @@ limitations under the License.
|
|||||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
#include "mlir/Support/Functional.h" // from @llvm-project
|
|
||||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||||
@ -52,6 +51,9 @@ namespace TFL {
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// The actual Optimize Pass.
|
// The actual Optimize Pass.
|
||||||
namespace {
|
namespace {
|
||||||
|
constexpr char kRelu[] = "RELU";
|
||||||
|
constexpr char kRelu6[] = "RELU6";
|
||||||
|
constexpr char kRelu1[] = "RELU_N1_TO_1";
|
||||||
|
|
||||||
bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) {
|
bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) {
|
||||||
if (sq_op.getType().cast<ShapedType>().getRank() - 1 ==
|
if (sq_op.getType().cast<ShapedType>().getRank() - 1 ==
|
||||||
@ -301,10 +303,11 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// TODO(b/136285429): Move to tablegen when variadic is supported.
|
// TODO(b/136285429): Move to tablegen when variadic is supported.
|
||||||
struct FuseFullyConnectedAndRelu : public OpRewritePattern<TFL::ReluOp> {
|
template <typename ReluXOp, char const *Act>
|
||||||
using OpRewritePattern<TFL::ReluOp>::OpRewritePattern;
|
struct FuseFullyConnectedAndReluX : public OpRewritePattern<ReluXOp> {
|
||||||
|
using OpRewritePattern<ReluXOp>::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(TFL::ReluOp relu_op,
|
LogicalResult matchAndRewrite(ReluXOp relu_op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Operation *input = relu_op.getOperand().getDefiningOp();
|
Operation *input = relu_op.getOperand().getDefiningOp();
|
||||||
if (!isa_and_nonnull<FullyConnectedOp>(input)) return failure();
|
if (!isa_and_nonnull<FullyConnectedOp>(input)) return failure();
|
||||||
@ -312,7 +315,7 @@ struct FuseFullyConnectedAndRelu : public OpRewritePattern<TFL::ReluOp> {
|
|||||||
if (fully_connected_op.fused_activation_function() != "NONE")
|
if (fully_connected_op.fused_activation_function() != "NONE")
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto new_activation_func = rewriter.getStringAttr("RELU");
|
auto new_activation_func = rewriter.getStringAttr(Act);
|
||||||
auto new_weights_format =
|
auto new_weights_format =
|
||||||
rewriter.getStringAttr(fully_connected_op.weights_format());
|
rewriter.getStringAttr(fully_connected_op.weights_format());
|
||||||
auto new_keep_num_dims =
|
auto new_keep_num_dims =
|
||||||
@ -709,9 +712,12 @@ void Optimize::runOnFunction() {
|
|||||||
// we explore these potentially first and then fuse the binary ops with the
|
// we explore these potentially first and then fuse the binary ops with the
|
||||||
// following ops in a second pattern match.
|
// following ops in a second pattern match.
|
||||||
TFL::populateWithGenerated(ctx, &patterns);
|
TFL::populateWithGenerated(ctx, &patterns);
|
||||||
patterns.insert<FuseFullyConnectedAndAdd, FuseFullyConnectedAndRelu,
|
patterns.insert<FuseFullyConnectedAndAdd,
|
||||||
|
FuseFullyConnectedAndReluX<TFL::ReluOp, kRelu>,
|
||||||
|
FuseFullyConnectedAndReluX<TFL::Relu6Op, kRelu6>,
|
||||||
|
FuseFullyConnectedAndReluX<TFL::Relu1Op, kRelu1>,
|
||||||
FuseFullyConnectedAndMul>(ctx);
|
FuseFullyConnectedAndMul>(ctx);
|
||||||
applyPatternsGreedily(func, patterns);
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
|
|
||||||
// Fuse the binary ops with the following ops.
|
// Fuse the binary ops with the following ops.
|
||||||
patterns.insert<
|
patterns.insert<
|
||||||
@ -719,7 +725,7 @@ void Optimize::runOnFunction() {
|
|||||||
FuseBinaryOpToFollowingFullyConnected, FuseConv2DAndMulWithQDQs,
|
FuseBinaryOpToFollowingFullyConnected, FuseConv2DAndMulWithQDQs,
|
||||||
FuseDepthwiseConv2DAndMulWithQDQs, ConvertTrivialTransposeOpToReshapeOp>(
|
FuseDepthwiseConv2DAndMulWithQDQs, ConvertTrivialTransposeOpToReshapeOp>(
|
||||||
ctx);
|
ctx);
|
||||||
applyPatternsGreedily(func, patterns);
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -187,7 +187,7 @@ void OptimizeFunctionalOpsPass::runOnOperation() {
|
|||||||
patterns.insert<FoldIfOp>(&getContext(), &inlined_funcs);
|
patterns.insert<FoldIfOp>(&getContext(), &inlined_funcs);
|
||||||
|
|
||||||
ModuleOp module = getOperation();
|
ModuleOp module = getOperation();
|
||||||
applyPatternsGreedily(module, patterns);
|
applyPatternsAndFoldGreedily(module, patterns);
|
||||||
|
|
||||||
// Erase inlined functions that don't have any references.
|
// Erase inlined functions that don't have any references.
|
||||||
//
|
//
|
||||||
|
@ -378,6 +378,19 @@ foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
|
|||||||
(IsTailOfShape $rhs, $input)]>;
|
(IsTailOfShape $rhs, $input)]>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reorder the element-wise value operations and the element move operations,
|
||||||
|
// such that the value operation happens before move operation.
|
||||||
|
foreach ValueOp = [TFL_CeilOp, TFL_ExpOp, TFL_FloorOp, TFL_NegOp,
|
||||||
|
TFL_ReluOp, TFL_Relu1Op, TFL_Relu6Op, TFL_RoundOp,
|
||||||
|
TFL_TanhOp, TFL_SqrtOp, TFL_SquareOp] in {
|
||||||
|
foreach MoveOp = [TFL_DepthToSpaceOp, TFL_ExpandDimsOp, TFL_SqueezeOp,
|
||||||
|
TFL_ReshapeOp, TFL_TransposeOp] in {
|
||||||
|
def : Pat<(ValueOp:$value (MoveOp:$move $input, $move_def)),
|
||||||
|
(MoveOp (ValueOp $input), $move_def),
|
||||||
|
[(HasOneUse $move)]>;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Returns shape of a ranked tensor.
|
// Returns shape of a ranked tensor.
|
||||||
// if called without a ranked tensor it will fail.
|
// if called without a ranked tensor it will fail.
|
||||||
def GetShape: NativeCodeCall<"GetShape($0)">;
|
def GetShape: NativeCodeCall<"GetShape($0)">;
|
||||||
@ -394,8 +407,9 @@ def : Pat<(TFL_ExpandDimsOp:$expand_dims_op $input, $dim),
|
|||||||
(ConstantOp (GetShape $expand_dims_op))),
|
(ConstantOp (GetShape $expand_dims_op))),
|
||||||
[(AnyStaticShapeTensor $expand_dims_op)]>;
|
[(AnyStaticShapeTensor $expand_dims_op)]>;
|
||||||
|
|
||||||
class ValueEquals<string val> : Constraint<CPred<
|
class FloatValueEquals<string val> : Constraint<CPred<
|
||||||
"$0.cast<DenseElementsAttr>().getNumElements() == 1 &&"
|
"$0.cast<DenseElementsAttr>().getNumElements() == 1 &&"
|
||||||
|
"$0.isa<DenseFPElementsAttr>() &&"
|
||||||
"*$0.cast<DenseElementsAttr>().getValues<float>().begin() == " # val>>;
|
"*$0.cast<DenseElementsAttr>().getValues<float>().begin() == " # val>>;
|
||||||
|
|
||||||
// ReLU patterns
|
// ReLU patterns
|
||||||
@ -403,13 +417,13 @@ def : Pat<(TFL_MinimumOp (TFL_MaximumOp $input,
|
|||||||
(ConstantOp $NegOne)),
|
(ConstantOp $NegOne)),
|
||||||
(ConstantOp $One)),
|
(ConstantOp $One)),
|
||||||
(TFL_Relu1Op $input),
|
(TFL_Relu1Op $input),
|
||||||
[(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>;
|
[(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>;
|
||||||
|
|
||||||
def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input,
|
def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input,
|
||||||
(ConstantOp $One)),
|
(ConstantOp $One)),
|
||||||
(ConstantOp $NegOne)),
|
(ConstantOp $NegOne)),
|
||||||
(TFL_Relu1Op $input),
|
(TFL_Relu1Op $input),
|
||||||
[(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>;
|
[(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>;
|
||||||
|
|
||||||
def : Pat<(TFL_MaximumOp (TFL_MulOp:$mul_out $input1,
|
def : Pat<(TFL_MaximumOp (TFL_MulOp:$mul_out $input1,
|
||||||
(ConstantOp F32ElementsAttr:$alpha), TFL_AF_None),
|
(ConstantOp F32ElementsAttr:$alpha), TFL_AF_None),
|
||||||
|
@ -125,7 +125,7 @@ void PostQuantizePass::runOnFunction() {
|
|||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
auto* ctx = func.getContext();
|
auto* ctx = func.getContext();
|
||||||
TFL::populateWithGenerated(ctx, &patterns);
|
TFL::populateWithGenerated(ctx, &patterns);
|
||||||
applyPatternsGreedily(func, patterns);
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
|
|
||||||
if (!emit_quant_adaptor_ops_) {
|
if (!emit_quant_adaptor_ops_) {
|
||||||
RemoveQuantizationAdaptorOps(getFunction());
|
RemoveQuantizationAdaptorOps(getFunction());
|
||||||
|
@ -267,7 +267,7 @@ void PrepareQuantizePass::runOnFunction() {
|
|||||||
// Currently, only activation stats are imported, so narrow_range = false.
|
// Currently, only activation stats are imported, so narrow_range = false.
|
||||||
patterns.insert<PrepareQuantStats>(8, false, false, ctx);
|
patterns.insert<PrepareQuantStats>(8, false, false, ctx);
|
||||||
}
|
}
|
||||||
applyPatternsGreedily(func, patterns);
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
|
|
||||||
SanityCheckAndAdjustment(func);
|
SanityCheckAndAdjustment(func);
|
||||||
|
|
||||||
|
@ -46,7 +46,6 @@ limitations under the License.
|
|||||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
#include "mlir/Support/Functional.h" // from @llvm-project
|
|
||||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||||
@ -322,9 +321,10 @@ class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
|
|||||||
|
|
||||||
// Create tensor type for the transpose result.
|
// Create tensor type for the transpose result.
|
||||||
auto filter_type = filter.getType().cast<RankedTensorType>();
|
auto filter_type = filter.getType().cast<RankedTensorType>();
|
||||||
auto result_shape = functional::map(
|
auto result_shape =
|
||||||
[filter_type](int64_t dim) { return filter_type.getDimSize(dim); },
|
llvm::to_vector<4>(llvm::map_range(perm, [filter_type](int64_t dim) {
|
||||||
perm);
|
return filter_type.getDimSize(dim);
|
||||||
|
}));
|
||||||
auto elem_type = filter_type.getElementType();
|
auto elem_type = filter_type.getElementType();
|
||||||
auto result_type = RankedTensorType::get(result_shape, elem_type);
|
auto result_type = RankedTensorType::get(result_shape, elem_type);
|
||||||
|
|
||||||
@ -619,8 +619,8 @@ void PrepareTFPass::runOnFunction() {
|
|||||||
|
|
||||||
// This pattern was intented to uses TFL QDQs to preserve the quantization
|
// This pattern was intented to uses TFL QDQs to preserve the quantization
|
||||||
// parameters from the TF Quant ops, thus this pattern should run with the
|
// parameters from the TF Quant ops, thus this pattern should run with the
|
||||||
// first `applyPatternsGreedily` method, which would otherwise removes the
|
// first `applyPatternsAndFoldGreedily` method, which would otherwise removes
|
||||||
// TF FakeQuant ops by the constant folding.
|
// the TF FakeQuant ops by the constant folding.
|
||||||
patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant>(ctx);
|
patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant>(ctx);
|
||||||
|
|
||||||
// This pattern will try to identify and optimize for dilated convolution.
|
// This pattern will try to identify and optimize for dilated convolution.
|
||||||
@ -634,7 +634,7 @@ void PrepareTFPass::runOnFunction() {
|
|||||||
// This will allow optimizing any TF_Mul->TF_Conv in the graph
|
// This will allow optimizing any TF_Mul->TF_Conv in the graph
|
||||||
// and any expanded from FusedBatchNorm. We need to do this
|
// and any expanded from FusedBatchNorm. We need to do this
|
||||||
// before converting TF_Conv to TFL_Conv
|
// before converting TF_Conv to TFL_Conv
|
||||||
applyPatternsGreedily(func, patterns);
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
|
|
||||||
// Load the generated pattern again, so new quantization pass-through
|
// Load the generated pattern again, so new quantization pass-through
|
||||||
// will be applied.
|
// will be applied.
|
||||||
@ -646,7 +646,7 @@ void PrepareTFPass::runOnFunction() {
|
|||||||
}
|
}
|
||||||
patterns.insert<TF::ConvertTFEinsumOp, ConvertTFConv2D,
|
patterns.insert<TF::ConvertTFEinsumOp, ConvertTFConv2D,
|
||||||
ConvertTFDepthwiseConv2dNative, ConvertTFStridedSlice>(ctx);
|
ConvertTFDepthwiseConv2dNative, ConvertTFStridedSlice>(ctx);
|
||||||
applyPatternsGreedily(func, patterns);
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -29,7 +29,6 @@ limitations under the License.
|
|||||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
#include "mlir/Support/Functional.h" // from @llvm-project
|
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||||
@ -88,7 +87,7 @@ void QuantizePass::runOnFunction() {
|
|||||||
TFL::populateWithGenerated(ctx, &patterns);
|
TFL::populateWithGenerated(ctx, &patterns);
|
||||||
patterns.insert<TFLFullQuantization>(
|
patterns.insert<TFLFullQuantization>(
|
||||||
ctx, enable_numeric_verify, error_tolerance, enable_single_layer_verify);
|
ctx, enable_numeric_verify, error_tolerance, enable_single_layer_verify);
|
||||||
applyPatternsGreedily(func, patterns);
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -94,9 +94,10 @@ Value Transpose(OpBuilder* builder, Value value_to_transpose,
|
|||||||
|
|
||||||
// Create tensor type for the transpose result.
|
// Create tensor type for the transpose result.
|
||||||
auto transpose_type = original_type;
|
auto transpose_type = original_type;
|
||||||
auto transpose_shape = functional::map(
|
auto transpose_shape =
|
||||||
[transpose_type](int32_t dim) { return transpose_type.getDimSize(dim); },
|
llvm::to_vector<8>(llvm::map_range(perm, [transpose_type](int32_t dim) {
|
||||||
perm);
|
return transpose_type.getDimSize(dim);
|
||||||
|
}));
|
||||||
auto elem_type = transpose_type.getElementType();
|
auto elem_type = transpose_type.getElementType();
|
||||||
auto result_type = RankedTensorType::get(transpose_shape, elem_type);
|
auto result_type = RankedTensorType::get(transpose_shape, elem_type);
|
||||||
|
|
||||||
|
@ -127,6 +127,7 @@ Status MlirFunctionOptimizationPass::Run(
|
|||||||
GraphImportConfig import_config;
|
GraphImportConfig import_config;
|
||||||
import_config.graph_as_function = true;
|
import_config.graph_as_function = true;
|
||||||
import_config.control_outputs = *control_ret_node_names;
|
import_config.control_outputs = *control_ret_node_names;
|
||||||
|
import_config.upgrade_legacy = true;
|
||||||
TF_ASSIGN_OR_RETURN(auto module_ref,
|
TF_ASSIGN_OR_RETURN(auto module_ref,
|
||||||
ConvertGraphToMlir(**graph, debug_info, *flib_def,
|
ConvertGraphToMlir(**graph, debug_info, *flib_def,
|
||||||
import_config, &context));
|
import_config, &context));
|
||||||
@ -149,7 +150,6 @@ Status MlirFunctionOptimizationPass::Run(
|
|||||||
}
|
}
|
||||||
|
|
||||||
GraphExportConfig export_config;
|
GraphExportConfig export_config;
|
||||||
export_config.graph_as_function = true;
|
|
||||||
absl::flat_hash_set<Node*> control_ret_nodes;
|
absl::flat_hash_set<Node*> control_ret_nodes;
|
||||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||||
ConvertMlirToGraph(*module_ref, export_config, graph, flib_def,
|
ConvertMlirToGraph(*module_ref, export_config, graph, flib_def,
|
||||||
|
@ -71,7 +71,8 @@ tool_dirs = config.mlir_tf_tools_dirs + [
|
|||||||
tool_names = [
|
tool_names = [
|
||||||
'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
|
'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
|
||||||
'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate',
|
'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate',
|
||||||
'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer'
|
'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer', 'xla-gpu-opt',
|
||||||
|
'xla-opt'
|
||||||
]
|
]
|
||||||
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
|
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
|
||||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||||
|
@ -45,7 +45,8 @@ mlir_tf_tools_dirs = [
|
|||||||
'tensorflow/compiler/mlir/lite',
|
'tensorflow/compiler/mlir/lite',
|
||||||
'tensorflow/compiler/mlir/tensorflow',
|
'tensorflow/compiler/mlir/tensorflow',
|
||||||
'tensorflow/compiler/mlir/xla',
|
'tensorflow/compiler/mlir/xla',
|
||||||
'tensorflow/compiler/aot'
|
'tensorflow/compiler/aot',
|
||||||
|
'tensorflow/compiler/xla/service/mlir_gpu',
|
||||||
]
|
]
|
||||||
config.mlir_tf_tools_dirs = [
|
config.mlir_tf_tools_dirs = [
|
||||||
os.path.join(real_test_srcdir, os.environ['TEST_WORKSPACE'], s)
|
os.path.join(real_test_srcdir, os.environ['TEST_WORKSPACE'], s)
|
||||||
|
@ -1292,6 +1292,45 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "dump_graph",
|
||||||
|
srcs = ["utils/dump_graph.cc"],
|
||||||
|
hdrs = ["utils/dump_graph.h"],
|
||||||
|
deps = [
|
||||||
|
":convert_graphdef",
|
||||||
|
":error_util",
|
||||||
|
":tensorflow",
|
||||||
|
":tensorflow_dialect_registration",
|
||||||
|
":tensorflow_passes",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:graph",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core/platform:logging",
|
||||||
|
"@llvm-project//llvm:support",
|
||||||
|
"@llvm-project//mlir:AllPassesAndDialects",
|
||||||
|
"@llvm-project//mlir:Analysis",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//mlir:Pass",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "dump_graph_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["utils/dump_graph_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":dump_graph",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:graph",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/platform:test",
|
||||||
|
"@llvm-project//llvm:support",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "bridge_logger",
|
name = "bridge_logger",
|
||||||
srcs = ["utils/bridge_logger.cc"],
|
srcs = ["utils/bridge_logger.cc"],
|
||||||
|
@ -40,7 +40,6 @@ limitations under the License.
|
|||||||
#include "mlir/IR/Value.h" // from @llvm-project
|
#include "mlir/IR/Value.h" // from @llvm-project
|
||||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||||
#include "mlir/Support/STLExtras.h" // from @llvm-project
|
|
||||||
#include "mlir/Transforms/InliningUtils.h" // from @llvm-project
|
#include "mlir/Transforms/InliningUtils.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
@ -90,7 +89,7 @@ struct TFInlinerInterface : public DialectInlinerInterface {
|
|||||||
// are perfectly forwarded to the block's terminator.
|
// are perfectly forwarded to the block's terminator.
|
||||||
bool BlockWrapsSingleOp(Block* block) {
|
bool BlockWrapsSingleOp(Block* block) {
|
||||||
auto body = block->without_terminator();
|
auto body = block->without_terminator();
|
||||||
if (!has_single_element(body)) return false;
|
if (!hasSingleElement(body)) return false;
|
||||||
|
|
||||||
Operation& wrapped_op = *body.begin();
|
Operation& wrapped_op = *body.begin();
|
||||||
Operation* terminator = block->getTerminator();
|
Operation* terminator = block->getTerminator();
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user