Merge branch 'master' into master

This commit is contained in:
Daniel Situnayake 2020-04-20 10:33:20 -07:00 committed by Daniel Situnayake
commit 7ba6bca1d9
No known key found for this signature in database
GPG Key ID: C26BBAA056CC4A6C
1083 changed files with 38870 additions and 11309 deletions

View File

@ -73,6 +73,10 @@
# tensorflow_testing_rbe_linux: RBE options to use RBE with tensorflow-testing project on linux
# tensorflow_testing_rbe_win: RBE options to use RBE with tensorflow-testing project on windows
#
# Embedded Linux options (experimental and only tested with TFLite build yet)
# elinux: General Embedded Linux options shared by all flavors.
# elinux_aarch64: Embedded Linux options for aarch64 (ARM64) CPU support.
# elinux_armhf: Embedded Linux options for armhf (ARMv7) CPU support.
@ -432,6 +436,14 @@ build:tensorflow_testing_rbe_linux --config=rbe_linux
common:tensorflow_testing_rbe_win --remote_instance_name=projects/tensorflow-testing/instances/windows
build:tensorflow_testing_rbe_win --config=tensorflow_testing_rbe
# TFLite build configs for generic embedded Linux
build:elinux --crosstool_top=@local_config_embedded_arm//:toolchain
build:elinux --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
build:elinux_aarch64 --config=elinux
build:elinux_aarch64 --cpu=aarch64
build:elinux_armhf --config=elinux
build:elinux_armhf --cpu=armhf
# END TF REMOTE BUILD EXECUTION OPTIONS
# Default options should come above this line

View File

@ -1 +1 @@
2.0.0
3.0.0

View File

@ -11,25 +11,23 @@ we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:bug_template</em>
**System information**
- Have I written custom code (as opposed to using a stock
example script provided in TensorFlow):
- OS Platform and Distribution (e.g.,
Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
the issue happens on mobile device:
- TensorFlow installed from (source or
binary): - TensorFlow version (use command below):
- Python version: - Bazel
version (if compiling from source):
- GCC/Compiler version (if compiling from
source):
- CUDA/cuDNN version: - GPU model and memory:
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
- TensorFlow installed from (source or binary):
- TensorFlow version (use command below):
- Python version:
- Bazel version (if compiling from source):
- GCC/Compiler version (if compiling from source):
- CUDA/cuDNN version:
- GPU model and memory:
You can collect some of this information using our environment capture
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
You can also obtain the TensorFlow version with:
1. TF 1.0: `python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"`
2. TF 2.0: `python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
**Describe the current behavior**

View File

@ -12,25 +12,22 @@ we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:performance_template</em>
**System information**
- Have I written custom code (as opposed to using a stock
example script provided in TensorFlow):
- OS Platform and Distribution (e.g.,
Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
the issue happens on mobile device:
- TensorFlow installed from (source or
binary): - TensorFlow version (use command below):
- Python version: - Bazel
version (if compiling from source):
- GCC/Compiler version (if compiling from
source):
- CUDA/cuDNN version: - GPU model and memory:
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
- TensorFlow installed from (source or binary):
- TensorFlow version (use command below):
- Python version:
- Bazel version (if compiling from source):
- GCC/Compiler version (if compiling from source):
- CUDA/cuDNN version:
- GPU model and memory:
You can collect some of this information using our environment capture
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
You can also obtain the TensorFlow version with:
1. TF 1.0: `python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"`
2. TF 2.0: `python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
**Describe the current behavior**

39
.github/stale.yml vendored Normal file
View File

@ -0,0 +1,39 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
#
# THIS IS A GENERATED DOCKERFILE.
#
# This file was assembled from multiple pieces, whose use is documented
# throughout. Please refer to the TensorFlow dockerfiles documentation
# for more information.
# Number of days of inactivity before an Issue or Pull Request becomes stale
daysUntilStale: 7
# Number of days of inactivity before a stale Issue or Pull Request is closed
daysUntilClose: 7
# Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable
onlyLabels:
- awaitingResponse
# Comment to post when marking as stale. Set to `false` to disable
markComment: >
This issue has been automatically marked as stale because it has not had
recent activity. It will be closed if no further activity occurs. Thank you.
# Comment to post when removing the stale label. Set to `false` to disable
unmarkComment: false
closeComment: >
Closing as stale. Please reopen if you'd like to work on this further.
limitPerRun: 30
# Limit to only `issues` or `pulls`
only: issues

2
configure vendored
View File

@ -4,7 +4,7 @@ set -e
set -o pipefail
if [ -z "$PYTHON_BIN_PATH" ]; then
PYTHON_BIN_PATH=$(which python || which python3 || true)
PYTHON_BIN_PATH=$(which python3 || which python || true)
fi
# Set all env variables

View File

@ -50,7 +50,7 @@ _TF_WORKSPACE_ROOT = ''
_TF_BAZELRC = ''
_TF_CURRENT_BAZEL_VERSION = None
_TF_MIN_BAZEL_VERSION = '2.0.0'
_TF_MAX_BAZEL_VERSION = '2.0.0'
_TF_MAX_BAZEL_VERSION = '3.99.0'
NCCL_LIB_PATHS = [
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
@ -58,8 +58,6 @@ NCCL_LIB_PATHS = [
# List of files to configure when building Bazel on Apple platforms.
APPLE_BAZEL_FILES = [
'tensorflow/lite/experimental/delegates/coreml/BUILD',
'tensorflow/lite/experimental/delegates/coreml/builders/BUILD',
'tensorflow/lite/experimental/ios/BUILD',
'tensorflow/lite/experimental/objc/BUILD',
'tensorflow/lite/experimental/swift/BUILD',

View File

@ -214,6 +214,12 @@ config_setting(
visibility = ["//visibility:public"],
)
config_setting(
name = "linux_armhf",
values = {"cpu": "armhf"},
visibility = ["//visibility:public"],
)
config_setting(
name = "linux_x86_64",
values = {"cpu": "k8"},
@ -703,8 +709,8 @@ tf_cc_shared_object(
"//tensorflow/c:version_script.lds",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:distributed_tensorflow_dependencies",
"//tensorflow/core:tensorflow",
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
],
)

View File

@ -118,6 +118,12 @@ cc_library(
visibility = ["//visibility:public"],
)
cc_library(
name = "c_api_macros",
hdrs = ["c_api_macros.h"],
visibility = ["//visibility:public"],
)
tf_cuda_library(
name = "c_api",
hdrs = [

View File

@ -186,10 +186,6 @@ struct TF_Server {
namespace tensorflow {
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
TF_Buffer* out);

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_C_API_MACROS_H_
#define TENSORFLOW_C_C_API_MACROS_H_
#ifdef SWIG
#define TF_CAPI_EXPORT
#else
#if defined(_WIN32)
#ifdef TF_COMPILE_LIBRARY
#define TF_CAPI_EXPORT __declspec(dllexport)
#else
#define TF_CAPI_EXPORT __declspec(dllimport)
#endif // TF_COMPILE_LIBRARY
#else
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
#endif // _WIN32
#endif // SWIG
#endif // TENSORFLOW_C_C_API_MACROS_H_

View File

@ -240,11 +240,6 @@ tf_cuda_cc_test(
"c_api_remote_test.cc",
],
extra_copts = tfe_xla_copts(),
tags = [
"guitar",
"multi_gpu",
"no_oss",
],
deps = [
":c_api",
":c_api_experimental",

View File

@ -1587,6 +1587,7 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
// require TFE_Op* and just convert it internally a NameAttrValue, so
// consider adding an overload to the C API to make this case easier.
TFE_OpSetAttrFunction(op, attr_name, func_op);
TFE_DeleteOp(func_op);
} break;
case tensorflow::AttrValue::kList:
TF_FALLTHROUGH_INTENDED;

View File

@ -129,7 +129,45 @@ void TestRemoteExecute(bool async) {
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
void TestRemoteExecuteSilentCopies(bool async, bool remote) {
string MatMulFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {"
" name: 'MatMulFunction'"
" input_arg {"
" name: 'a'"
" type: DT_FLOAT"
" }"
" input_arg {"
" name: 'b'"
" type: DT_FLOAT"
" }"
" output_arg {"
" name: 'm'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'matmul'"
" op: 'MatMul'"
" input: 'a'"
" input: 'b'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'm'"
" value: 'matmul:product'"
" }",
&def));
return def.SerializeAsString();
}
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
@ -169,12 +207,36 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Op* matmul = nullptr;
if (func) {
string function_def = MatMulFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
matmul = TFE_NewOp(ctx, "MatMulFunction", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(matmul, h0_task0, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(matmul, h1_task2, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
} else {
// Handles are on task0 (local), and task2, but op is on task1.
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2);
matmul = MatMulOp(ctx, h0_task0, h1_task2);
}
if (remote) {
TFE_OpSetDevice(matmul, task1_name, status);
}
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
} else if (!async) {
// Set the local device to CPU to easily validate mirroring
string cpu_device_name;
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
// The input handles should never change since they have been mirrored.
ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr));
}
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
@ -182,12 +244,10 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TODO(gjn): Add support for waiting on async local mirrors
if (!async) {
if (!remote && !async) {
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
tensorflow::EagerOperation* op =
tensorflow::OperationFromInterface(matmul->operation);
// The input handles should never change since they have been mirrored.
ASSERT_EQ(op->Inputs()[1], remote_arg);
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
}
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
@ -217,6 +277,9 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteExecutor(executor);
if (func) {
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
}
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
@ -227,16 +290,22 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
}
TEST(CAPI, RemoteExecuteSilentCopies) {
TestRemoteExecuteSilentCopies(false, true);
TestRemoteExecuteSilentCopies(false, true, false);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
TestRemoteExecuteSilentCopies(true, true);
TestRemoteExecuteSilentCopies(true, true, false);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
TestRemoteExecuteSilentCopies(true, true, true);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
TestRemoteExecuteSilentCopies(false, false);
TestRemoteExecuteSilentCopies(false, false, false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
TestRemoteExecuteSilentCopies(true, false);
TestRemoteExecuteSilentCopies(true, false, false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
TestRemoteExecuteSilentCopies(true, false, true);
}
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {

View File

@ -78,11 +78,18 @@ void BM_Execute(int iters, int async) {
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = MatMulOp(ctx, m, m);
TFE_Op* matmul = TFE_NewOp(ctx, "MatMul", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
TFE_OpReset(matmul, "MatMul", nullptr, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(matmul, m, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(matmul, m, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
}
@ -113,11 +120,15 @@ void BM_Execute_Identity(int iters, int async) {
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_Op* identity = IdentityOp(ctx, m);
TFE_Op* identity = TFE_NewOp(ctx, "Identity", status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
TFE_OpReset(identity, "Identity", nullptr, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(identity, m, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Execute(identity, &retvals[0], &num_retvals, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
}
@ -405,6 +416,11 @@ void TensorHandleSilentCopy(bool async,
hcpu, ctx, gpu_device_name.c_str(), status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
auto cpu_arg = tensorflow::TensorHandleFromInterface(hcpu->handle);
auto gpu_arg = tensorflow::TensorHandleFromInterface(hgpu->handle);
auto gpu_device = absl::get<tensorflow::Device*>(gpu_arg->device());
ASSERT_FALSE(cpu_arg->HasLocalMirror(gpu_device));
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
if (cpu_op) {
string cpu_device_name;
@ -420,15 +436,8 @@ void TensorHandleSilentCopy(bool async,
TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
// Validate if the input was replaced with a different TensorHandle
auto arg0 = tensorflow::TensorHandleFromInterface(hcpu->handle);
auto arg1 = tensorflow::TensorHandleFromInterface(hgpu->handle);
tensorflow::EagerOperation* op =
tensorflow::OperationFromInterface(matmul->operation);
// The input handles should never change since they have been mirrored.
EXPECT_EQ(op->Inputs()[0], arg0);
EXPECT_EQ(op->Inputs()[1], arg1);
// The CPU handle should have been copied and have a mirror on the GPU
ASSERT_TRUE(cpu_arg->HasLocalMirror(gpu_device));
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(retvals[0]);
@ -626,17 +635,6 @@ void ExecuteAdd(bool async, bool forward_input, bool tfrt) {
}
int num_retvals = 1;
if (async) {
// Enqueue dummy ops so we backlog async execution & actually test async.
for (int i = 0; i < 10000; ++i) {
TFE_TensorHandle* dummy = nullptr;
TFE_Execute(add_op, &dummy, &num_retvals, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(dummy);
}
}
TFE_TensorHandle* retval = nullptr;
TFE_Execute(add_op, &retval, &num_retvals, status);
EXPECT_EQ(1, num_retvals);

View File

@ -38,97 +38,160 @@ typedef void (*ExecuteOperation)(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs,
TF_OutputList* o, TF_ExecutionContext* ctx,
TF_Status* s);
struct TF_ExecutionContext {
explicit TF_ExecutionContext() {}
absl::variant<TFE_Context*, TF_GraphContext*> ctx;
ExecuteOperation execution_callback;
};
// Needed to implement our own version of RTTI since dynamic_cast is not
// supported in mobile builds.
enum ExecutionContextKind { GraphContext, EagerContext };
explicit TF_ExecutionContext(ExecutionContextKind kind) : k(kind) {}
ExecutionContextKind getKind() const { return k; }
struct TF_AbstractTensor {
absl::variant<TFE_TensorHandle*, TF_GraphTensor*> t;
};
virtual void ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs,
TF_OutputList* o, TF_Status* s) = 0;
virtual TF_AbstractOp* CreateOperation() = 0;
virtual void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) = 0;
virtual ~TF_ExecutionContext() {}
struct TF_AbstractOp {
string op_type;
string op_name;
private:
const ExecutionContextKind k;
};
TF_ExecutionContext* TF_NewExecutionContext() {
return new TF_ExecutionContext();
}
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete c; }
TF_AbstractOp* TF_NewAbstractOp() {
TF_AbstractOp* op = new TF_AbstractOp;
return op;
template <typename T, typename S>
T* dynamic_cast_helper(S source) {
if (source->getKind() != T::kKind) {
return nullptr;
}
return tensorflow::down_cast<T*>(source);
}
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete op; }
TF_AbstractTensor* TF_NewAbstractTensor() {
TF_AbstractTensor* t = new TF_AbstractTensor;
return t;
}
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete t; }
struct TF_GraphContext {
TF_Graph* graph;
// TODO(srbs): Handle captures.
};
TF_GraphContext* TF_NewGraphContext(TF_Graph* g) {
auto ctx = new TF_GraphContext;
ctx->graph = g;
return ctx;
}
void TF_DeleteGraphContext(TF_GraphContext* ctx) { delete ctx; }
class TF_GraphContext;
class TF_EagerContext;
struct TF_GraphTensor {
TF_Output output;
TF_GraphContext* ctx;
};
TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* ctx, TF_Output output,
TF_Status* s) {
TF_GraphTensor* t = new TF_GraphTensor;
t->output = output;
t->ctx = ctx;
return t;
}
TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s) {
return t->output;
}
void TF_DeleteGraphTensor(TF_GraphTensor* t) { delete t; }
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
TF_Status* s) {
at->t = t;
}
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s) {
if (!absl::holds_alternative<TFE_TensorHandle*>(at->t)) {
string msg = absl::StrCat("Not an eager tensor handle.",
reinterpret_cast<uintptr_t>(at));
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
struct TF_AbstractTensor {
absl::variant<TFE_TensorHandle*, TF_GraphTensor*> t;
~TF_AbstractTensor() {
if (absl::holds_alternative<TFE_TensorHandle*>(t)) {
TFE_DeleteTensorHandle(absl::get<TFE_TensorHandle*>(t));
} else if (absl::holds_alternative<TF_GraphTensor*>(t)) {
delete absl::get<TF_GraphTensor*>(t);
}
return absl::get<TFE_TensorHandle*>(at->t);
}
void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t,
TF_Status* s) {
at->t = t;
}
TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at,
TF_Status* s) {
if (!absl::holds_alternative<TF_GraphTensor*>(at->t)) {
string msg = absl::StrCat("Not an graph tensor handle.");
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
}
return absl::get<TF_GraphTensor*>(at->t);
};
struct TF_AbstractOp {
// Needed to implement our own version of RTTI since dynamic_cast is not
// supported in mobile builds.
enum AbstractOpKind { GraphOp, EagerOp };
explicit TF_AbstractOp(AbstractOpKind kind) : k(kind) {}
AbstractOpKind getKind() const { return k; }
virtual void SetOpType(const char* const op_type, TF_Status* s) = 0;
virtual void SetOpName(const char* const op_name, TF_Status* s) = 0;
virtual void SetAttrType(const char* const attr_name, TF_DataType value,
TF_Status* s) = 0;
virtual ~TF_AbstractOp() {}
private:
const AbstractOpKind k;
};
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
return c->CreateOperation();
}
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete op; }
class TF_GraphOp : public TF_AbstractOp {
public:
explicit TF_GraphOp(TF_Graph* g) : TF_AbstractOp(kKind), g_(g) {}
void SetOpType(const char* const op_type, TF_Status* s) override {
if (op_) {
TF_SetStatus(
s, TF_FAILED_PRECONDITION,
absl::StrCat("SetOpType called on already built op.").c_str());
return;
}
if (op_name_ != nullptr) {
op_.reset(TF_NewOperation(g_, op_type, op_name_));
op_name_ = nullptr;
} else {
op_type_ = op_type;
}
}
void SetOpName(const char* const op_name, TF_Status* s) override {
if (op_) {
TF_SetStatus(
s, TF_FAILED_PRECONDITION,
absl::StrCat("SetOpName called on already built op.").c_str());
return;
}
if (op_type_ != nullptr) {
op_.reset(TF_NewOperation(g_, op_type_, op_name));
op_type_ = nullptr;
} else {
op_name_ = op_name;
}
}
void SetAttrType(const char* const attr_name, TF_DataType value,
TF_Status* s) override {
if (!op_) {
TF_SetStatus(
s, TF_FAILED_PRECONDITION,
"op_type and op_name must be specified before specifying attrs.");
return;
}
TF_SetAttrType(op_.get(), attr_name, value);
}
~TF_GraphOp() override {}
static constexpr AbstractOpKind kKind = GraphOp;
private:
friend class TF_GraphContext; // For access to op_.
TF_Graph* g_;
std::unique_ptr<TF_OperationDescription> op_;
// Hold `op_type` and `op_name` till both are available since we need both
// to build a graph operation.
const char* op_type_ = nullptr;
const char* op_name_ = nullptr;
};
class TF_EagerOp : public TF_AbstractOp {
public:
explicit TF_EagerOp(TFE_Context* ctx) : TF_AbstractOp(kKind), ctx_(ctx) {}
void SetOpType(const char* const op_type, TF_Status* s) override {
op_ = TFE_NewOp(ctx_, op_type, s);
}
void SetOpName(const char* const op_name, TF_Status* s) override {
// Name is ignored in eager mode.
}
void SetAttrType(const char* const attr_name, TF_DataType value,
TF_Status* s) override {
if (op_ == nullptr) {
TF_SetStatus(s, TF_FAILED_PRECONDITION,
"op_type must be specified before specifying attrs.");
return;
}
TFE_OpSetAttrType(op_, attr_name, value);
}
~TF_EagerOp() override { TFE_DeleteOp(op_); }
static constexpr AbstractOpKind kKind = EagerOp;
private:
friend class TF_EagerContext; // For access to op_.
TFE_Op* op_ = nullptr;
TFE_Context* ctx_;
};
bool IsEagerTensor(const TF_AbstractTensor* const t) {
return absl::holds_alternative<TFE_TensorHandle*>(t->t);
}
@ -138,22 +201,35 @@ struct TF_OutputList {
int expected_num_outputs = -1;
};
TF_OutputList* TF_NewOutputList() { return new TF_OutputList; }
void TF_DeleteOutputList(TF_OutputList* o) { delete o; }
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
TF_Status* s) {
o->expected_num_outputs = num_outputs;
}
int TF_OutputListNumOutputs(TF_OutputList* o) { return o->outputs.size(); }
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
return o->outputs[i];
}
struct TF_AbstractFunction {
TF_Function* func = nullptr;
void ExecuteOperationEager(TF_AbstractOp* op, int num_inputs,
~TF_AbstractFunction() { TF_DeleteFunction(func); }
};
class TF_EagerContext : public TF_ExecutionContext {
public:
TF_EagerContext() : TF_ExecutionContext(kKind) {}
void Build(TFE_ContextOptions* options, TF_Status* status) {
eager_ctx_ = TFE_NewContext(options, status);
}
TF_AbstractOp* CreateOperation() override {
// TODO(srbs): Should the lifetime of this op be tied to the context.
return new TF_EagerOp(eager_ctx_);
}
void ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s) {
auto* tfe_op =
TFE_NewOp(absl::get<TFE_Context*>(ctx->ctx), op->op_type.c_str(), s);
TF_Status* s) override {
auto* eager_op = dynamic_cast_helper<TF_EagerOp>(op);
if (eager_op == nullptr) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Unable to cast TF_AbstractOp to TF_EagerOp.");
return;
}
auto* tfe_op = eager_op->op_;
if (TF_GetCode(s) != TF_OK) return;
for (int i = 0; i < num_inputs; ++i) {
if (!IsEagerTensor(inputs[i])) {
@ -174,30 +250,58 @@ void ExecuteOperationEager(TF_AbstractOp* op, int num_inputs,
int num_retvals = o->expected_num_outputs;
retvals.resize(num_retvals);
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
TFE_DeleteOp(tfe_op);
if (TF_GetCode(s) != TF_OK) {
return;
}
o->outputs.clear();
o->outputs.reserve(num_retvals);
for (int i = 0; i < num_retvals; ++i) {
auto* t = TF_NewAbstractTensor();
auto* t = new TF_AbstractTensor();
t->t = retvals[i];
o->outputs.push_back(t);
}
}
}
void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) override {
TFE_ContextAddFunction(eager_ctx_, func->func, s);
}
~TF_EagerContext() override { TFE_DeleteContext(eager_ctx_); }
static constexpr ExecutionContextKind kKind = EagerContext;
private:
friend TFE_Context* TF_ExecutionContextGetTFEContext(
TF_ExecutionContext* ctx);
TFE_Context* eager_ctx_;
};
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete t; }
TF_GraphContext* GetGraphContext(TF_AbstractTensor const* t) {
return absl::get<TF_GraphTensor*>(t->t)->ctx;
}
void ExecuteOperationGraph(TF_AbstractOp* op, int num_inputs,
class TF_GraphContext : public TF_ExecutionContext {
public:
TF_GraphContext()
: TF_ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {}
TF_AbstractOp* CreateOperation() override {
// TODO(srbs): Should the lifetime of this op be tied to the context.
return new TF_GraphOp(graph_.get());
}
void ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s) {
TF_GraphContext* graph_ctx = absl::get<TF_GraphContext*>(ctx->ctx);
TF_Graph* g = graph_ctx->graph;
auto* tf_opdesc =
TF_NewOperation(g, op->op_type.c_str(), op->op_name.c_str());
TF_Status* s) override {
auto* graph_op = dynamic_cast_helper<TF_GraphOp>(op);
if (graph_op == nullptr) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Unable to cast TF_AbstractOp to TF_GraphOp.");
return;
}
auto* tf_opdesc = graph_op->op_.release();
for (int i = 0; i < num_inputs; ++i) {
auto* input = inputs[i];
if (IsEagerTensor(input)) {
@ -205,7 +309,7 @@ void ExecuteOperationGraph(TF_AbstractOp* op, int num_inputs,
"Capturing eager tensors is not supported yet.");
return;
} else {
if (GetGraphContext(input) != graph_ctx) {
if (GetGraphContext(input) != this) {
TF_SetStatus(
s, TF_INVALID_ARGUMENT,
"Capturing tensors from other graphs is not supported yet.");
@ -215,47 +319,182 @@ void ExecuteOperationGraph(TF_AbstractOp* op, int num_inputs,
}
}
auto* operation = TF_FinishOperation(tf_opdesc, s);
// TF_FinishOperation deletes `tf_opdesc` so clear its reference.
graph_op->op_ = nullptr;
if (TF_GetCode(s) != TF_OK) return;
int num_outputs = TF_OperationNumOutputs(operation);
o->outputs.clear();
o->outputs.reserve(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
auto* t = TF_NewAbstractTensor();
TF_GraphTensor* output_t = TF_NewGraphTensor(graph_ctx, {operation, i}, s);
if (TF_GetCode(s) != TF_OK) {
return;
}
t->t = output_t;
auto* t = new TF_AbstractTensor;
TF_GraphTensor* graph_t = new TF_GraphTensor;
graph_t->ctx = this;
graph_t->output = {operation, i};
t->t = graph_t;
o->outputs.push_back(t);
}
}
TF_Function* ToFunction(const char* fn_name, int num_inputs,
const TF_AbstractTensor* inputs, int num_outputs,
const TF_AbstractTensor* outputs,
TF_Status* status) const {
std::vector<TF_Output> graph_inputs;
graph_inputs.resize(num_inputs);
std::vector<TF_Output> graph_outputs;
graph_outputs.resize(num_outputs);
for (int i = 0; i < num_inputs; i++) {
graph_inputs[i] = absl::get<TF_GraphTensor*>(inputs[i].t)->output;
}
for (int i = 0; i < num_outputs; i++) {
graph_outputs[i] = absl::get<TF_GraphTensor*>(outputs[i].t)->output;
}
return TF_GraphToFunction(graph_.get(), fn_name, 0, -1, nullptr,
graph_inputs.size(), graph_inputs.data(),
graph_outputs.size(), graph_outputs.data(),
nullptr, nullptr, fn_name, status);
}
void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) override {
TF_SetStatus(s, TF_UNIMPLEMENTED,
"Registering graph functions has not been implemented yet.");
}
~TF_GraphContext() override {}
static constexpr ExecutionContextKind kKind = GraphContext;
private:
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
};
struct TF_GraphContextOptions {};
struct TF_EagerContextOptions {
explicit TF_EagerContextOptions(TFE_ContextOptions* options)
: options(options) {}
TFE_ContextOptions* options; // Not owned.
};
struct TF_ExecutionContextOptions {
absl::variant<TF_GraphContextOptions*, TF_EagerContextOptions*> options;
~TF_ExecutionContextOptions() {
if (absl::holds_alternative<TF_GraphContextOptions*>(options)) {
delete absl::get<TF_GraphContextOptions*>(options);
} else if (absl::holds_alternative<TF_EagerContextOptions*>(options)) {
delete absl::get<TF_EagerContextOptions*>(options);
}
}
};
TF_ExecutionContextOptions* TF_NewGraphContextOptions() {
auto* options = new TF_ExecutionContextOptions();
options->options = new TF_GraphContextOptions();
return options;
}
void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context,
TFE_Context* eager_context,
TF_Status* s) {
context->ctx = eager_context;
context->execution_callback = &ExecuteOperationEager;
void TF_DeleteExecutionContextOptions(TF_ExecutionContextOptions* options) {
delete options;
}
void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context,
TF_GraphContext* graph_context,
TF_ExecutionContextOptions* TF_NewEagerContextOptions(
TFE_ContextOptions* tfe_options) {
auto* options = new TF_ExecutionContextOptions();
options->options = new TF_EagerContextOptions(tfe_options);
return options;
}
TF_ExecutionContext* TF_NewExecutionContext(TF_ExecutionContextOptions* options,
TF_Status* s) {
context->ctx = graph_context;
context->execution_callback = &ExecuteOperationGraph;
if (absl::holds_alternative<TF_EagerContextOptions*>(options->options)) {
auto* ctx = new TF_EagerContext();
ctx->Build(absl::get<TF_EagerContextOptions*>(options->options)->options,
s);
return ctx;
} else {
return new TF_GraphContext();
}
}
TF_OutputList* TF_NewOutputList() { return new TF_OutputList; }
void TF_DeleteOutputList(TF_OutputList* o) { delete o; }
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
TF_Status* s) {
o->expected_num_outputs = num_outputs;
}
int TF_OutputListNumOutputs(TF_OutputList* o) { return o->outputs.size(); }
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
return o->outputs[i];
}
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
TF_Status* s) {
op->op_type = op_type;
op->SetOpType(op_type, s);
}
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
TF_Status* s) {
op->op_name = op_name;
op->SetOpName(op_name, s);
}
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
TF_DataType value, TF_Status* s) {
op->SetAttrType(attr_name, value, s);
}
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s) {
ctx->execution_callback(op, num_inputs, inputs, o, ctx, s);
ctx->ExecuteOperation(op, num_inputs, inputs, o, s);
}
TF_AbstractFunction* TF_ExecutionContextToFunction(
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const TF_AbstractTensor* inputs, int num_outputs,
const TF_AbstractTensor* outputs, TF_Status* status) {
auto* graph_ctx = dynamic_cast_helper<const TF_GraphContext>(fn_body);
if (graph_ctx == nullptr) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"fn_body is not a TF_GraphContext.");
return nullptr;
}
TF_AbstractFunction* func = new TF_AbstractFunction;
func->func = graph_ctx->ToFunction(fn_name, num_inputs, inputs, num_outputs,
outputs, status);
return func;
}
void TF_DeleteAbstractFunction(TF_AbstractFunction* func) { delete func; }
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext* ctx,
TF_AbstractFunction* func,
TF_Status* s) {
ctx->RegisterFunction(func, s);
}
// Temporary APIs till we figure out how to create scalar valued Eager
// tensors and how to get value out of eager abstract tensors.
TF_AbstractTensor* TF_NewAbstractTensor() {
TF_AbstractTensor* t = new TF_AbstractTensor;
return t;
}
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
TF_Status* s) {
at->t = t;
}
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s) {
if (!absl::holds_alternative<TFE_TensorHandle*>(at->t)) {
string msg = absl::StrCat("Not an eager tensor handle.",
reinterpret_cast<uintptr_t>(at));
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
}
return absl::get<TFE_TensorHandle*>(at->t);
}
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx) {
return dynamic_cast_helper<TF_EagerContext>(ctx)->eager_ctx_;
}

View File

@ -15,8 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/tf_status.h"
#ifdef __cplusplus
extern "C" {
@ -41,32 +41,19 @@ typedef struct TF_AbstractTensor TF_AbstractTensor;
// could contain the op type and other attributes.
typedef struct TF_AbstractOp TF_AbstractOp;
TF_ExecutionContext* TF_NewExecutionContext();
// `TF_ExecutionContextOptions` define what type of `TF_ExecutionContext` is
// created. It can be used to pass context specific params.
typedef struct TF_ExecutionContextOptions TF_ExecutionContextOptions;
void TF_DeleteExecutionContextOptions(TF_ExecutionContextOptions*);
TF_ExecutionContext* TF_NewExecutionContext(TF_ExecutionContextOptions*,
TF_Status* s);
void TF_DeleteExecutionContext(TF_ExecutionContext*);
TF_AbstractOp* TF_NewAbstractOp();
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx);
void TF_DeleteAbstractOp(TF_AbstractOp*);
TF_AbstractTensor* TF_NewAbstractTensor();
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
// -----------------------------------------------------------------------------
// APIs for Eager and graph modes
// -----------------------------------------------------------------------------
// Keeps track of the current graph and other state e.g. captures etc.
typedef struct TF_GraphContext TF_GraphContext;
TF_GraphContext* TF_NewGraphContext(TF_Graph*);
void TF_DeleteGraphContext(TF_GraphContext*);
// `eager_context` must outlive `context`.
void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context,
TFE_Context* eager_context, TF_Status*);
// `graph_context` must outlive `context`.
void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context,
TF_GraphContext* graph_context,
TF_Status*);
// TODO(srbs): Add APIs for specifying attrs etc.
// `op_type` must outlive `op`.
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
@ -74,25 +61,9 @@ void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
// `op_name` must outlive `op`.
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
TF_Status* s);
// Wrapper for TF_Output but contains a pointer to TF_GraphContext as well.
typedef struct TF_GraphTensor TF_GraphTensor;
TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* c, TF_Output t,
TF_Status* s);
TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s);
void TF_DeleteGraphTensor(TF_GraphTensor* t);
// `t` must outlive `at`.
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
TF_Status* s);
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s);
// `t` must outlive `at`.
void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t,
TF_Status* s);
TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at,
TF_Status* s);
// `attr_name` must outlive `op`.
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
TF_DataType value, TF_Status* s);
// TF_OutputList just lets us not specify the number of outputs of an operation
// beforehand. This forces a memory allocation in the runtime, which is bad, but
@ -104,6 +75,17 @@ void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*);
int TF_OutputListNumOutputs(TF_OutputList* o);
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
// Stores a function representation that can be used for execution or for
// setting functional attributes of other composite ops e.g. control flow.
typedef struct TF_AbstractFunction TF_AbstractFunction;
TF_AbstractFunction* TF_ExecutionContextToFunction(
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const TF_AbstractTensor* inputs, int num_outputs,
const TF_AbstractTensor* outputs, TF_Status* status);
void TF_DeleteAbstractFunction(TF_AbstractFunction*);
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext*,
TF_AbstractFunction*, TF_Status*);
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
// capture some inputs and then add a node in the graph, and after
// execution/node creation it'll go and record things that happened in any tape
@ -112,6 +94,23 @@ void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s);
// -----------------------------------------------------------------------------
// APIs specific to Eager and graph modes
// -----------------------------------------------------------------------------
TF_ExecutionContextOptions* TF_NewGraphContextOptions();
TF_ExecutionContextOptions* TF_NewEagerContextOptions(TFE_ContextOptions*);
// Temporary APIs till we figure out how to create scalar valued Eager
// tensors and how to get value out of eager abstract tensors.
TF_AbstractTensor* TF_NewAbstractTensor();
void TF_AbstractTensorSetEagerTensor(
TF_AbstractTensor* at, TFE_TensorHandle* t,
TF_Status* s); // `at` takes ownership of `t`.
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s);
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext*);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -33,26 +33,25 @@ namespace tensorflow {
namespace {
TEST(UnifedCAPI, TestBasicEager) {
TF_ExecutionContext* ctx = TF_NewExecutionContext();
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* eager_ctx = TFE_NewContext(opts, status.get());
TF_ExecutionContextOptions* options = TF_NewEagerContextOptions(opts);
TF_ExecutionContext* ctx = TF_NewExecutionContext(options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
// Enter the eager context.
TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract input tensor.
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx);
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
TF_AbstractTensor* at = TF_NewAbstractTensor();
TF_AbstractTensorSetEagerTensor(at, t, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract operation.
auto* op = TF_NewAbstractOp();
auto* op = TF_NewAbstractOp(ctx);
TF_AbstractOpSetOpType(op, "Add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
@ -69,7 +68,6 @@ TEST(UnifedCAPI, TestBasicEager) {
// Clean up operation and inputs.
TF_DeleteAbstractOp(op);
TF_DeleteAbstractTensor(at);
TFE_DeleteTensorHandle(t);
// Verify the results.
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
@ -83,100 +81,98 @@ TEST(UnifedCAPI, TestBasicEager) {
TF_DeleteTensor(result_tensor);
TF_DeleteAbstractTensor(result);
TFE_DeleteTensorHandle(result_t);
TF_DeleteOutputList(o);
TFE_DeleteContext(eager_ctx);
TF_DeleteExecutionContext(ctx);
TF_DeleteExecutionContextOptions(options);
}
TEST(UnifedCAPI, TestBasicGraph) {
TF_ExecutionContext* ctx = TF_NewExecutionContext();
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
// Enter a graph context.
TF_Graph* g = TF_NewGraph();
TF_GraphContext* graph_context = TF_NewGraphContext(g);
TF_ExecutionContextSetGraphContext(ctx, graph_context, status.get());
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
TF_ExecutionContext* graph_ctx =
TF_NewExecutionContext(options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
auto* placeholder_op = TF_NewOperation(g, "Placeholder", "Placeholder");
TF_SetAttrType(placeholder_op, "dtype", TF_FLOAT);
auto* operation = TF_FinishOperation(placeholder_op, status.get());
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_Output placeholder_t = {operation, 0};
TF_GraphTensor* graph_t =
TF_NewGraphTensor(graph_context, placeholder_t, status.get());
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractTensor* t = TF_NewAbstractTensor();
TF_AbstractTensorSetGraphTensor(t, graph_t, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract operation.
auto* op = TF_NewAbstractOp();
TF_AbstractOpSetOpType(op, "Add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(op, "my_add", status.get());
TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {t, t};
TF_OutputList* o = TF_NewOutputList();
TF_OutputList* placeholder_outputs = TF_NewOutputList();
// Execute.
TF_ExecuteOperation(op, 2, inputs, o, ctx, status.get());
TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs,
graph_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs));
TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0);
// Delete placeholder op.
TF_DeleteAbstractOp(placeholder_op);
// Build an abstract operation.
auto* add_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(add_op, "Add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(add_op, "my_add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t};
TF_OutputList* add_outputs = TF_NewOutputList();
// Execute.
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractTensor* output_t = TF_OutputListGet(add_outputs, 0);
// Clean up operation and inputs.
TF_DeleteAbstractOp(op);
TF_DeleteAbstractTensor(t);
TF_DeleteGraphTensor(graph_t);
TF_DeleteAbstractOp(add_op);
TF_AbstractTensor* result = TF_OutputListGet(o, 0);
TF_GraphTensor* result_graph_tensor =
TF_AbstractTensorGetGraphTensor(result, status.get());
TF_DeleteAbstractTensor(result);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_Output result_output =
TF_GraphTensorToOutput(result_graph_tensor, status.get());
TF_DeleteGraphTensor(result_graph_tensor);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
string fn_name = "double";
TF_Function* f = TF_GraphToFunction(
g, fn_name.c_str(), 0, -1, nullptr, 1, &placeholder_t, 1, &result_output,
nullptr, nullptr, fn_name.c_str(), status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
graph_ctx, fn_name.c_str(), 1, placeholder_t, 1, output_t, status.get());
TF_DeleteAbstractTensor(placeholder_t);
TF_DeleteAbstractTensor(output_t);
// Build an eager context to run the function.
// Build eager context.
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* eager_ctx = TFE_NewContext(opts, status.get());
TF_ExecutionContextOptions* eager_ctx_options =
TF_NewEagerContextOptions(opts);
TF_ExecutionContext* eager_execution_ctx =
TF_NewExecutionContext(eager_ctx_options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
// Build the abstract op to run the function.
TFE_ContextAddFunction(eager_ctx, f, status.get());
TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOp* fn_op = TF_NewAbstractOp();
// Build the abstract op to run the function.
TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx);
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract input tensor.
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
TF_AbstractTensor* input_t = TF_NewAbstractTensor();
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(eager_execution_ctx);
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
TF_AbstractTensorSetEagerTensor(input_t, input_eager, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Enter the eager context.
TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get());
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_OutputListSetNumOutputs(o, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_ExecuteOperation(fn_op, 1, &input_t, o, ctx, status.get());
TF_ExecuteOperation(fn_op, 1, &input_t, add_outputs, eager_execution_ctx,
status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
TF_AbstractTensor* final_result = TF_OutputListGet(o, 0);
ASSERT_EQ(1, TF_OutputListNumOutputs(add_outputs));
TF_AbstractTensor* final_result = TF_OutputListGet(add_outputs, 0);
TFE_TensorHandle* final =
TF_AbstractTensorGetEagerTensor(final_result, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
@ -185,19 +181,201 @@ TEST(UnifedCAPI, TestBasicGraph) {
float* f_value = static_cast<float*>(TF_TensorData(f_t));
ASSERT_EQ(*f_value, 4.0);
TF_DeleteOutputList(o);
TF_DeleteOutputList(add_outputs);
TF_DeleteOutputList(placeholder_outputs);
TF_DeleteAbstractOp(fn_op);
TF_DeleteAbstractTensor(input_t);
TFE_DeleteTensorHandle(input_eager);
TF_DeleteAbstractTensor(final_result);
TFE_DeleteTensorHandle(final);
TF_DeleteTensor(f_t);
TF_DeleteFunction(f);
TF_DeleteAbstractFunction(func);
TF_DeleteExecutionContext(graph_ctx);
TF_DeleteExecutionContext(eager_execution_ctx);
TF_DeleteExecutionContextOptions(eager_ctx_options);
TF_DeleteExecutionContextOptions(options);
}
TEST(UnifedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContextOptions* options = TF_NewEagerContextOptions(opts);
TF_ExecutionContext* ctx = TF_NewExecutionContext(options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
ctx, nullptr, 0, nullptr, 0, nullptr, status.get());
ASSERT_EQ(nullptr, func);
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
TF_DeleteGraphContext(graph_context);
TF_DeleteGraph(g);
TFE_DeleteContext(eager_ctx);
TF_DeleteExecutionContext(ctx);
TF_DeleteExecutionContextOptions(options);
}
TEST(UnifedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
TF_ExecutionContext* graph_ctx =
TF_NewExecutionContext(options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// This should fail.
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
ASSERT_EQ(TF_FAILED_PRECONDITION, TF_GetCode(status.get()));
TF_DeleteAbstractOp(placeholder_op);
TF_DeleteExecutionContext(graph_ctx);
TF_DeleteExecutionContextOptions(options);
}
TEST(UnifedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
TF_ExecutionContext* graph_ctx =
TF_NewExecutionContext(options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// This should fail.
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
ASSERT_EQ(TF_FAILED_PRECONDITION, TF_GetCode(status.get()));
TF_DeleteAbstractOp(placeholder_op);
TF_DeleteExecutionContext(graph_ctx);
TF_DeleteExecutionContextOptions(options);
}
TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
// Build an Eager context.
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContextOptions* options = TF_NewEagerContextOptions(opts);
TF_ExecutionContext* ctx = TF_NewExecutionContext(options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an Eager operation.
auto* op = TF_NewAbstractOp(ctx);
TF_AbstractOpSetOpType(op, "Add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract input tensor.
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx);
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
TF_AbstractTensor* at = TF_NewAbstractTensor();
TF_AbstractTensorSetEagerTensor(at, t, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {at, at};
TF_OutputList* o = TF_NewOutputList();
TF_OutputListSetNumOutputs(o, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build a Graph context.
TF_ExecutionContextOptions* graph_options = TF_NewGraphContextOptions();
TF_ExecutionContext* graph_ctx =
TF_NewExecutionContext(graph_options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Execute eager op using graph context.
TF_ExecuteOperation(op, 2, inputs, o, graph_ctx, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
// Clean up operation and inputs.
TF_DeleteAbstractOp(op);
TF_DeleteAbstractTensor(at);
TF_DeleteOutputList(o);
TF_DeleteExecutionContext(ctx);
TF_DeleteExecutionContextOptions(options);
TF_DeleteExecutionContext(graph_ctx);
TF_DeleteExecutionContextOptions(graph_options);
}
TEST(UnifedCAPI, TestExecutingGraphOpInEagerModeRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
TF_ExecutionContext* graph_ctx =
TF_NewExecutionContext(options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_OutputList* placeholder_outputs = TF_NewOutputList();
// Execute.
TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs,
graph_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs));
TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0);
// Delete placeholder op.
TF_DeleteAbstractOp(placeholder_op);
// Build an abstract operation.
auto* add_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(add_op, "Add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(add_op, "my_add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t};
TF_OutputList* add_outputs = TF_NewOutputList();
// Build eager context.
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContextOptions* eager_ctx_options =
TF_NewEagerContextOptions(opts);
TF_ExecutionContext* eager_execution_ctx =
TF_NewExecutionContext(eager_ctx_options, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
// Execute.
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, eager_execution_ctx,
status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
// Clean up operation and inputs.
TF_DeleteAbstractTensor(placeholder_t);
TF_DeleteAbstractOp(add_op);
TF_DeleteOutputList(add_outputs);
TF_DeleteOutputList(placeholder_outputs);
TF_DeleteExecutionContext(graph_ctx);
TF_DeleteExecutionContext(eager_execution_ctx);
TF_DeleteExecutionContextOptions(eager_ctx_options);
TF_DeleteExecutionContextOptions(options);
}
} // namespace

View File

@ -16,6 +16,7 @@ limitations under the License.
// A simple logging device to test custom device registration.
#include <memory>
#include "absl/strings/match.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
@ -25,7 +26,6 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/test.h"
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
@ -176,7 +176,7 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
TEST(CUSTOM_DEVICE, AccessVariableOnCustomDevice) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
@ -226,16 +226,21 @@ TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
// Read the variable's value.
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpAddInput(op.get(), var_handle, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
executed = false;
num_retvals = 1;
TFE_TensorHandle* var_value = nullptr;
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
EXPECT_FALSE(TF_GetCode(status.get()) == TF_OK)
<< "Execution should fail because the variable is being used on the "
"wrong device.";
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
ASSERT_EQ(
tensorflow::string(name),
tensorflow::string(TFE_TensorHandleDeviceName(var_value, status.get())));
TFE_DeleteTensorHandle(var_value);
// Free the backing buffer for the variable.
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
@ -246,6 +251,79 @@ TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
TEST(CUSTOM_DEVICE, InputBasedPlacement) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* custom0 = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
const char* custom1 = "/job:localhost/replica:0/task:0/device:CUSTOM:1";
bool arrived = false;
bool executed = false;
RegisterLoggingDevice(context.get(), custom0, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
RegisterLoggingDevice(context.get(), custom1, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcpu(
TestMatrixTensorHandle(context.get()), TFE_DeleteTensorHandle);
ASSERT_FALSE(arrived);
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcustom0(
TFE_TensorHandleCopyToDevice(hcpu.get(), context.get(), custom0,
status.get()),
TFE_DeleteTensorHandle);
ASSERT_TRUE(arrived);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
arrived = false;
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcustom1(
TFE_TensorHandleCopyToDevice(hcpu.get(), context.get(), custom1,
status.get()),
TFE_DeleteTensorHandle);
ASSERT_TRUE(arrived);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Base case: two CPU inputs executes fine.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> matmul(
MatMulOp(context.get(), hcpu.get(), hcpu.get()), TFE_DeleteOp);
TFE_TensorHandle* retval;
int num_retvals = 1;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_DeleteTensorHandle(retval);
// Custom device: inputs in same custom device works.
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcustom0.get()));
num_retvals = 1;
executed = false;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
TFE_DeleteTensorHandle(retval);
// Custom device: inputs in different custom devices fails.
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcustom1.get()));
num_retvals = 1;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_NE(TF_OK, TF_GetCode(status.get()));
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom1));
// Custom device: mix of custom/physical fails.
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
num_retvals = 1;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_NE(TF_OK, TF_GetCode(status.get()));
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
ASSERT_TRUE(
absl::StrContains(TF_Message(status.get()), "[]")); // kVariantDeviceNull
}
TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);

View File

@ -42,7 +42,28 @@ class AbstractOperationInterface {
virtual Status Reset(const char* op, const char* raw_device_name) = 0;
virtual const string& Name() const = 0;
// Returns the operation's device name.
//
// The value returned may be different from the one set by SetDeviceName, but
// it will be compatible with it: the name will be updated by device placement
// logic to refer to the specific device chosen.
//
// Example: If one calls `op->SetDeviceName("/device:GPU")`, the value
// returned by DeviceName should be "/device:GPU:*" until a particular GPU is
// chosen for the operation by the device placement logic in the
// executor. After that, the value returned by DeviceName will be a full
// device name such as "/job:localhost/replica:0/task:0/device:GPU:1".
virtual const string& DeviceName() const = 0;
// Sets the operation device name.
//
// The given `name` must be parseable by DeviceNameUtils::ParseFullName, and
// the result will be used as a constraint for device placement. See the
// documentation for DeviceName for more details.
//
// The value will override the previous value - that is, no "merging" of
// existing and given constraints will be performed.
virtual Status SetDeviceName(const char* name) = 0;
virtual Status AddInput(AbstractTensorHandleInterface* input) = 0;

View File

@ -0,0 +1,66 @@
# Tensorflow C SavedModel API
## Overview
These are the new experimental C SavedModel APIs for loading and running
SavedModels in a TF2-idiomatic fashion. See
[RFC 207](https://github.com/tensorflow/community/pull/207) for additional
context.
The directory structure is as follows:
```none
saved_model/
public/
internal/
core/
```
## saved_model/public
`saved_model/public` is intended to house *only the public headers* of the
SavedModel C API.
These headers:
1. declare opaque C types (like `TF_SavedModel`),
2. declare the functions that operate on these types (like `TF_LoadSavedModel`).
Once they leave experimental, these APIs should be considered stable for use
by external clients.
These headers are in a separate directory to make it obvious to clients which
headers they should depend on, and which headers are implementation details.
Separating these public headers by directory also allow future programmatic
checks to ensure that TF public headers only `#include` other public TF headers.
## saved_model/internal
`saved_model/internal` is the "glue" between the C API and the internal C++
implementation.
Its role is to:
1. implement the C API functions declared in `saved_model/public`
2. define the C API types declared in `saved_model/public`
The files fulfilling 1. are named `*.cc` (eg: `concrete_function.cc`), while
the files fulfilling 2. are `*type.h` (eg: `concrete_function_type.h`).
The headers exposing the internal implementation of the opaque C types are only
visible to other implementors of the C API. This is similar to how other
TF C API implementations use `tf_status_internal.h` (to extract the underlying
`tensorflow::Status`). All other targets in this directory are private.
## saved_model/core
`saved_model/core` contains pure C++ "Classes" underlying the C API types
in `saved_model/public/`. These are implementation
details subject to change, and have limited visibility to implementors only.
This is the bottom-most layer of the `C++ -> C -> C++` sandwich.

View File

@ -0,0 +1,46 @@
# Experimental SavedModel C APIs for TensorFlow. See RFC
# https://github.com/tensorflow/community/pull/207
# Targets in this directory are pure C++ "Classes" underlying the C API types
# under tf/c/experimental/saved_model/public/. They are subject to change and
# have visibility limited to Tensorflow's implementation only.
package(
default_visibility = [
"//tensorflow/c/experimental/saved_model/internal:__pkg__",
],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "concrete_function",
srcs = [
"concrete_function.cc",
],
hdrs = [
"concrete_function.h",
],
deps = [
":function_metadata",
"//tensorflow/c/eager:operation_interface",
"//tensorflow/c/eager:tensor_handle_interface",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "function_metadata",
hdrs = [
"function_metadata.h",
],
)
cc_library(
name = "saved_model_api",
hdrs = [
"saved_model_api.h",
],
deps = [
":concrete_function",
"//tensorflow/core:lib",
],
)

View File

@ -0,0 +1,32 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
namespace tensorflow {
const std::vector<tensorflow::AbstractTensorHandleInterface*>&
ConcreteFunction::Captures() const {
return captures_;
}
const FunctionMetadata& ConcreteFunction::GetFunctionMetadata() const {
return metadata_;
}
} // namespace tensorflow

View File

@ -0,0 +1,55 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_
#include <vector>
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/core/framework/function.pb.h"
namespace tensorflow {
// Note that ConcreteFunctions's lifetimes are effectively bound
// to the SavedModel they are loaded from, since they retain pointers
// to the TensorHandles owned by the SavedModel, and the FunctionDef
// of the SavedModel.
// Note(bmzhao): This class is only TEMPORARILY virtual, as a way to unblock
// TFRT integration with TF Serving. Do not add more virtual implementations of
// this class. Eventually we want to remove this virtual base class indirection
// and have only a single implementation.
class ConcreteFunction {
public:
virtual ~ConcreteFunction() = 0;
// This method returns the "Call" Op used to execute the function.
virtual AbstractOperationInterface* GetFunctionOp() = 0;
const std::vector<tensorflow::AbstractTensorHandleInterface*>& Captures()
const;
const FunctionMetadata& GetFunctionMetadata() const;
private:
FunctionMetadata metadata_;
std::vector<tensorflow::AbstractTensorHandleInterface*> captures_;
FunctionDef* function_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_

View File

@ -1,4 +1,4 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,13 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_LOGGING_H_
#define TENSORFLOW_LITE_TOOLS_BENCHMARK_LOGGING_H_
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_FUNCTION_METADATA_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_FUNCTION_METADATA_H_
// TODO(b/149482807): completely remove this file from the code base.
#include "tensorflow/lite/tools/logging.h"
namespace tensorflow {
#define TFLITE_BENCHMARK_CHECK(condition) TFLITE_TOOLS_CHECK(condition)
#define TFLITE_BENCHMARK_CHECK_EQ(a, b) TFLITE_TOOLS_CHECK(a == b)
class FunctionMetadata {
// TODO(bmzhao): Fill in with fields as necessary
};
#endif // TENSORFLOW_LITE_TOOLS_BENCHMARK_LOGGING_H_
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_FUNCTION_METADATA_H_

View File

@ -0,0 +1,55 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_API_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_API_H_
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
// Note(bmzhao): This class is only TEMPORARILY virtual, as a way to unblock
// TFRT integration with TF Serving. Do not add more virtual implementations of
// this class. Eventually we want to remove this virtual base class indirection
// and have only a single implementation.
class SavedModelAPI {
public:
// Retrieve a function from the TF2 SavedModel, using the "path" to a function
// in a TF2 savedmodel.
// Note: `function` is a double pointer, so that implementations are
// able to return a pointer to an internal member.
virtual Status GetFunction(const std::string& function_path,
ConcreteFunction** function) = 0;
// Retrieve a function from a SavedModel, using the key of the
// SignatureDef map:
// https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89
virtual Status GetSignatureDefFunction(const std::string& signature_def_key,
ConcreteFunction** function) = 0;
virtual const std::vector<ConcreteFunction*>& ListFunctions() = 0;
virtual ~SavedModelAPI() = default;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_API_H_

View File

@ -0,0 +1,157 @@
# Experimental Implementation of SavedModel C APIs for TensorFlow. See RFC
# https://github.com/tensorflow/community/pull/207
# External clients should not worry about this directory; all contents are implementation details.
# Code in this directory is intended to form the glue between the C API and the internal C++
# implementation by
# 1. mapping C API calls onto correponding methods of C++ objects
# 2. mapping opaque C types onto C++ classes
# Note(bmzhao): The *.cc files in this directory form the direct implementation of the
# C API functions exposed in tf/c/experimental/saved_model/public/.
# Note(bmzhao): All *type.h files in this directory are the internal definitions of
# the opaque C types. These headers should only be visible to internal tensorflow
# implementors.
package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "conversion_macros",
hdrs = [
"conversion_macros.h",
],
)
cc_library(
name = "concrete_function",
srcs = [
"concrete_function.cc",
],
hdrs = [
"//tensorflow/c/experimental/saved_model/public:concrete_function.h",
],
# TODO(bmzhao): Remove this as we refactor C API to granular targets,
# so that we can depend on c/eager/c_api_unified_experimental.h.
features = ["-layering_check"],
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
deps = [
":concrete_function_type",
":function_metadata",
":function_metadata_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_internal",
"//tensorflow/c/experimental/saved_model/core:concrete_function",
"//tensorflow/c/experimental/saved_model/core:function_metadata",
],
)
cc_library(
name = "concrete_function_list",
srcs = [
"concrete_function_list.cc",
],
hdrs = [
"//tensorflow/c/experimental/saved_model/public:concrete_function_list.h",
],
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
deps = [
":concrete_function",
":concrete_function_list_type",
":concrete_function_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c/experimental/saved_model/core:concrete_function",
],
)
cc_library(
name = "concrete_function_list_type",
hdrs = [
"concrete_function_list_type.h",
],
deps = [
":conversion_macros",
"//tensorflow/c/experimental/saved_model/core:concrete_function",
],
)
cc_library(
name = "concrete_function_type",
hdrs = [
"concrete_function_type.h",
],
deps = [
":conversion_macros",
"//tensorflow/c/experimental/saved_model/core:concrete_function",
],
)
cc_library(
name = "function_metadata",
srcs = [
"function_metadata.cc",
],
hdrs = [
"//tensorflow/c/experimental/saved_model/public:function_metadata.h",
],
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
deps = [
":function_metadata_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c/experimental/saved_model/core:function_metadata",
],
)
cc_library(
name = "function_metadata_type",
hdrs = [
"function_metadata_type.h",
],
deps = [
":conversion_macros",
"//tensorflow/c/experimental/saved_model/core:function_metadata",
],
)
cc_library(
name = "saved_model_api",
srcs = [
"saved_model_api.cc",
],
hdrs = [
"//tensorflow/c/experimental/saved_model/public:saved_model_api.h",
],
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
deps = [
":concrete_function",
":concrete_function_list",
":concrete_function_list_type",
":concrete_function_type",
":saved_model_api_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_internal",
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
"//tensorflow/core:lib",
],
)
cc_library(
name = "saved_model_api_type",
hdrs = [
"saved_model_api_type.h",
],
deps = [
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
],
)

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h"
extern "C" {
TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) {
return tensorflow::wrap(&tensorflow::unwrap(func)->GetFunctionMetadata());
}
TF_OutputList* TF_ConcreteFunctionGetCaptures(TF_ConcreteFunction* func) {
// TODO(bmzhao): Refactor TF_OutputList struct definition into a separate
// internal header, and implement this function.
return nullptr;
}
TFE_Op* TF_ConcreteFunctionGetOperation(TF_ConcreteFunction* func) {
return new TFE_Op{tensorflow::unwrap(func)->GetFunctionOp()};
}
} // end extern "C"

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stddef.h>
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
extern "C" {
size_t TF_ConcreteFunctionListNumOutputs(TF_ConcreteFunctionList* list) {
return tensorflow::unwrap(list)->size();
}
TF_ConcreteFunction* TF_ConcreteFunctionListGet(TF_ConcreteFunctionList* list,
int i) {
return tensorflow::wrap((*tensorflow::unwrap(list))[i]);
}
} // end extern "C"

View File

@ -0,0 +1,36 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
#include <vector>
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/internal/conversion_macros.h"
// Internal structures used by the SavedModel C API. These are likely to change
// and should not be depended on.
typedef struct TF_ConcreteFunctionList TF_ConcreteFunctionList;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(std::vector<tensorflow::ConcreteFunction*>,
TF_ConcreteFunctionList)
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_

View File

@ -0,0 +1,36 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/internal/conversion_macros.h"
// Internal structures used by the SavedModel C API. These are likely to change
// and should not be depended on.
// It doesn't make sense to wrap tensorflow::ConcreteFunction* in a separate
// struct, since the lifetime of the struct and the raw pointer it wraps would
// be different. Therefore TF_ConcreteFunction* = tensorflow::ConcreteFunction*.
typedef struct TF_ConcreteFunction TF_ConcreteFunction;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ConcreteFunction, TF_ConcreteFunction)
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_

View File

@ -0,0 +1,28 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONVERSION_MACROS_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONVERSION_MACROS_H_
#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \
inline cpp_impl *unwrap(wrapper *w) { \
return reinterpret_cast<cpp_impl *>(w); \
} \
\
inline wrapper *wrap(const cpp_impl *i) { \
return reinterpret_cast<wrapper *>(const_cast<cpp_impl *>(i)); \
}
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONVERSION_MACROS_H_

View File

@ -0,0 +1,20 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h"
// TODO(bmzhao): Add getter functions here as necessary.

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/internal/conversion_macros.h"
typedef struct TF_FunctionMetadata TF_FunctionMetadata;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::FunctionMetadata, TF_FunctionMetadata)
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_

View File

@ -0,0 +1,67 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/platform/status.h"
extern "C" {
TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
const char* const* tags, int tags_len,
TF_Status* status) {
// TODO(bmzhao): Add a virtual "LoadSavedModel" method to
// AbstractContextInterface, and call it here.
return nullptr;
}
void TF_DeleteSavedModel(TF_SavedModel* model) { delete model; }
TF_ConcreteFunction* TF_GetSavedModelFunction(TF_SavedModel* model,
char* function_path,
TF_Status* status) {
tensorflow::ConcreteFunction* result = nullptr;
tensorflow::Status get_function_status =
model->saved_model->GetFunction(function_path, &result);
status->status.Update(get_function_status);
if (!get_function_status.ok()) {
return nullptr;
}
return tensorflow::wrap(result);
}
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
TF_SavedModel* model, char* signature_def_key, TF_Status* status) {
tensorflow::ConcreteFunction* result = nullptr;
tensorflow::Status get_function_status =
model->saved_model->GetSignatureDefFunction(signature_def_key, &result);
status->status.Update(get_function_status);
if (!get_function_status.ok()) {
return nullptr;
}
return tensorflow::wrap(result);
}
TF_ConcreteFunctionList* TF_ListSavedModelFunctions(TF_SavedModel* model) {
return tensorflow::wrap(&model->saved_model->ListFunctions());
}
} // end extern "C"

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_
#include <memory>
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
// Internal structures used by the SavedModel C API. These are likely to change
// and should not be depended on.
struct TF_SavedModel {
std::unique_ptr<tensorflow::SavedModelAPI> saved_model;
};
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_

View File

@ -0,0 +1,63 @@
# Experimental SavedModel C APIs for TensorFlow.
# See RFC https://github.com/tensorflow/community/pull/207
# All headers are on the public surface of Tensorflow's C API.
# Once moved out of experimental, these will be stable.
# The idea behind a separate public/ directory is to make apparent
# which headers are part of TF's public interface (and which headers)
# are implementation details. This structure allows us to also perform future
# programmatic checks that all "public" headers only include other "public"
# headers.
package(
# This is intentionally public
default_visibility = [
"//visibility:public",
],
licenses = ["notice"], # Apache 2.0
)
# TODO(bmzhao): Remove these exports_files and rules, swap with cc_public_library instead.
# cc_public_library would allows us to separate the header dep graph from header+srcs dep graph.
exports_files(
[
"concrete_function.h",
"concrete_function_list.h",
"function_metadata.h",
"saved_model_api.h",
],
visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"],
)
# The purpose of this header is to provide insulation against
# future changes where we rename/move a public header, without
# forcing all clients to change their "#includes".
cc_library(
name = "c_saved_model_api",
hdrs = ["c_saved_model_api.h"],
deps = [
":concrete_function",
":concrete_function_list",
":function_metadata",
":saved_model_api",
],
)
alias(
name = "concrete_function",
actual = "//tensorflow/c/experimental/saved_model/internal:concrete_function",
)
alias(
name = "concrete_function_list",
actual = "//tensorflow/c/experimental/saved_model/internal:concrete_function_list",
)
alias(
name = "function_metadata",
actual = "//tensorflow/c/experimental/saved_model/internal:function_metadata",
)
alias(
name = "saved_model_api",
actual = "//tensorflow/c/experimental/saved_model/internal:saved_model_api",
)

View File

@ -0,0 +1,26 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_
// IWYU pragma: begin_exports
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
// IWYU pragma: end_exports
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_

View File

@ -0,0 +1,53 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// An opaque type that corresponds to a Function loaded from a SavedModel.
// TODO(bmzhao): Work together w/srbs@ to make sure this composes w/the
// C++ Unified Eager/Graph API's AbstractFunction
typedef struct TF_ConcreteFunction TF_ConcreteFunction;
// Returns FunctionMetadata associated with `func`. Metadata's lifetime is
// bound to `func`, which is bound to the TF_SavedModel it was loaded from.
TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(
TF_ConcreteFunction* func);
// Returns a list of TensorHandles implicitly captured by this function.
TF_CAPI_EXPORT extern TF_OutputList* TF_ConcreteFunctionGetCaptures(
TF_ConcreteFunction* func);
// Returns a TFE_Op suitable for executing this function.
TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetOperation(
TF_ConcreteFunction* func);
// Deletes `func`.
TF_CAPI_EXPORT extern void TF_DeleteConcreteFunction(TF_ConcreteFunction* func);
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_

View File

@ -0,0 +1,35 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
#include <stddef.h>
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
// An opaque type that is acts like a list of TF_ConcreteFunction pointers.
typedef struct TF_ConcreteFunctionList TF_ConcreteFunctionList;
// Returns the size of `list`.
TF_CAPI_EXPORT size_t
TF_ConcreteFunctionListSize(TF_ConcreteFunctionList* list);
// Returns the `i`th TF_ConcreteFunction in the list.
TF_CAPI_EXPORT TF_ConcreteFunction* TF_ConcreteFunctionListGet(
TF_ConcreteFunctionList* list, int i);
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_

View File

@ -0,0 +1,35 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_FUNCTION_METADATA_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_FUNCTION_METADATA_H_
#include "tensorflow/c/c_api_macros.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// An opaque type used to store any metadata associated with a function.
typedef struct TF_FunctionMetadata TF_FunctionMetadata;
// TODO(bmzhao): Add getters for fields as we determine what metadata
// we want to expose.
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_FUNCTION_METADATA_H_

View File

@ -0,0 +1,96 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SAVED_MODEL_API_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SAVED_MODEL_API_H_
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
#include "tensorflow/c/tf_status.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// An opaque type representing a Tensorflow "SavedModel"
// (https://www.tensorflow.org/guide/saved_model) that we always pass by pointer
// to achieve ABI stability.
typedef struct TF_SavedModel TF_SavedModel;
// Load a SavedModel from `dirname`.
//
// Params:
// dirname - A directory filepath that the SavedModel is at.
// ctx - A TFE_Context containing optional load/TF runtime options.
// `ctx` must outlive the returned TF_SavedModel pointer.
// tags - Pointer to char* array of SavedModel tags. Optional if the SavedModel
// contains a single Metagraph, as for those exported from
// `tf.saved_model.save`.
// tags_len - number of elements in the `tags` array.
// status - Set to OK on success and an appropriate error on failure.
// Returns:
// If status is not OK, returns nullptr. Otherwise, returns a newly created
// TF_SavedModel instance. It must be deleted by calling TF_DeleteSavedModel.
TF_CAPI_EXPORT extern TF_SavedModel* TF_LoadSavedModel(const char* dirname,
TFE_Context* ctx,
const char* const* tags,
int tags_len,
TF_Status* status);
// Deletes a TF_SavedModel, and frees any resources owned by it.
TF_CAPI_EXPORT extern void TF_DeleteSavedModel(TF_SavedModel* model);
// Retrieve a function from the TF2 SavedModel via function path.
//
// Params:
// model - The TF2 SavedModel to load a function from.
// function_path - A string containing the path from the root saved python
// object to a tf.function method.
// TODO(bmzhao): Add a detailed example of this with a
// python tf.module before moving this out of experimental.
// status - Set to OK on success and an appropriate error on failure.
// Returns:
// If status is not OK, returns nullptr. Otherwise, returns a
// TF_ConcreteFunction instance. The lifetime of this instance is
// "conceptually" bound to `model`. Once `model` is deleted, all
// `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted.
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelFunction(
TF_SavedModel* model, char* function_path, TF_Status* status);
// Retrieve a function from the TF SavedModel via a SignatureDef key.
//
// Params:
// model - The SavedModel to load a function from.
// signature_def_key - The string key of the SignatureDef map of a SavedModel:
// https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89
// status - Set to OK on success and an appropriate error on failure.
// Returns:
// If status is not OK, returns nullptr. Otherwise, returns a
// TF_ConcreteFunction instance. Once `model` is deleted, all
// `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted.
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
TF_SavedModel* model, char* signature_def_key, TF_Status* status);
// Returns a list of all ConcreteFunctions stored in this SavedModel.
// The lifetime of the returned list is bound to `model`.
TF_CAPI_EXPORT extern TF_ConcreteFunctionList* TF_ListSavedModelFunctions(
TF_SavedModel* model);
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SAVED_MODEL_API_H_

View File

@ -119,6 +119,9 @@ inline Tensor& TensorFromInterface(AbstractTensorInterface* tensor) {
return down_cast<TensorInterface*>(tensor)->Tensor();
}
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
} // namespace tensorflow
#endif // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_

View File

@ -156,6 +156,7 @@ cc_library(
":array_grad",
":data_flow_grad",
":image_grad",
":manip_grad",
":math_grad",
":nn_grad",
],
@ -494,6 +495,32 @@ tf_cc_test(
],
)
cc_library(
name = "manip_grad",
srcs = ["gradients/manip_grad.cc"],
deps = [
":cc_ops",
":grad_op_registry",
":gradients",
],
alwayslink = 1,
)
tf_cc_test(
name = "gradients_manip_grad_test",
srcs = ["gradients/manip_grad_test.cc"],
deps = [
":array_ops",
":cc_ops",
":gradient_checker",
":manip_grad",
":testutil",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
# Generates separate libraries for array_ops and math_ops to reduce the dependency count of targets that depend on only these
tf_gen_op_wrappers_cc(
name = "math_ops",

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/framework/grad_op_registry.h"
#include "tensorflow/cc/framework/gradients.h"
#include "tensorflow/cc/ops/manip_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
namespace tensorflow {
namespace ops {
namespace {
Status RollGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
auto shift = op.input(1);
auto axis = op.input(2);
auto grad_op = Roll(scope, grad_inputs[0], Neg(scope, shift), axis);
grad_outputs->push_back(grad_op);
grad_outputs->push_back(NoGradient());
grad_outputs->push_back(NoGradient());
return scope.status();
}
REGISTER_GRADIENT_OP("Roll", RollGrad);
} // namespace
} // namespace ops
} // namespace tensorflow

View File

@ -0,0 +1,51 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/framework/gradient_checker.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/manip_ops.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace tensorflow {
namespace {
using ops::Placeholder;
using ops::Roll;
class ManipGradTest : public ::testing::Test {
protected:
ManipGradTest() : scope_(Scope::NewRootScope()) {}
void RunTest(const Output& x, const TensorShape& x_shape, const Output& y,
const TensorShape& y_shape) {
TF_ASSERT_OK(scope_.status());
float max_error;
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
scope_, {x}, {x_shape}, {y}, {y_shape}, &max_error)));
EXPECT_LT(max_error, 1e-4);
}
Scope scope_;
};
TEST_F(ManipGradTest, RollGrad) {
TensorShape shape({5, 4, 3});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
auto y = Roll(scope_, x, {2, 1}, {0, 1});
RunTest(x, shape, y, shape);
}
} // namespace
} // namespace tensorflow

View File

@ -358,13 +358,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_,
resources_, constants_, /*lazy=*/false, &client, &variables, &kernel,
&executable);
if (!s.ok() && (platform_info_.device_type().type_string() == DEVICE_CPU ||
platform_info_.device_type().type_string() == DEVICE_GPU)) {
// Suggest auto jit if the failure was with GPU or CPU.
errors::AppendToMessage(&s,
xla::status_macros::kPossibleAutoJitAlternative);
}
OP_REQUIRES_OK(ctx, s);
}

View File

@ -1891,6 +1891,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"DynamicStitch",
"Einsum",
"EmptyTensorList",
"EnsureShape",
"ExtractImagePatches",
"Igamma",
"IgammaGradA",

View File

@ -145,16 +145,9 @@ Status XlaCompileOnDemandOp::Compile(
attrs.set_on_host(true);
TF_RETURN_IF_ERROR(ctx->allocate_temp(
device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs));
Notification n;
Status status;
ctx->op_device_context()->CopyDeviceTensorToCPU(
Status status = ctx->op_device_context()->CopyDeviceTensorToCPUSync(
&device_tensor, "ConstantArgument",
reinterpret_cast<Device*>(ctx->device()), &host_tensor,
[&](Status s) {
status = s;
n.Notify();
});
n.WaitForNotification();
reinterpret_cast<Device*>(ctx->device()), &host_tensor);
if (!status.ok()) {
LOG(ERROR) << "Copying tensor of shape "
<< device_tensor.shape().DebugString() << " from "

View File

@ -488,15 +488,8 @@ Status XlaDevice::MakeTensorFromProto(XlaDeviceContext* device_context,
mutex_lock lock(mu_);
Allocator* allocator = GetAllocatorLocked(alloc_attrs);
Tensor copy(allocator, parsed.dtype(), parsed.shape());
Notification n;
device_context->CopyCPUTensorToDevice(
&parsed, this, &copy,
[&n, &status](const Status& s) {
status = s;
n.Notify();
},
true /*sync_dst_compute*/);
n.WaitForNotification();
TF_RETURN_IF_ERROR(
device_context->CopyCPUTensorToDeviceSync(&parsed, this, &copy));
*tensor = copy;
}
VLOG(2) << "Allocated tensor at " << DMAHelper::base(tensor);

View File

@ -69,6 +69,7 @@ absl::optional<AllocatorStats> XlaDeviceAllocator::GetStats() {
tf_stats.bytes_reserved = se_stats->bytes_reserved;
tf_stats.peak_bytes_reserved = se_stats->peak_bytes_reserved;
tf_stats.bytes_reservable_limit = se_stats->bytes_reservable_limit;
tf_stats.largest_free_block_bytes = se_stats->largest_free_block_bytes;
return tf_stats;
}

View File

@ -479,6 +479,15 @@ Status XlaComputationLaunchContext::PopulateOutputs(
input_output_alias, output_num, ctx, i, shape, &output,
definition_event, stream, use_multiple_streams_));
} else {
auto program_shape =
kernel->computation->GetProgramShape().ValueOrDie();
if (program_shape.result().IsTuple() &&
program_shape.result().tuple_shapes(output_num).IsTuple()) {
return errors::Unimplemented(
"Support for TensorList or Stack crossing the XLA/TF boundary "
"is not implemented");
}
se::DeviceMemoryBase buffer = output.buffer({output_num});
Tensor output_tensor = GetOrCreateTensorForOutput(
output_num, ctx, missing_ctx_input_prefix, input_output_alias,

View File

@ -40,11 +40,11 @@ cc_library(
srcs = ["tf_mlir_opt_main.cc"],
deps = [
":init_mlir",
":passes",
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"@llvm-project//llvm:support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
@ -55,6 +55,7 @@ cc_library(
cc_library(
name = "passes",
visibility = [
":__subpackages__",
"//tensorflow/python:__subpackages__",
],
deps = [
@ -76,24 +77,6 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
"//tensorflow/compiler/mlir/xla:buffer_assignment",
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/mlir/xla:hlo_legalize_to_lhlo",
"//tensorflow/compiler/mlir/xla:lhlo",
"//tensorflow/compiler/mlir/xla:lhlo_copy_removal",
"//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_parallel_loops",
"//tensorflow/compiler/mlir/xla:xla_dialect_registration",
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
"//tensorflow/compiler/mlir/xla:xla_legalize_to_linalg",
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
"//tensorflow/compiler/mlir/xla:xla_lower",
"//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts",
"//tensorflow/compiler/mlir/xla:xla_test_passes",
],
)
@ -141,11 +124,14 @@ cc_library(
tf_cc_binary(
name = "tf-opt",
deps = [
":passes",
":tf_mlir_opt_main",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
"//tensorflow/compiler/mlir/tfjs:tensorflow_js_dialect_registration",
"//tensorflow/compiler/mlir/xla:all_xla_passes_for_testing",
],
)

View File

@ -26,7 +26,7 @@ _ALWAYS_EXCLUDE = [
"**/* */**",
]
def _run_lit_test(name, data, size, tags, driver, features):
def _run_lit_test(name, data, size, tags, driver, features, exec_properties):
"""Runs lit on all tests it can find in `data` under tensorflow/compiler/mlir.
Note that, due to Bazel's hermetic builds, lit only sees the tests that
@ -64,6 +64,7 @@ def _run_lit_test(name, data, size, tags, driver, features):
],
size = size,
main = "lit.py",
exec_properties = exec_properties,
)
def glob_lit_tests(
@ -76,7 +77,8 @@ def glob_lit_tests(
default_tags = _default_tags,
tags_override = {},
driver = _default_driver,
features = []):
features = [],
exec_properties = {}):
"""Creates all plausible Lit tests (and their inputs) under this directory.
Args:
@ -92,6 +94,7 @@ def glob_lit_tests(
Note: use of a custom driver is not currently supported
and specifying a default driver will abort the tests.
features: [str], list of extra features to enable.
exec_properties: a dictionary of properties to pass on.
"""
# Ignore some patterns by default for tests and input data.
@ -115,6 +118,7 @@ def glob_lit_tests(
tags = default_tags + tags_override.pop(curr_test, []),
driver = driver,
features = features,
exec_properties = exec_properties,
)
def lit_test(
@ -123,7 +127,8 @@ def lit_test(
size = _default_size,
tags = _default_tags,
driver = _default_driver,
features = []):
features = [],
exec_properties = {}):
"""Runs test files under lit.
Args:
@ -136,4 +141,4 @@ def lit_test(
and specifying a default driver will abort the tests.
features: [str], list of extra features to enable.
"""
_run_lit_test(name + ".test", data + [name], size, tags, driver, features)
_run_lit_test(name + ".test", data + [name], size, tags, driver, features, exec_properties)

View File

@ -512,7 +512,7 @@ cc_library(
],
deps = [
":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
@ -562,19 +562,16 @@ cc_library(
)
cc_library(
name = "flatbuffer_translate_lib",
name = "flatbuffer_export",
srcs = [
"flatbuffer_export.cc",
"flatbuffer_import.cc",
"utils/convert_type.cc",
],
hdrs = [
"flatbuffer_export.h",
"flatbuffer_export_flags.h",
"flatbuffer_import.h",
"utils/convert_type.h",
],
deps = [
":convert_type",
":flatbuffer_tflite_operator_lib",
":stateful_ops_utils",
":tensorflow_lite",
@ -592,14 +589,12 @@ cc_library(
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:status",
"//tensorflow/lite:framework",
"//tensorflow/lite:schema_fbs_version",
"//tensorflow/lite:string_util",
"//tensorflow/lite/delegates/flex:whitelisted_flex_ops_lib",
"//tensorflow/lite/kernels/internal:kernel_utils",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/tools/versioning",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
@ -614,6 +609,78 @@ cc_library(
],
)
cc_library(
name = "flatbuffer_import",
srcs = [
"flatbuffer_import.cc",
],
hdrs = [
"flatbuffer_import.h",
],
deps = [
":convert_type",
":flatbuffer_tflite_operator_lib",
":tensorflow_lite",
":tensorflow_lite_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
"//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Translation",
],
)
cc_library(
name = "convert_type",
srcs = [
"utils/convert_type.cc",
],
hdrs = [
"utils/convert_type.h",
],
deps = [
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:errors",
"//tensorflow/lite/schema:schema_fbs",
"@llvm-project//mlir:IR",
],
)
cc_library(
name = "flatbuffer_translate_lib",
hdrs = [
"flatbuffer_export.h",
"flatbuffer_export_flags.h",
"flatbuffer_import.h",
"utils/convert_type.h",
],
deps = [
":flatbuffer_export",
":flatbuffer_import",
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:protos_all_cc",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings",
"@llvm-project//mlir:IR",
],
)
cc_library(
name = "flatbuffer_translate_registeration",
srcs = [

View File

@ -496,7 +496,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
auto &value = op.getOperand(i);
// Skip from from first variadic operands for now. Else getOperand index
// used below doesn't match.
if (value.isVariadic()) break;
if (value.isVariableLength()) break;
if (!value.name.empty())
verify_ctx.addSubst(value.name, formatv("op->getOperand({0})", i));
}
@ -504,7 +504,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
auto &value = op.getResult(i);
// Skip from from first variadic results for now. Else getResult index
// used below doesn't match.
if (value.isVariadic()) break;
if (value.isVariableLength()) break;
if (!value.name.empty())
verify_ctx.addSubst(value.name, formatv("op->getResult({0})", i));
}

View File

@ -9,7 +9,9 @@ cc_library(
name = "cost_estimators",
textual_hdrs = [
"estimator.h",
"cpu_estimators.h",
"gpu_estimators.h",
"hardware.h",
"arithmetic_count_util.h",
],
)

View File

@ -0,0 +1,45 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ARITHMETIC_COUNT_UTIL_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ARITHMETIC_COUNT_UTIL_H_
// For add/mul/div/sub and other broadcastable ops.
class ArithmeticCountUtilHelper {
public:
static bool GetArithmeticCountForBroadcastableOp(mlir::Operation* op,
int64_t* count) {
auto output = op->getResult(0);
auto output_type = output.getType().dyn_cast_or_null<RankedTensorType>();
if (!output_type || !output_type.hasStaticShape()) return false;
*count = output_type.getNumElements();
return true;
}
static bool GetInputTensorTotalSize(mlir::Operation* op, int64_t* count) {
int64_t total_count = 0;
for (auto input : op->getOperands()) {
auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
if (!input_type || !input_type.hasStaticShape()) {
return false;
}
total_count += input_type.getNumElements();
}
*count = total_count;
return true;
}
};
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ARITHMETIC_COUNT_UTIL_H_

View File

@ -0,0 +1,103 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_CPU_ESTIMATORS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_CPU_ESTIMATORS_H_
// CPU
constexpr float kCPUArithmeticUnitCost = 1.0;
// This basically assumes pure load/store. This is just fake data.
constexpr float kCPUCopyUnitCost = 0.5;
constexpr float kCPUDefaultCost = 3.0f;
// Default values.
constexpr float kCPUDefaultFixedValuedCost = 10000.0;
// tfl.add
template <>
class TFLiteCostEstimator<AddOp, hardware::CPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t count;
if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op,
&count))
return kCPUArithmeticUnitCost * count;
return kCPUDefaultFixedValuedCost;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.mul
template <>
class TFLiteCostEstimator<MulOp, hardware::CPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t count;
if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op,
&count))
return kCPUArithmeticUnitCost * count;
return kCPUDefaultFixedValuedCost;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.concatenation
template <>
class TFLiteCostEstimator<ConcatenationOp, hardware::CPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t count;
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
return kCPUCopyUnitCost * count;
return kCPUDefaultFixedValuedCost;
}
// TODO(renjieliu): We probably need to check for dynamic weights.
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.pack
template <>
class TFLiteCostEstimator<PackOp, hardware::CPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t count;
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
return kCPUCopyUnitCost * count;
return kCPUDefaultFixedValuedCost;
}
// TODO(renjieliu): We probably need to check for dynamic weights.
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.reshape
template <>
class TFLiteCostEstimator<ReshapeOp, hardware::CPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t count;
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
return kCPUCopyUnitCost * count;
return kCPUDefaultFixedValuedCost;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_CPU_ESTIMATORS_H_

View File

@ -16,6 +16,16 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
// GPU
constexpr float kGPUArithmeticUnitCost = 0.2;
// The copy can be non-consectutive copy. This is just fake data.
constexpr float kGPUCopyUnitCost = 0.2;
constexpr float kGPUDefaultCost = 1.0f;
// Default values.
constexpr float kGPUDefaultFixedValuedCost = 10000.0;
// tfl.abs
template <>
class TFLiteCostEstimator<AbsOp, hardware::GPU> {
@ -34,9 +44,11 @@ template <>
class TFLiteCostEstimator<AddOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
int64_t count;
if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op,
&count))
return kGPUArithmeticUnitCost * count;
return kGPUDefaultFixedValuedCost;
}
static bool IsSupported(mlir::Operation* op) { return true; }
@ -60,9 +72,10 @@ template <>
class TFLiteCostEstimator<ConcatenationOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
int64_t count;
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
return kGPUCopyUnitCost * count;
return kGPUDefaultFixedValuedCost;
}
// TODO(renjieliu): We probably need to check for dynamic weights.
@ -227,6 +240,33 @@ class TFLiteCostEstimator<MaximumOp, hardware::GPU> {
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.max_unpooling_2d
template <>
class TFLiteCostEstimator<MaxUnpooling2DOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.mean
template <>
class TFLiteCostEstimator<MeanOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
// TODO(renjieiu): check for constraints.
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.minimum
template <>
class TFLiteCostEstimator<MinimumOp, hardware::GPU> {
@ -245,9 +285,11 @@ template <>
class TFLiteCostEstimator<MulOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
int64_t count;
if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op,
&count))
return kGPUArithmeticUnitCost * count;
return kGPUDefaultFixedValuedCost;
}
static bool IsSupported(mlir::Operation* op) { return true; }
@ -321,6 +363,33 @@ class TFLiteCostEstimator<Relu6Op, hardware::GPU> {
// tfl.reshape
template <>
class TFLiteCostEstimator<ReshapeOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t count;
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
return kGPUCopyUnitCost * count;
return kGPUDefaultFixedValuedCost;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.rsqrt
template <>
class TFLiteCostEstimator<RsqrtOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.sin
template <>
class TFLiteCostEstimator<SinOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
@ -357,6 +426,58 @@ class TFLiteCostEstimator<SoftmaxOp, hardware::GPU> {
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.space_to_depth
template <>
class TFLiteCostEstimator<SpaceToDepthOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.sqrt
template <>
class TFLiteCostEstimator<SqrtOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.square
template <>
class TFLiteCostEstimator<SquareOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.squared_difference
template <>
class TFLiteCostEstimator<SquaredDifferenceOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.strided_slice
template <>
class TFLiteCostEstimator<StridedSliceOp, hardware::GPU> {
@ -370,6 +491,19 @@ class TFLiteCostEstimator<StridedSliceOp, hardware::GPU> {
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.tanh
template <>
class TFLiteCostEstimator<TanhOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.transpose
template <>
class TFLiteCostEstimator<TransposeOp, hardware::GPU> {
@ -383,5 +517,18 @@ class TFLiteCostEstimator<TransposeOp, hardware::GPU> {
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.transpose_conv
template <>
class TFLiteCostEstimator<TransposeConvOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_

View File

@ -59,13 +59,11 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/Functional.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Translation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -676,8 +674,8 @@ template <typename ContainerType>
mlir::NamedAttribute BuildTFEntryFunctionAttribute(
const tflite::SubGraphT& subgraph, Builder* builder, const std::string name,
const ContainerType indices) {
llvm::SmallVector<std::string, 8> tensor_names = mlir::functional::map(
[&](int i) { return subgraph.tensors.at(i)->name; }, indices);
auto tensor_names = llvm::map_range(
indices, [&](int i) { return subgraph.tensors.at(i)->name; });
return builder->getNamedAttr(
name, builder->getStringAttr(llvm::join(tensor_names, ",")));
}

View File

@ -25,7 +25,7 @@ limitations under the License.
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h"

View File

@ -28,7 +28,6 @@ limitations under the License.
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
#include "mlir/Support/Functional.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/lite/schema/schema_generated.h"
@ -54,6 +53,8 @@ class TensorFlowLiteDialect : public Dialect {
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc"
// Include all specializes estimators below this line
#include "tensorflow/compiler/mlir/lite/experimental/estimators/arithmetic_count_util.h"
#include "tensorflow/compiler/mlir/lite/experimental/estimators/cpu_estimators.h"
#include "tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h"
} // end namespace TFL

View File

@ -450,7 +450,7 @@ retained with length 1.
}
def TFL_TransposeConvOp:
TFL_Op<"transpose_conv", [NoSideEffect]> {
TFL_Op<"transpose_conv", [NoSideEffect, TFL_GpuTargetOp]> {
let summary = "Transpose convolution operator";
let description = [{
@ -1658,7 +1658,7 @@ def TFL_MaxPoolingWithArgMax2DOp :
}
def TFL_MaxUnpooling2DOp :
Op<TFL_Dialect, "max_unpooling_2d", [NoSideEffect]> {
Op<TFL_Dialect, "max_unpooling_2d", [NoSideEffect, TFL_GpuTargetOp]> {
let summary = "Max Unpool 2D";
let description = [{
@ -1711,7 +1711,7 @@ def TFL_MaximumOp : TFL_Op<"maximum", [
let hasOptions = 0;
}
def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect]> {
def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect, TFL_GpuTargetOp]> {
let summary = "Mean operator";
let description = [{
@ -2116,7 +2116,9 @@ def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape,
let builders = [TFL_BroadcastableBinaryBuilder];
}
def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect, TFL_GpuTargetOp]> {
def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect,
TFL_GpuTargetOp,
SameOperandsAndResultsScale]> {
let summary = "Parameterized Relu operator";
let description = [{
@ -2165,6 +2167,17 @@ def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect,
let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x);
let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y);
// This builder doesn't work with quantized type, so it can only be used by
// non-quantization tablegen patterns. Currently, it is used by the
// elementwise-move reordering pattern in the optimize_patterns.td
let builders = [OpBuilder<
"Builder *, OperationState &state, Value input",
[{
state.addOperands({input});
state.addTypes(input.getType());
}]>
];
}
def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
@ -2181,6 +2194,17 @@ def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x);
let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y);
// This builder doesn't work with quantized type, so it can only be used by
// non-quantization tablegen patterns. Currently, it is used by the
// elementwise-move reordering pattern in the optimize_patterns.td
let builders = [OpBuilder<
"Builder *, OperationState &state, Value input",
[{
state.addOperands({input});
state.addTypes(input.getType());
}]>
];
}
def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect,
@ -2196,6 +2220,17 @@ def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect,
let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x);
let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y);
// This builder doesn't work with quantized type, so it can only be used by
// non-quantization tablegen patterns. Currently, it is used by the
// elementwise-move reordering pattern in the optimize_patterns.td
let builders = [OpBuilder<
"Builder *, OperationState &state, Value input",
[{
state.addOperands({input});
state.addTypes(input.getType());
}]>
];
}
def TFL_ReshapeOp: TFL_Op<"reshape", [
@ -2247,7 +2282,10 @@ slice `i`, with the first `seq_lengths[i]` slices along dimension
let hasOptions = 1;
}
def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect,
SameOperandsAndResultType,
NoQuantizableResult,
TFL_GpuTargetOp]> {
let summary = "Reciprocal of square root operator";
let description = [{
@ -2395,7 +2433,10 @@ def TFL_SelectV2Op : TFL_Op<"select_v2", [NoSideEffect]> {
}
def TFL_SinOp: TFL_Op<"sin", [
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
NoSideEffect,
SameOperandsAndResultType,
NoQuantizableResult,
TFL_GpuTargetOp]> {
let summary = "Sine operator";
let description = [{
@ -2437,7 +2478,10 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [
}
def TFL_SqrtOp: TFL_Op<"sqrt", [
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
NoSideEffect,
SameOperandsAndResultType,
NoQuantizableResult,
TFL_GpuTargetOp]> {
let summary = "Square root operator";
let description = [{
@ -2452,7 +2496,10 @@ def TFL_SqrtOp: TFL_Op<"sqrt", [
}
def TFL_SquareOp: TFL_Op<"square", [
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
NoSideEffect,
SameOperandsAndResultType,
NoQuantizableResult,
TFL_GpuTargetOp]> {
let summary = "Square operator";
let description = [{
@ -2496,7 +2543,10 @@ def TFL_SubOp : TFL_Op<"sub", [ResultsBroadcastableShape, NoSideEffect]> {
// TODO(jpienaar): Expand the kernel implementation to support all types besides
// I32 and F32.
def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
ResultsBroadcastableShape,
NoSideEffect,
NoQuantizableResult,
TFL_GpuTargetOp]> {
let summary = "Squared difference operator";
let description = [{
@ -2523,7 +2573,8 @@ def TFL_TanhOp: TFL_Op<"tanh", [
// zero_point = central_value
// scale = 1. / (central_value - min_value)
FixedResultScale<Int8UniformQuantizedType<0, 78125, -7>>,
FixedResultScale<UInt8UniformQuantizedType<128, 78125, -7>>]> {
FixedResultScale<UInt8UniformQuantizedType<128, 78125, -7>>,
TFL_GpuTargetOp]> {
let summary = "Hyperbolic tangent operator";
let description = [{
@ -2533,6 +2584,17 @@ def TFL_TanhOp: TFL_Op<"tanh", [
let arguments = (ins TFL_TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$x);
let results = (outs TFL_TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$y);
// This builder doesn't work with quantized type, so it can only be used by
// non-quantization tablegen patterns. Currently, it is used by the
// elementwise-move reordering pattern in the optimize_patterns.td
let builders = [OpBuilder<
"Builder *, OperationState &state, Value input",
[{
state.addOperands({input});
state.addTypes(input.getType());
}]>
];
}
def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, SameOperandsAndResultsScale,
@ -2718,7 +2780,8 @@ def TFL_SpaceToDepthOp: TFL_Op<"space_to_depth", [
NoSideEffect,
SameOperandsAndResultsScale,
PredOpTrait<"input and output must have same element type",
TCresVTEtIsSameAsOp<0, 0>>
TCresVTEtIsSameAsOp<0, 0>>,
TFL_GpuTargetOp
]> {
let summary = "SpaceToDepth operator";
@ -2981,14 +3044,13 @@ def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [
}];
let arguments = (ins
// TODO: add uint8 support when ready.
TFL_TensorOf<[F32, I32, I64]>:$input,
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8]>:$input,
TFL_TensorOf<[I32, I64]>:$pad,
TFL_MirrorPaddingAttr:$mode
);
let results = (outs
TFL_TensorOf<[F32, I32, I64]>:$output
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8]>:$output
);
let hasOptions = 1;

View File

@ -1,4 +1,4 @@
load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary")
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_native_cc_binary")
load(
"//tensorflow/core/platform:build_config.bzl",
"tf_proto_library",
@ -115,11 +115,22 @@ tf_native_cc_binary(
],
)
cc_library(
name = "numerical_utils",
srcs = ["numerical_utils.cc"],
hdrs = ["numerical_utils.h"],
deps = [
"@com_google_absl//absl/types:optional",
],
)
cc_library(
name = "device_target",
srcs = ["device_target.cc"],
hdrs = ["device_target.h"],
deps = [
":numerical_utils",
"@com_google_absl//absl/types:optional",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
@ -142,3 +153,13 @@ cc_library(
"@llvm-project//mlir:Support",
],
)
tf_cc_test(
name = "numerical_utils_test",
srcs = ["numerical_utils_test.cc"],
deps = [
":numerical_utils",
"@com_google_absl//absl/types:optional",
"@com_google_googletest//:gtest_main",
],
)

View File

@ -15,12 +15,18 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
#include <algorithm>
#include "absl/types/optional.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h"
namespace mlir {
namespace quant {
@ -39,7 +45,7 @@ DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) {
assert(qi8n_ == qi8n_);
}
Optional<KernelSpec> DeviceTarget::Get(QuantizeRegionOp op) const {
Optional<KernelSpec> DeviceTarget::GetKernelSpec(QuantizeRegionOp op) const {
auto kernel_specs_it = specs_.find(op.logical_kernel());
if (kernel_specs_it == specs_.end()) return llvm::None;
@ -50,9 +56,15 @@ Optional<KernelSpec> DeviceTarget::Get(QuantizeRegionOp op) const {
return kernel_specs_it->getValue().Find(signature);
}
ScaleDecomposeFn DeviceTarget::GetDecomposeFn(QuantizeRegionOp op) const {
auto kernel_specs_it = specs_.find(op.logical_kernel());
if (kernel_specs_it == specs_.end()) return ScaleDecomposeFn(nullptr);
return kernel_specs_it->second.GetDecomposeFn();
}
LogicalResult DeviceTarget::RegisterKernel(
llvm::StringRef kernel, const KernelSpecs::Signature& signature,
const ScaleFn& fn) {
const ScaleFn& fn, const ScaleDecomposeFn& dfn) {
return specs_[kernel].Add(signature, {ScaleConstraintType::CustomScale, fn});
}
@ -78,5 +90,49 @@ void DeviceTarget::AppendToSignature(ArrayAttr specs_attr,
}
}
LogicalResult DeviceTarget::DecomposeMultiplyAccumulateScale(
Operation* op, quant::QuantizedMultipliers* input_multipliers,
quant::QuantizedMultipliers* output_multipliers,
quant::QuantizedRanges* output_ranges) {
auto rop = llvm::dyn_cast<quant::QuantizeRegionOp>(op);
if (!rop) return failure();
llvm::SmallVector<Type, 4> input_specs, out_specs;
for (auto spec : rop.input_specs()) {
input_specs.push_back(spec.cast<TypeAttr>().getValue());
}
for (auto spec : rop.output_specs()) {
out_specs.push_back(spec.cast<TypeAttr>().getValue());
}
auto in_spec = input_specs[0].dyn_cast<quant::UniformQuantizedType>();
// TODO(fengliuai): handles the PerAxis QuantizedType.
auto w_spec = input_specs[1].dyn_cast<quant::UniformQuantizedType>();
auto b_spec = input_specs[2].dyn_cast<quant::UniformQuantizedType>();
auto o_spec = out_specs[0].dyn_cast<quant::UniformQuantizedType>();
if (!in_spec || !w_spec || !b_spec || !o_spec) return failure();
double scale_product = in_spec.getScale() * w_spec.getScale();
if (fabs(scale_product - b_spec.getScale()) < 1e-6) return failure();
// input multipliers
input_multipliers->append(3, kUnitQuantizedMultiplier);
// output multipliers
double real_multiplier = o_spec.getScale() / scale_product;
output_multipliers->push_back(quant::QuantizeMultiplier(real_multiplier));
// output ranges
auto min = rop.getAttrOfType<FloatAttr>("min");
auto max = rop.getAttrOfType<FloatAttr>("max");
output_ranges->push_back(quant::CalculateQuantizedRange(
o_spec.getScale(), o_spec.getZeroPoint(),
(min ? absl::optional<double>(min.getValueAsDouble()) : absl::nullopt),
(max ? absl::optional<double>(max.getValueAsDouble()) : absl::nullopt),
o_spec.getStorageTypeMin(), o_spec.getStorageTypeMax()));
return success();
}
} // namespace quant
} // namespace mlir

View File

@ -17,13 +17,13 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_DEVICE_TARGET_H_
#include <functional>
#include <ostream>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
@ -33,6 +33,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h"
namespace mlir {
namespace quant {
@ -40,9 +41,17 @@ namespace quant {
class QuantizeContext;
using AdjacentOperations = llvm::SmallVectorImpl<Operation*>;
using QuantizedMultipliers = llvm::SmallVector<QuantizedMultiplier, 4>;
using QuantizedRanges = llvm::SmallVector<QuantizedRange, 4>;
using ScaleFn = std::function<LogicalResult(QuantizeContext*, Operation*,
AdjacentOperations*, bool*)>;
using ScaleDecomposeFn =
std::function<LogicalResult(Operation*, QuantizedMultipliers*,
QuantizedMultipliers*, QuantizedRanges*)>;
static const QuantizedMultiplier kUnitQuantizedMultiplier{1, 0};
enum class ScaleConstraintType {
OutputInputSameScale,
OutputInputFreeScale,
@ -73,12 +82,25 @@ class KernelSpecs {
}
}
ScaleDecomposeFn GetDecomposeFn() const { return decompose_fn_; }
// Adds the kernel signature with the kernel specification.
LogicalResult Add(const Signature& signature, const KernelSpec& spec) {
if (all_signatures_.insert({signature, spec}).second) return success();
return failure();
}
KernelSpecs& WithSignature(const KernelSpecs::Signature& signature,
const ScaleFn& fn) {
Add(signature, {ScaleConstraintType::CustomScale, fn});
return *this;
}
KernelSpecs& WithImpl(const ScaleDecomposeFn& dfn) {
decompose_fn_ = dfn;
return *this;
}
private:
// The signature is pattern match based.
struct SignatureInfo : public llvm::DenseMapInfo<Signature> {
@ -101,6 +123,10 @@ class KernelSpecs {
// Maps the signature to the kernel spec. Note that the matching is
// pattern match based.
llvm::DenseMap<Signature, KernelSpec, SignatureInfo> all_signatures_;
// A method to compute the effective multipliers. This is independent on the
// bits of the ports, thus all the signature shares the same here.
ScaleDecomposeFn decompose_fn_;
};
class DeviceTarget {
@ -108,19 +134,26 @@ class DeviceTarget {
explicit DeviceTarget(MLIRContext* ctx);
// Retrieves the kernel spec for the quant region op.
Optional<KernelSpec> Get(quant::QuantizeRegionOp op) const;
Optional<KernelSpec> GetKernelSpec(quant::QuantizeRegionOp op) const;
// Retrieves the scale decomposition function for the quant region op.
ScaleDecomposeFn GetDecomposeFn(quant::QuantizeRegionOp op) const;
protected:
// Adds the kernel spec with the custom scale function for the kernel.
LogicalResult RegisterKernel(llvm::StringRef kernel,
const KernelSpecs::Signature& signature,
const ScaleFn& fn);
const ScaleFn& fn, const ScaleDecomposeFn& dfn);
// Adds the kernel spec with the scale constraint type for the kernel.
LogicalResult RegisterKernel(llvm::StringRef kernel,
const KernelSpecs::Signature& signature,
const ScaleConstraintType constraint);
// Adds the kernel with the name. Retrun an existing one if it has been
// added before.
KernelSpecs& RegisterKernel(llvm::StringRef kernel) { return specs_[kernel]; }
// converts specification to signature:
// - UniformedQuantizedType -> AnyQuantizedType
// - AnyQuantizedType (int) -> AnyQuantizedType
@ -128,6 +161,13 @@ class DeviceTarget {
void AppendToSignature(ArrayAttr specs_attr,
KernelSpecs::Signature* signature) const;
// For "mulmat->add" type of kernels, convert the scales of all the ports to
// multipliers.
static LogicalResult DecomposeMultiplyAccumulateScale(
Operation* op, quant::QuantizedMultipliers* input_multipliers,
quant::QuantizedMultipliers* output_multipliers,
quant::QuantizedRanges* output_ranges);
// A set of parameters are required to build the signatures.
FloatType f32_;
IntegerType i8_;

View File

@ -33,7 +33,6 @@ limitations under the License.
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/Functional.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_info.pb.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"

View File

@ -0,0 +1,82 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h"
#include <assert.h>
#include <algorithm>
#include <cmath>
#include <limits>
#include "absl/types/optional.h"
namespace mlir {
namespace quant {
// This method is adopted from TFLite:
// ["tensorflow/lite/kernels/internal/quantization_util.cc"]
QuantizedMultiplier QuantizeMultiplier(double double_multiplier) {
if (double_multiplier < 1e-6) {
return {0, 0};
}
int32_t shift;
const double q = frexp(double_multiplier, &shift);
auto q_fixed = static_cast<int64_t>(round(q * (1ll << 31)));
assert(q_fixed <= (1ll << 31));
if (q_fixed == (1ll << 31)) {
q_fixed /= 2;
++shift;
}
assert(q_fixed <= std::numeric_limits<int32_t>::max());
// A shift amount smaller than -31 would cause all bits to be shifted out
// and thus all results would be zero. We implement that instead with
// q_fixed==0, so as to avoid hitting issues with right-shift
// operations with shift amounts greater than 31. Note that this happens
// roughly when abs(double_multiplier) < 2^-31 and the present handling means
// that we're effectively flushing tiny double_multiplier's to zero.
// We could conceivably handle values in the range (roughly) [32, 63]
// as 'denormals' i.e. (shift==0, q_fixed < 2^30). In that point of view
// the present handling is just doing 'flush denormals to zero'. We could
// reconsider and actually generate nonzero denormals if a need arises.
if (shift < -31) {
shift = 0;
q_fixed = 0;
}
return {static_cast<int32_t>(q_fixed), shift};
}
QuantizedRange CalculateQuantizedRange(double scale, int32_t zero_point,
absl::optional<double> rmin,
absl::optional<double> rmax,
int32_t qmin, int32_t qmax) {
auto quantize = [scale, zero_point](float f) {
return zero_point + static_cast<int32_t>(std::round(f / scale));
};
if (rmin.has_value() && rmax.has_value()) {
return {std::max(qmin, quantize(rmin.value())),
std::min(qmax, quantize(rmax.value()))};
} else if (rmin.has_value()) {
return {std::max(qmin, quantize(rmin.value())), qmax};
} else if (rmax.has_value()) {
return {qmin, std::min(qmax, quantize(rmax.value()))};
} else {
return {qmin, qmax};
}
}
} // namespace quant
} // namespace mlir

View File

@ -0,0 +1,45 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_NUMERICAL_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_NUMERICAL_UTILS_H_
#include <cstdint>
#include <utility>
#include "absl/types/optional.h"
namespace mlir {
namespace quant {
using QuantizedMultiplier = std::pair<int32_t, int32_t>;
using QuantizedRange = std::pair<int32_t, int32_t>;
// Decompose double precision multiplier to integer multiplier and exponent.
// double_multiplier = int_multiplier * 2 ^ (-31 + exponent)
// int_multiplier will be range of (2^31, 2^30].
QuantizedMultiplier QuantizeMultiplier(double double_multiplier);
// Calculate the effective quantized value range for the scale, zero point. The
// range is the minimum range defined by [rmin, rmax] and [qmin, qmax].
QuantizedRange CalculateQuantizedRange(double scale, int32_t zero_point,
absl::optional<double> rmin,
absl::optional<double> rmax,
int32_t qmin, int32_t qmax);
} // namespace quant
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_NUMERICAL_UTILS_H_

View File

@ -0,0 +1,114 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h"
#include <cmath>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/types/optional.h"
namespace mlir {
namespace quant {
namespace {
double ComposeScale(const QuantizedMultiplier& input) {
return input.first * exp2(-31 + input.second);
}
TEST(NumericalUtils, QuantizeMultiplier) {
// Decompose multiplier larger than 1.
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e6)), 1.0e6);
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e3)), 1.0e3);
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(10.)), 10.);
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(5.)), 5.);
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(2.)), 2.);
// Decompose multiplier between 1.0 and 1e-6.
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(0.0)), 0.0);
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0)), 1.0);
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-1)), 1.0e-1);
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-2)), 1.0e-2);
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-3)), 1.0e-3);
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-4)), 1.0e-4);
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-5)), 1.0e-5);
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-6)), 1.0e-6);
// When scale is smaller than 1.0e-6, it is decomposed to {0, 0}.
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-7)), 0.0);
ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-8)), 0.0);
}
TEST(NumericalUtils, ActivationRange) {
// zero point = 0
auto a =
CalculateQuantizedRange(1e-6, 0, absl::nullopt, absl::nullopt, -128, 127);
ASSERT_EQ(a.first, -128);
ASSERT_EQ(a.second, 127);
auto b = CalculateQuantizedRange(1e-6, 0, 0.0, absl::nullopt, -128, 127);
ASSERT_EQ(b.first, 0);
ASSERT_EQ(b.second, 127);
auto c = CalculateQuantizedRange(1e-6, 0, -1.0, 1.0, -128, 127);
ASSERT_EQ(c.first, -128);
ASSERT_EQ(c.second, 127);
auto d = CalculateQuantizedRange(1e-6, 0, 0.0, 6.0, -128, 127);
ASSERT_EQ(d.first, 0);
ASSERT_EQ(d.second, 127);
// zero point = 100
auto e = CalculateQuantizedRange(1e-6, 100, absl::nullopt, absl::nullopt,
-128, 127);
ASSERT_EQ(e.first, -128);
ASSERT_EQ(e.second, 127);
auto f = CalculateQuantizedRange(1e-6, 100, 0.0, absl::nullopt, -128, 127);
ASSERT_EQ(f.first, 100);
ASSERT_EQ(f.second, 127);
auto g = CalculateQuantizedRange(1e-6, 100, -1.0, 1.0, -128, 127);
ASSERT_EQ(g.first, -128);
ASSERT_EQ(g.second, 127);
auto h = CalculateQuantizedRange(1e-6, 100, 0.0, 6.0, -128, 127);
ASSERT_EQ(h.first, 100);
ASSERT_EQ(h.second, 127);
// zero point = -100
auto i = CalculateQuantizedRange(1e-6, -100, absl::nullopt, absl::nullopt,
-128, 127);
ASSERT_EQ(i.first, -128);
ASSERT_EQ(i.second, 127);
auto j = CalculateQuantizedRange(1e-6, -100, 0.0, absl::nullopt, -128, 127);
ASSERT_EQ(j.first, -100);
ASSERT_EQ(j.second, 127);
auto k = CalculateQuantizedRange(1e-6, -100, -1.0, 1.0, -128, 127);
ASSERT_EQ(k.first, -128);
ASSERT_EQ(k.second, 127);
auto l = CalculateQuantizedRange(1e-6, -100, 0.0, 6.0, -128, 127);
ASSERT_EQ(l.first, -100);
ASSERT_EQ(l.second, 127);
}
} // namespace
} // namespace quant
} // namespace mlir

View File

@ -67,7 +67,7 @@ std::vector<quant::QuantizeRegionOp> QuantizeContext::GetAllOps() {
LogicalResult QuantizeContext::Handle(
quant::QuantizeRegionOp op, llvm::SmallVectorImpl<Operation *> *new_items,
bool *changed) {
auto spec = target_spec_.Get(op);
auto spec = target_spec_.GetKernelSpec(op);
if (!spec.hasValue()) {
op.emitWarning(
"Couldn't find kernel from the registeration for quantization.");

View File

@ -146,7 +146,7 @@ void LegalizeTFToQuant::runOnFunction() {
auto func = getFunction();
auto *ctx = func.getContext();
patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant>(ctx);
applyPatternsGreedily(func, patterns);
applyPatternsAndFoldGreedily(func, patterns);
}
} // namespace

View File

@ -23,10 +23,10 @@ filegroup(
data = [
":importer_test_legacy_reshape",
":importer_test_min_max",
":test_schema.fbs",
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
"//tensorflow/compiler/mlir/lite:flatbuffer_translate",
"//tensorflow/compiler/mlir/lite:json_to_flatbuffer",
"//tensorflow/lite/schema:schema.fbs",
"@llvm-project//llvm:FileCheck",
],
)

View File

@ -1,4 +1,4 @@
// RUN: json_to_flatbuffer %p/../../../../../lite/schema/schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --dump-input-on-failure %s
// RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --dump-input-on-failure %s
// CHECK: %cst = constant unit
// CHECK: %[[RES0:.*]] = "tfl.conv_2d"(%arg0, %arg1, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 0 : i32, stride_w = 0 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, none) -> tensor<256x32x32x16xf32>

View File

@ -1,4 +1,4 @@
// RUN: json_to_flatbuffer %p/../../../../../lite/schema/schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --dump-input-on-failure %s
// RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --dump-input-on-failure %s
// This test is to test that if the flatbuffer omits the last optional input `bias` of tfl.conv_2d op, the flatbuffer_importer will automatically adds `none` value to tfl.conv_2d.

File diff suppressed because it is too large Load Diff

View File

@ -52,14 +52,14 @@ func @squeezeAndReshape(%arg0: tensor<1x1x10xf32>, %arg1: tensor<?x10xf32>) -> i
%1 = "tf.Squeeze"(%arg1) : (tensor<?x10xf32>) -> tensor<*xf32>
%2 = "tf.Const"() { value = dense<[2, 5]> : tensor<2xi32> } : () -> tensor<2xi32>
%3 = "tf.Reshape" (%0, %2) : (tensor<1x10xf32>, tensor<2xi32>) -> tensor<2x5xf32>
%4 = "some_op"(%1, %3) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
%4 = "tf.some_op"(%1, %3) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
return %4 : i32
// CHECK-LABEL: squeezeAndReshape
// CHECK: "tfl.squeeze"(%arg0) {squeeze_dims = [0]} : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
// CHECK: %1 = "tfl.squeeze"(%arg1) {squeeze_dims = []} : (tensor<?x10xf32>) -> tensor<*xf32>
// CHECK: %cst = constant dense<[2, 5]> : tensor<2xi32>
// CHECK: %2 = "tfl.reshape"(%0, %cst) : (tensor<1x10xf32>, tensor<2xi32>) -> tensor<2x5xf32>
// CHECK: %3 = "some_op"(%1, %2) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
// CHECK: %3 = "tf.some_op"(%1, %2) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
// CHECK: return
}

View File

@ -1,4 +1,4 @@
// RUN: tf-opt -tfl-load-recipe %s | FileCheck %s --dump-input-on-failure
// RUN: tf-opt -allow-unregistered-dialect -tfl-load-recipe %s | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: testLstm
func @testLstm(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>, %arg3: tensor<*xf32>, %arg4: tensor<*xf32>, %arg5: tensor<*xf32>, %arg6: tensor<*xf32>, %arg7: tensor<*xf32>, %arg8: tensor<*xf32>, %arg9: tensor<*xf32>, %arg10: tensor<*xf32>, %arg11: tensor<*xf32>, %arg12: tensor<*xf32>, %arg13: tensor<*xf32>, %arg14: tensor<*xf32>, %arg15: tensor<*xf32>, %arg16: tensor<*xf32>, %arg17: tensor<*xf32>, %arg18: tensor<*xf32>, %arg19: tensor<*xf32>, %arg20: tensor<*xf32>, %arg21: tensor<*xf32>, %arg22: tensor<*xf32>, %arg23: tensor<*xf32>) -> tensor<*xf32> {

View File

@ -439,6 +439,31 @@ func @NotReorderReshapeAddIfNotTailingDim(%arg0: tensor<40x40x1xf32>) -> tensor<
// CHECK: return %[[rs2]]
}
// CHECK-LABEL: @ReorderElementwiseValueOpAndMoveOp
func @ReorderElementwiseValueOpAndMoveOp(%arg0: tensor<40x40x1xf32>) -> tensor<40x40xf32> {
%shape = constant dense<[40, 40]> : tensor<2xi32>
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<40x40x1xf32>, tensor<2xi32>) -> tensor<40x40xf32>
%2 = "tfl.relu"(%1) : (tensor<40x40xf32>) -> tensor<40x40xf32>
return %2 : tensor<40x40xf32>
// CHECK: %[[rs1:.*]] = "tfl.relu"(%arg0
// CHECK: %[[rs2:.*]] = "tfl.reshape"(%[[rs1]]
// CHECK: return %[[rs2]]
}
// CHECK-LABEL: @NotReorderElementwiseValueOpAndMoveOp
func @NotReorderElementwiseValueOpAndMoveOp(%arg0: tensor<40x40x1xf32>) -> (tensor<40x40xf32>, tensor<40x40xf32>) {
%shape = constant dense<[40, 40]> : tensor<2xi32>
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<40x40x1xf32>, tensor<2xi32>) -> tensor<40x40xf32>
%2 = "tfl.relu"(%1) : (tensor<40x40xf32>) -> tensor<40x40xf32>
return %1, %2 : tensor<40x40xf32>, tensor<40x40xf32>
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0
// CHECK: %[[rs2:.*]] = "tfl.relu"(%[[rs1]]
// CHECK: return %[[rs1]], %[[rs2]]
}
// CHECK-LABEL: @FuseFullyConnectedRelu
func @FuseFullyConnectedRelu(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> {
%0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32>
@ -450,6 +475,28 @@ func @FuseFullyConnectedRelu(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32
// CHECK: return %[[RES]]
}
// CHECK-LABEL: @FuseFullyConnectedRelu6
func @FuseFullyConnectedRelu6(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> {
%0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32>
%1 = "tfl.relu6"(%0) : (tensor<1x128xf32>) -> tensor<1x128xf32>
return %1 : tensor<1x128xf32>
// CHECK: %[[RES:[0-9].*]] = "tfl.fully_connected"
// CHECK-SAME: fused_activation_function = "RELU6"
// CHECK: return %[[RES]]
}
// CHECK-LABEL: @FuseFullyConnectedRelu1
func @FuseFullyConnectedRelu1(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> {
%0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32>
%1 = "tfl.relu_n1_to_1"(%0) : (tensor<1x128xf32>) -> tensor<1x128xf32>
return %1 : tensor<1x128xf32>
// CHECK: %[[RES:[0-9].*]] = "tfl.fully_connected"
// CHECK-SAME: fused_activation_function = "RELU_N1_TO_1"
// CHECK: return %[[RES]]
}
// CHECK-LABEL: @HardSwishPattern
func @HardSwishPattern(%arg0: tensor<1xf32>) -> tensor<1xf32> {
%three = constant dense<3.> : tensor<f32>

View File

@ -161,7 +161,7 @@ func @_functionalize_if_else_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>
}
func @_functionalize_if_then_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
%0 = "my_unknown_op.blah"() : () -> tensor<i1>
%0 = "tf.blah"() : () -> tensor<i1>
return %0 : tensor<i1>
}
@ -199,7 +199,7 @@ func @_functionalize_if_else_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>
}
func @_functionalize_if_then_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
%0 = "my_unknown_op.blah"() : () -> tensor<i1>
%0 = "tf.blah"() : () -> tensor<i1>
return %0 : tensor<i1>
}

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/AsmState.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
@ -128,6 +129,7 @@ int main(int argc, char **argv) {
// We need to disable duplicated ones to provide a cleaner command-line option
// interface. That also means we need to relay the value set in one option to
// all its aliases.
mlir::registerAsmPrinterCLOptions();
llvm::cl::ParseCommandLineOptions(
argc, argv, "TF GraphDef to TFLite FlatBuffer converter\n");

View File

@ -20,7 +20,6 @@ limitations under the License.
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/LLVM.h"
#include "absl/memory/memory.h"
#include "llvm/ADT/STLExtras.h"

View File

@ -30,7 +30,7 @@ void IdentifyDilatedConvPass::runOnFunction() {
patterns.insert<ConvertTFDilatedConvOp<TF::Conv2DOp>,
ConvertTFDilatedConvOp<TF::DepthwiseConv2dNativeOp>>(
&getContext());
applyPatternsGreedily(func, patterns);
applyPatternsAndFoldGreedily(func, patterns);
}
} // namespace

View File

@ -38,7 +38,6 @@ limitations under the License.
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/Functional.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"

View File

@ -36,7 +36,6 @@ limitations under the License.
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/Functional.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
@ -288,12 +287,10 @@ LogicalResult ConvertTFSplitOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_split_op = cast<TF::SplitOp>(op);
auto output_types = functional::map([](Value v) { return v.getType(); },
tf_split_op.output());
// Number of splits cannot be negative.
auto num_split = rewriter.getI32IntegerAttr(tf_split_op.num_split());
rewriter.replaceOpWithNewOp<TFL::SplitOp>(op, output_types,
rewriter.replaceOpWithNewOp<TFL::SplitOp>(op, tf_split_op.output().getTypes(),
tf_split_op.split_dim(),
tf_split_op.value(), num_split);
return success();
@ -303,14 +300,12 @@ LogicalResult ConvertTFSplitVOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_splitv_op = cast<TF::SplitVOp>(op);
auto output_types = functional::map([](Value v) { return v.getType(); },
tf_splitv_op.output());
// Number of splits cannot be negative.
auto num_split = rewriter.getI32IntegerAttr(tf_splitv_op.num_split());
rewriter.replaceOpWithNewOp<TFL::SplitVOp>(
op, output_types, tf_splitv_op.value(), tf_splitv_op.size_splits(),
tf_splitv_op.split_dim(), num_split);
op, tf_splitv_op.output().getTypes(), tf_splitv_op.value(),
tf_splitv_op.size_splits(), tf_splitv_op.split_dim(), num_split);
return success();
}
@ -402,13 +397,12 @@ LogicalResult ConvertTFUnpackOp::matchAndRewrite(
auto tf_unpack_op = cast<TF::UnpackOp>(op);
auto input = tf_unpack_op.value();
auto output_types = functional::map([](Value v) { return v.getType(); },
tf_unpack_op.output());
auto num = rewriter.getI32IntegerAttr(tf_unpack_op.num());
// Axis can be negative.
auto axis = rewriter.getI32IntegerAttr(tf_unpack_op.axis().getSExtValue());
rewriter.replaceOpWithNewOp<UnpackOp>(op, output_types, input, num, axis);
rewriter.replaceOpWithNewOp<UnpackOp>(op, tf_unpack_op.output().getTypes(),
input, num, axis);
return success();
}

View File

@ -49,7 +49,6 @@ limitations under the License.
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/Functional.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project

View File

@ -37,7 +37,6 @@ limitations under the License.
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/Functional.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
@ -52,6 +51,9 @@ namespace TFL {
//===----------------------------------------------------------------------===//
// The actual Optimize Pass.
namespace {
constexpr char kRelu[] = "RELU";
constexpr char kRelu6[] = "RELU6";
constexpr char kRelu1[] = "RELU_N1_TO_1";
bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) {
if (sq_op.getType().cast<ShapedType>().getRank() - 1 ==
@ -301,10 +303,11 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
};
// TODO(b/136285429): Move to tablegen when variadic is supported.
struct FuseFullyConnectedAndRelu : public OpRewritePattern<TFL::ReluOp> {
using OpRewritePattern<TFL::ReluOp>::OpRewritePattern;
template <typename ReluXOp, char const *Act>
struct FuseFullyConnectedAndReluX : public OpRewritePattern<ReluXOp> {
using OpRewritePattern<ReluXOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TFL::ReluOp relu_op,
LogicalResult matchAndRewrite(ReluXOp relu_op,
PatternRewriter &rewriter) const override {
Operation *input = relu_op.getOperand().getDefiningOp();
if (!isa_and_nonnull<FullyConnectedOp>(input)) return failure();
@ -312,7 +315,7 @@ struct FuseFullyConnectedAndRelu : public OpRewritePattern<TFL::ReluOp> {
if (fully_connected_op.fused_activation_function() != "NONE")
return failure();
auto new_activation_func = rewriter.getStringAttr("RELU");
auto new_activation_func = rewriter.getStringAttr(Act);
auto new_weights_format =
rewriter.getStringAttr(fully_connected_op.weights_format());
auto new_keep_num_dims =
@ -709,9 +712,12 @@ void Optimize::runOnFunction() {
// we explore these potentially first and then fuse the binary ops with the
// following ops in a second pattern match.
TFL::populateWithGenerated(ctx, &patterns);
patterns.insert<FuseFullyConnectedAndAdd, FuseFullyConnectedAndRelu,
patterns.insert<FuseFullyConnectedAndAdd,
FuseFullyConnectedAndReluX<TFL::ReluOp, kRelu>,
FuseFullyConnectedAndReluX<TFL::Relu6Op, kRelu6>,
FuseFullyConnectedAndReluX<TFL::Relu1Op, kRelu1>,
FuseFullyConnectedAndMul>(ctx);
applyPatternsGreedily(func, patterns);
applyPatternsAndFoldGreedily(func, patterns);
// Fuse the binary ops with the following ops.
patterns.insert<
@ -719,7 +725,7 @@ void Optimize::runOnFunction() {
FuseBinaryOpToFollowingFullyConnected, FuseConv2DAndMulWithQDQs,
FuseDepthwiseConv2DAndMulWithQDQs, ConvertTrivialTransposeOpToReshapeOp>(
ctx);
applyPatternsGreedily(func, patterns);
applyPatternsAndFoldGreedily(func, patterns);
}
} // namespace

View File

@ -187,7 +187,7 @@ void OptimizeFunctionalOpsPass::runOnOperation() {
patterns.insert<FoldIfOp>(&getContext(), &inlined_funcs);
ModuleOp module = getOperation();
applyPatternsGreedily(module, patterns);
applyPatternsAndFoldGreedily(module, patterns);
// Erase inlined functions that don't have any references.
//

View File

@ -378,6 +378,19 @@ foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
(IsTailOfShape $rhs, $input)]>;
}
// Reorder the element-wise value operations and the element move operations,
// such that the value operation happens before move operation.
foreach ValueOp = [TFL_CeilOp, TFL_ExpOp, TFL_FloorOp, TFL_NegOp,
TFL_ReluOp, TFL_Relu1Op, TFL_Relu6Op, TFL_RoundOp,
TFL_TanhOp, TFL_SqrtOp, TFL_SquareOp] in {
foreach MoveOp = [TFL_DepthToSpaceOp, TFL_ExpandDimsOp, TFL_SqueezeOp,
TFL_ReshapeOp, TFL_TransposeOp] in {
def : Pat<(ValueOp:$value (MoveOp:$move $input, $move_def)),
(MoveOp (ValueOp $input), $move_def),
[(HasOneUse $move)]>;
}
}
// Returns shape of a ranked tensor.
// if called without a ranked tensor it will fail.
def GetShape: NativeCodeCall<"GetShape($0)">;
@ -394,8 +407,9 @@ def : Pat<(TFL_ExpandDimsOp:$expand_dims_op $input, $dim),
(ConstantOp (GetShape $expand_dims_op))),
[(AnyStaticShapeTensor $expand_dims_op)]>;
class ValueEquals<string val> : Constraint<CPred<
class FloatValueEquals<string val> : Constraint<CPred<
"$0.cast<DenseElementsAttr>().getNumElements() == 1 &&"
"$0.isa<DenseFPElementsAttr>() &&"
"*$0.cast<DenseElementsAttr>().getValues<float>().begin() == " # val>>;
// ReLU patterns
@ -403,13 +417,13 @@ def : Pat<(TFL_MinimumOp (TFL_MaximumOp $input,
(ConstantOp $NegOne)),
(ConstantOp $One)),
(TFL_Relu1Op $input),
[(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>;
[(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>;
def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input,
(ConstantOp $One)),
(ConstantOp $NegOne)),
(TFL_Relu1Op $input),
[(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>;
[(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>;
def : Pat<(TFL_MaximumOp (TFL_MulOp:$mul_out $input1,
(ConstantOp F32ElementsAttr:$alpha), TFL_AF_None),

View File

@ -125,7 +125,7 @@ void PostQuantizePass::runOnFunction() {
auto func = getFunction();
auto* ctx = func.getContext();
TFL::populateWithGenerated(ctx, &patterns);
applyPatternsGreedily(func, patterns);
applyPatternsAndFoldGreedily(func, patterns);
if (!emit_quant_adaptor_ops_) {
RemoveQuantizationAdaptorOps(getFunction());

View File

@ -267,7 +267,7 @@ void PrepareQuantizePass::runOnFunction() {
// Currently, only activation stats are imported, so narrow_range = false.
patterns.insert<PrepareQuantStats>(8, false, false, ctx);
}
applyPatternsGreedily(func, patterns);
applyPatternsAndFoldGreedily(func, patterns);
SanityCheckAndAdjustment(func);

View File

@ -46,7 +46,6 @@ limitations under the License.
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/Functional.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
@ -322,9 +321,10 @@ class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
// Create tensor type for the transpose result.
auto filter_type = filter.getType().cast<RankedTensorType>();
auto result_shape = functional::map(
[filter_type](int64_t dim) { return filter_type.getDimSize(dim); },
perm);
auto result_shape =
llvm::to_vector<4>(llvm::map_range(perm, [filter_type](int64_t dim) {
return filter_type.getDimSize(dim);
}));
auto elem_type = filter_type.getElementType();
auto result_type = RankedTensorType::get(result_shape, elem_type);
@ -619,8 +619,8 @@ void PrepareTFPass::runOnFunction() {
// This pattern was intented to uses TFL QDQs to preserve the quantization
// parameters from the TF Quant ops, thus this pattern should run with the
// first `applyPatternsGreedily` method, which would otherwise removes the
// TF FakeQuant ops by the constant folding.
// first `applyPatternsAndFoldGreedily` method, which would otherwise removes
// the TF FakeQuant ops by the constant folding.
patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant>(ctx);
// This pattern will try to identify and optimize for dilated convolution.
@ -634,7 +634,7 @@ void PrepareTFPass::runOnFunction() {
// This will allow optimizing any TF_Mul->TF_Conv in the graph
// and any expanded from FusedBatchNorm. We need to do this
// before converting TF_Conv to TFL_Conv
applyPatternsGreedily(func, patterns);
applyPatternsAndFoldGreedily(func, patterns);
// Load the generated pattern again, so new quantization pass-through
// will be applied.
@ -646,7 +646,7 @@ void PrepareTFPass::runOnFunction() {
}
patterns.insert<TF::ConvertTFEinsumOp, ConvertTFConv2D,
ConvertTFDepthwiseConv2dNative, ConvertTFStridedSlice>(ctx);
applyPatternsGreedily(func, patterns);
applyPatternsAndFoldGreedily(func, patterns);
}
} // namespace

View File

@ -29,7 +29,6 @@ limitations under the License.
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/Functional.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
@ -88,7 +87,7 @@ void QuantizePass::runOnFunction() {
TFL::populateWithGenerated(ctx, &patterns);
patterns.insert<TFLFullQuantization>(
ctx, enable_numeric_verify, error_tolerance, enable_single_layer_verify);
applyPatternsGreedily(func, patterns);
applyPatternsAndFoldGreedily(func, patterns);
}
} // namespace

View File

@ -94,9 +94,10 @@ Value Transpose(OpBuilder* builder, Value value_to_transpose,
// Create tensor type for the transpose result.
auto transpose_type = original_type;
auto transpose_shape = functional::map(
[transpose_type](int32_t dim) { return transpose_type.getDimSize(dim); },
perm);
auto transpose_shape =
llvm::to_vector<8>(llvm::map_range(perm, [transpose_type](int32_t dim) {
return transpose_type.getDimSize(dim);
}));
auto elem_type = transpose_type.getElementType();
auto result_type = RankedTensorType::get(transpose_shape, elem_type);

View File

@ -127,6 +127,7 @@ Status MlirFunctionOptimizationPass::Run(
GraphImportConfig import_config;
import_config.graph_as_function = true;
import_config.control_outputs = *control_ret_node_names;
import_config.upgrade_legacy = true;
TF_ASSIGN_OR_RETURN(auto module_ref,
ConvertGraphToMlir(**graph, debug_info, *flib_def,
import_config, &context));
@ -149,7 +150,6 @@ Status MlirFunctionOptimizationPass::Run(
}
GraphExportConfig export_config;
export_config.graph_as_function = true;
absl::flat_hash_set<Node*> control_ret_nodes;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
ConvertMlirToGraph(*module_ref, export_config, graph, flib_def,

View File

@ -71,7 +71,8 @@ tool_dirs = config.mlir_tf_tools_dirs + [
tool_names = [
'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate',
'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer'
'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer', 'xla-gpu-opt',
'xla-opt'
]
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
llvm_config.add_tool_substitutions(tools, tool_dirs)

View File

@ -45,7 +45,8 @@ mlir_tf_tools_dirs = [
'tensorflow/compiler/mlir/lite',
'tensorflow/compiler/mlir/tensorflow',
'tensorflow/compiler/mlir/xla',
'tensorflow/compiler/aot'
'tensorflow/compiler/aot',
'tensorflow/compiler/xla/service/mlir_gpu',
]
config.mlir_tf_tools_dirs = [
os.path.join(real_test_srcdir, os.environ['TEST_WORKSPACE'], s)

View File

@ -1292,6 +1292,45 @@ tf_cc_test(
],
)
cc_library(
name = "dump_graph",
srcs = ["utils/dump_graph.cc"],
hdrs = ["utils/dump_graph.h"],
deps = [
":convert_graphdef",
":error_util",
":tensorflow",
":tensorflow_dialect_registration",
":tensorflow_passes",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"@llvm-project//llvm:support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
],
)
tf_cc_test(
name = "dump_graph_test",
size = "small",
srcs = ["utils/dump_graph_test.cc"],
deps = [
":dump_graph",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:test",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
],
)
cc_library(
name = "bridge_logger",
srcs = ["utils/bridge_logger.cc"],

View File

@ -40,7 +40,6 @@ limitations under the License.
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Support/STLExtras.h" // from @llvm-project
#include "mlir/Transforms/InliningUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/core/platform/logging.h"
@ -90,7 +89,7 @@ struct TFInlinerInterface : public DialectInlinerInterface {
// are perfectly forwarded to the block's terminator.
bool BlockWrapsSingleOp(Block* block) {
auto body = block->without_terminator();
if (!has_single_element(body)) return false;
if (!hasSingleElement(body)) return false;
Operation& wrapped_op = *body.begin();
Operation* terminator = block->getTerminator();

Some files were not shown because too many files have changed in this diff Show More