Merge pull request #3 from tensorflow/master

downstream merge 2
This commit is contained in:
tg-at-google 2020-06-28 20:31:50 -04:00 committed by GitHub
commit 9fdba01ead
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1662 changed files with 64588 additions and 25346 deletions

View File

@ -39,6 +39,7 @@
#
# Feature and Third party library support options:
# xla: Build TF with XLA
# tpu: Build TF with TPU support
# using_cuda: CUDA is available to build system.
# cuda: Build with full cuda support.
# rocm: Build with AMD GPU support (rocm).
@ -57,13 +58,12 @@
#
#
# Remote build execution options (only configured to work with TF team projects for now.)
# rbe: General RBE options shared by all flavors.
# rbe_linux: General RBE options used on all linux builds.
# rbe_win: General RBE options used on all windows builds.
# rbe: General RBE options shared by all flavors.
# rbe_linux: General RBE options used on all linux builds.
# rbe_win: General RBE options used on all windows builds.
#
# rbe_cpu_linux: RBE options to build with only CPU support.
# rbe_linux_cuda_nvcc: RBE options to build with GPU support using nvcc.
# rbe_gpu_linux: An alias for rbe_linux_cuda_nvcc
# rbe_cpu_linux: RBE options to build with only CPU support.
# rbe_linux_cuda_nvcc_py*: RBE options to build with GPU support using nvcc.
#
# rbe_linux_py2: Linux Python 2 RBE config.
# rbe_linux_py3: Linux Python 3 RBE config
@ -180,6 +180,9 @@ build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON
# AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498
build:dbg --copt -DDEBUG_BUILD
# Config to build TPU backend
build:tpu --define=with_tpu_support=true
build:tensorrt --action_env TF_NEED_TENSORRT=1
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
@ -396,33 +399,48 @@ build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1
build:rbe_linux_cuda_base --repo_env=TF_NEED_CUDA=1
test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --host_platform="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --platforms="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_nvcc --define=using_cuda_nvcc=true
test:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
build:rbe_linux_cuda10.1_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda10.1_nvcc_base --define=using_cuda_nvcc=true
build:rbe_linux_cuda10.1_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda10.1_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda10.1_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda10.1_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda10.1_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda10.1_nvcc_py2.7 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
build:rbe_linux_cuda10.1_nvcc_py3.5 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
build:rbe_linux_cuda10.1_nvcc_py3.6 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
build:rbe_linux_cuda10.1_nvcc_py3.7 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
build:rbe_linux_cuda10.1_nvcc_py3.8 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
build:rbe_linux_cuda_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_nvcc_base --define=using_cuda_nvcc=true
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
build:rbe_linux_cuda11.0_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda11.0_nvcc_base --define=using_cuda_nvcc=true
build:rbe_linux_cuda11.0_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda11.0_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda11.0_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform"
build:rbe_linux_cuda11.0_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform"
build:rbe_linux_cuda11.0_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform"
build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda"
build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_tensorrt"
build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_nccl"
build:rbe_linux_cuda11.0_nvcc_py2.7 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python2.7"
build:rbe_linux_cuda11.0_nvcc_py3.5 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.5"
build:rbe_linux_cuda11.0_nvcc_py3.6 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.6"
build:rbe_linux_cuda11.0_nvcc_py3.7 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.7"
build:rbe_linux_cuda11.0_nvcc_py3.8 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.8"
# Map default to CUDA 10.1.
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda10.1_nvcc_py2.7
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda10.1_nvcc_py3.5
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda10.1_nvcc_py3.6
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda10.1_nvcc_py3.7
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda10.1_nvcc_py3.8
# Deprecated configs that people might still use.
build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_nvcc_py36
build:rbe_gpu_linux --config=rbe_linux_cuda_nvcc
build:rbe_linux_cuda_clang_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda_clang_base --crosstool_top="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
@ -440,8 +458,6 @@ build:rbe_linux_cuda_clang_py36 --config=rbe_linux_cuda_clang_base --repo_env=TF
build:rbe_linux_cuda_clang_py37 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
build:rbe_linux_cuda_clang_py38 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
common:rbe_gpu_linux --config=rbe_linux_cuda_nvcc
build:rbe_linux_py2 --config=rbe_linux
build:rbe_linux_py2 --repo_env=PYTHON_BIN_PATH="/usr/bin/python2"
build:rbe_linux_py2 --python_path="/usr/bin/python2"

View File

@ -27,6 +27,19 @@ assignees:
# A list of assignees for compiler folder
compiler_assignees:
- joker-eph
# filesystem path
filesystem_path:
- tensorflow/c/experimental/filesystem
# security path
security_path:
- tensorflow/security
# words checklist
segfault_memory:
- segfault
- memory leaks
# assignees
filesystem_security_assignee:
- mihaimaruseac
# Cuda Comment
cuda_comment: >
From the template it looks like you are installing **TensorFlow** (TF) prebuilt binaries:

View File

@ -95,6 +95,7 @@ for general questions and discussion, and please direct specific questions to
The TensorFlow project strives to abide by generally accepted best practices in
open-source software development:
[![Fuzzing Status](https://oss-fuzz-build-logs.storage.googleapis.com/badges/tensorflow.svg)](https://bugs.chromium.org/p/oss-fuzz/issues/list?sort=-opened&can=1&q=proj:tensorflow)
[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1486/badge)](https://bestpractices.coreinfrastructure.org/projects/1486)
[![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-v1.4%20adopted-ff69b4.svg)](CODE_OF_CONDUCT.md)

View File

@ -1,3 +1,57 @@
# Release 2.4.0
<INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES>
## Breaking Changes
* <DOCUMENT BREAKING CHANGES HERE>
* <THIS SECTION SHOULD CONTAIN API, ABI AND BEHAVIORAL BREAKING CHANGES>
## Known Caveats
* <CAVEATS REGARDING THE RELEASE (BUT NOT BREAKING CHANGES). E.G. ADDING A NEW DEPENDENCY, BUMPING A DEPENDENCY NUMBER, LACK OF SUPPORT ON SOME PLATFORM, ETC>
## Major Features and Improvements
* <INSERT MAJOR FEATURE HERE, USING MARKDOWN SYNTAX>
* <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER>
## Bug Fixes and Other Changes
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
* <NOTES SHOULD BE GROUPED PER AREA>
* TF Core:
* <ADD RELEASE NOTES HERE>
* `tf.data`:
* <ADD RELEASE NOTES HERE>
* `tf.distribute`:
* <ADD RELEASE NOTES HERE>
* `tf.keras`:
* <ADD RELEASE NOTES HERE>
* `tf.function`/AutoGraph:
* <ADD RELEASE NOTES HERE>
* `tf.lite`:
* <ADD RELEASE NOTES HERE>
* `tf.random`:
* <ADD RELEASE NOTES HERE>
* Math and Linear Algebra:
* <ADD RELEASE NOTES HERE>
* TPU Enhancements:
* <ADD RELEASE NOTES HERE>
* XLA Support:
* <ADD RELEASE NOTES HERE>
* Tracing and Debugging:
* <ADD RELEASE NOTES HERE>
* Other:
* <ADD RELEASE NOTES HERE>
## Thanks to our Contributors
This release contains contributions from many people at Google, as well as:
<INSERT>, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
# Release 2.3.0
## Breaking Changes
@ -11,6 +65,15 @@
exsiting C++ kernel `ExtractGlimpse` does not change as well, so saved
models will not be impacted.
## Bug Fixes and Other Changes
* `tf.keras`:
* Deprecated the `tf.keras.experimental.PeepholeLSTMCell` layer, which was
moved to `tensorflow_addons` as
`tensorflow_addons.rnn.PeepholeLSTMCell`. This experimental API is
expected to be removed from TF in the next public release (2.4).
* Mutable tables now restore checkpointed values when loaded from SavedModel.
# Release 2.1.1
## Bug Fixes and Other Changes

View File

@ -36,18 +36,6 @@ load(
bazel_toolchains_repositories()
load(
"@io_bazel_rules_docker//repositories:repositories.bzl",
container_repositories = "repositories",
)
container_repositories()
load("//third_party/toolchains/preconfig/generate:workspace.bzl",
"remote_config_workspace")
remote_config_workspace()
# Use `swift_rules_dependencies` to fetch the toolchains. With the
# `git_repository` rules above, the following call will skip redefining them.
load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies")

View File

@ -480,7 +480,7 @@ def check_bazel_version(min_version, max_version):
"""
if which('bazel') is None:
print('Cannot find bazel. Please install bazel.')
sys.exit(0)
sys.exit(1)
stderr = open(os.devnull, 'wb')
curr_version = run_shell(['bazel', '--version'],

View File

@ -467,6 +467,13 @@ config_setting(
visibility = ["//visibility:public"],
)
# This flag enables experimental TPU support
config_setting(
name = "with_tpu_support",
values = {"define": "with_tpu_support=true"},
visibility = ["//visibility:public"],
)
# Specifies via a config setting if this is a mobile build or not, makes
# it easier to combine settings later.
selects.config_setting_group(

View File

@ -173,6 +173,7 @@ tf_cuda_library(
copts = tf_copts(),
visibility = [
"//tensorflow/c:__subpackages__",
"//tensorflow/python:__subpackages__",
"//third_party/llvm/llvm-project:__subpackages__",
],
deps = [

View File

@ -54,7 +54,7 @@ Status ProcessInputs(
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
input_tensors->reserve(ninputs);
for (int i = 0; i < ninputs; ++i) {
Node* node = &inputs[i].oper->node;
Node* node = inputs[i].oper ? &inputs[i].oper->node : nullptr;
int idx = inputs[i].index;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
@ -90,7 +90,7 @@ Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
output_tensors->reserve(noutputs);
for (int i = 0; i < noutputs; ++i) {
Node* node = &outputs[i].oper->node;
Node* node = outputs[i].oper ? &outputs[i].oper->node : nullptr;
int idx = outputs[i].index;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
fn_body->graph.IsValidOutputTensor(node, idx),

View File

@ -158,9 +158,13 @@ cc_library(
"//tensorflow:internal",
],
deps = [
":abstract_context",
":abstract_operation",
":abstract_tensor_handle",
":c_api",
":c_api_experimental",
"//tensorflow/c:c_api_internal",
"//tensorflow/c:conversion_macros",
"//tensorflow/c:tf_status",
"//tensorflow/core/platform:casts",
"//tensorflow/core/platform:types",
@ -244,7 +248,6 @@ cc_library(
"//tensorflow:internal",
],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:status",
],
@ -542,6 +545,9 @@ tf_cuda_library(
":abstract_operation",
":abstract_context",
":abstract_tensor_handle",
":immediate_execution_tensor_handle",
":immediate_execution_context",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
"//tensorflow/core:core_cpu",
@ -560,6 +566,7 @@ tf_cuda_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:variant",
"//tensorflow/c:conversion_macros",
],
}) + select({
"//tensorflow:with_xla_support": [
@ -733,6 +740,10 @@ filegroup(
],
exclude = [
"c_api_experimental.cc",
"c_api_unified_experimental.cc",
"c_api_unified_experimental_eager.cc",
"c_api_unified_experimental_graph.cc",
"c_api_unified_experimental_internal.h",
"*test*",
"*dlpack*",
],

View File

@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_
#define TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_
#include <vector>
#include <memory>
#include "tensorflow/c/eager/abstract_function.h"
#include "tensorflow/c/eager/abstract_operation.h"
@ -32,7 +32,7 @@ namespace tensorflow {
// environment, a traced representation etc.
class AbstractContext {
protected:
enum AbstractContextKind { kTracing, kImmediateExecution };
enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt };
explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {}
virtual ~AbstractContext() {}
@ -64,6 +64,19 @@ class AbstractContext {
const AbstractContextKind kind_;
};
namespace internal {
struct AbstractContextDeleter {
void operator()(AbstractContext* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using AbstractContextPtr =
std::unique_ptr<AbstractContext, internal::AbstractContextDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_

View File

@ -15,7 +15,6 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_
#define TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_
#include "tensorflow/c/c_api.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/platform/status.h"
@ -26,7 +25,7 @@ namespace tensorflow {
// function.
class AbstractFunction {
protected:
enum AbstractFunctionKind { kGraphFunc, kMlirFunc };
enum AbstractFunctionKind { kGraph, kMlir };
explicit AbstractFunction(AbstractFunctionKind kind) : kind_(kind) {}
public:

View File

@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
#define TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
@ -28,7 +30,7 @@ namespace tensorflow {
// tracing or immediate execution mode.
class AbstractOperation {
protected:
enum AbstractOperationKind { kTracing, kImmediateExecution };
enum AbstractOperationKind { kGraph, kMlir, kEager, kTfrt };
explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {}
virtual ~AbstractOperation() {}
@ -110,6 +112,19 @@ class AbstractOperation {
const AbstractOperationKind kind_;
};
namespace internal {
struct AbstractOperationDeleter {
void operator()(AbstractOperation* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using AbstractOperationPtr =
std::unique_ptr<AbstractOperation, internal::AbstractOperationDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_

View File

@ -15,13 +15,14 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
#define TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
#include <memory>
namespace tensorflow {
// Abstract interface to a Tensor handle in either tracing or immediate
// execution mode.
class AbstractTensorHandle {
protected:
enum AbstractTensorHandleKind { kTracing, kImmediateExecution };
enum AbstractTensorHandleKind { kGraph, kMlir, kEager, kTfrt };
explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {}
virtual ~AbstractTensorHandle() {}
@ -40,6 +41,20 @@ class AbstractTensorHandle {
const AbstractTensorHandleKind kind_;
};
namespace internal {
struct AbstractTensorHandleDeleter {
void operator()(AbstractTensorHandle* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using AbstractTensorHandlePtr =
std::unique_ptr<AbstractTensorHandle,
internal::AbstractTensorHandleDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_

View File

@ -60,6 +60,12 @@ void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
context->SetShouldStoreGraphs(false);
}
uint64_t TFE_GetContextId(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return context->GetContextId();
}
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
int64_t value) {
cell->cell.IncrementBy(value);

View File

@ -300,6 +300,14 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetLazyRemoteInputsCopy(
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetTfrt(TFE_ContextOptions*,
bool use_tfrt);
// Returns the context_id from the EagerContext which is used by the
// EagerService to maintain consistency between client and worker. The
// context_id is initialized with a dummy value and is later set when the worker
// is initialized (either locally or remotely). The context_id can change during
// the process lifetime although this should cause the worker to be
// reinitialized (e.g. cleared caches) as well.
TF_CAPI_EXPORT extern uint64_t TFE_GetContextId(TFE_Context* ctx);
// -----------------------------------------------------------------------------
// Cancellation APIs.

View File

@ -22,15 +22,17 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::string;
using tensorflow::internal::OutputList;
using tensorflow::internal::unwrap;
namespace tensorflow {
namespace internal {
typedef absl::flat_hash_map<std::string, FactoryFunction> FactoriesMap;
namespace tracing {
typedef absl::flat_hash_map<std::string, tracing::FactoryFunction> FactoriesMap;
static FactoriesMap& GetFactories() {
static FactoriesMap* factories = new FactoriesMap;
@ -48,8 +50,8 @@ void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
void SetDefaultTracingEngine(const char* name) { default_factory = name; }
static ExecutionContext* CreateTracingExecutionContext(const char* fn_name,
TF_Status* s) {
static TracingContext* CreateTracingExecutionContext(const char* fn_name,
TF_Status* s) {
auto entry = GetFactories().find(default_factory);
if (entry != GetFactories().end()) return entry->second(fn_name, s);
string msg = absl::StrCat(
@ -70,7 +72,7 @@ static ExecutionContext* CreateTracingExecutionContext(const char* fn_name,
return nullptr;
}
} // end namespace internal
} // end namespace tracing
} // end namespace tensorflow
// =============================================================================
@ -83,43 +85,77 @@ static ExecutionContext* CreateTracingExecutionContext(const char* fn_name,
//
// =============================================================================
using tensorflow::AbstractFunction;
using tensorflow::AbstractTensorHandle;
using tensorflow::DataType;
using tensorflow::dyn_cast;
using tensorflow::OutputList;
using tensorflow::Status;
using tensorflow::unwrap;
using tensorflow::wrap;
using tensorflow::tracing::CreateTracingExecutionContext;
using tensorflow::tracing::SetDefaultTracingEngine;
using tensorflow::tracing::TracingContext;
using tensorflow::tracing::TracingOperation;
using tensorflow::tracing::TracingTensorHandle;
void TF_SetTracingImplementation(const char* name) {
tensorflow::internal::SetDefaultTracingEngine(name);
SetDefaultTracingEngine(name);
}
// Creates a new TensorFlow function, it is an execution context attached to a
// given tracing context.
TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* s) {
return wrap(tensorflow::internal::CreateTracingExecutionContext(fn_name, s));
return wrap(CreateTracingExecutionContext(fn_name, s));
}
TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
TF_OutputList* outputs, TF_Status* s) {
auto* func = wrap(unwrap(ctx)->Finalize(unwrap(outputs), s));
AbstractFunction* func;
TracingContext* tracing_ctx = dyn_cast<TracingContext>(unwrap(ctx));
if (!tracing_ctx) {
Set_TF_Status_from_Status(
s, tensorflow::errors::InvalidArgument(
"Only TracingContext can be converted into a function."));
return nullptr;
}
Set_TF_Status_from_Status(s, tracing_ctx->Finalize(unwrap(outputs), &func));
TF_DeleteExecutionContext(ctx);
return func;
return wrap(func);
}
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
TF_DataType dtype, TF_Status* s) {
return wrap(unwrap(func)->AddParameter(dtype, s));
TracingTensorHandle* t;
TracingContext* tracing_ctx = dyn_cast<TracingContext>(unwrap(func));
if (!tracing_ctx) {
Set_TF_Status_from_Status(
s, tensorflow::errors::InvalidArgument(
"TF_AddFunctionParameter must be called on a TracingContext."));
return nullptr;
}
Set_TF_Status_from_Status(
s, tracing_ctx->AddParameter(static_cast<DataType>(dtype), &t));
return wrap(t);
}
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); }
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { unwrap(c)->Release(); }
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
return wrap(unwrap(c)->CreateOperation());
return wrap((unwrap(c)->CreateOperation()));
}
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete unwrap(op); }
void TF_DeleteAbstractOp(TF_AbstractOp* op) { unwrap(op)->Release(); }
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete unwrap(t); }
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { unwrap(t)->Release(); }
TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); }
void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); }
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
TF_Status* s) {
unwrap(o)->expected_num_outputs = num_outputs;
unwrap(o)->outputs.clear();
unwrap(o)->outputs.resize(num_outputs);
}
int TF_OutputListNumOutputs(TF_OutputList* o) {
return unwrap(o)->outputs.size();
@ -134,24 +170,46 @@ void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
TF_Status* s) {
unwrap(op)->SetOpType(op_type, s);
Set_TF_Status_from_Status(s, unwrap(op)->Reset(op_type,
/*raw_device_name=*/nullptr));
}
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
TF_Status* s) {
unwrap(op)->SetOpName(op_name, s);
TracingOperation* tracing_op = dyn_cast<TracingOperation>(unwrap(op));
if (!tracing_op) {
Set_TF_Status_from_Status(
s, tensorflow::errors::InvalidArgument(
"TF_AbstractOpSetOpName must be called on a TracingOperation."));
return;
}
Set_TF_Status_from_Status(s, tracing_op->SetOpName(op_name));
}
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
TF_DataType value, TF_Status* s) {
unwrap(op)->SetAttrType(attr_name, value, s);
Status status =
unwrap(op)->SetAttrType(attr_name, static_cast<DataType>(value));
TF_SetStatus(s, static_cast<TF_Code>(status.code()),
status.error_message().c_str());
}
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s) {
unwrap(ctx)->ExecuteOperation(unwrap(op), num_inputs, &unwrap(*inputs),
unwrap(o), s);
TF_Status* s) {
for (int i = 0; i < num_inputs; i++) {
Set_TF_Status_from_Status(s, unwrap(op)->AddInput(unwrap(inputs[i])));
if (TF_GetCode(s) != TF_OK) {
return;
}
}
int num_outputs = unwrap(o)->expected_num_outputs;
Set_TF_Status_from_Status(
s, unwrap(op)->Execute(
absl::MakeSpan(reinterpret_cast<AbstractTensorHandle**>(
unwrap(o)->outputs.data()),
unwrap(o)->outputs.size()),
&num_outputs));
}
void TF_DeleteAbstractFunction(TF_AbstractFunction* func) {
@ -161,5 +219,5 @@ void TF_DeleteAbstractFunction(TF_AbstractFunction* func) {
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext* ctx,
TF_AbstractFunction* func,
TF_Status* s) {
unwrap(ctx)->RegisterFunction(unwrap(func), s);
Set_TF_Status_from_Status(s, unwrap(ctx)->RegisterFunction(unwrap(func)));
}

View File

@ -110,7 +110,7 @@ void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
// Any active tape will observe the effects of this execution.
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s);
TF_Status* s);
// Creates a new TF_AbstractFunction from the current tracing states in the
// context. The provided `ctx` is consumed by this API call and deleted.
@ -137,7 +137,8 @@ TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t,
TF_Status* s);
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s);
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext*);
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext*,
TF_Status* s);
#ifdef __cplusplus
} /* end extern "C" */

View File

@ -15,180 +15,68 @@ limitations under the License.
#include <vector>
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::string;
namespace tensorflow {
namespace internal {
// Simple wrapper over a TFE_TensorHandle
struct EagerTensor : public AbstractTensor {
TFE_TensorHandle* t = nullptr;
EagerTensor() : AbstractTensor(kKind) {}
explicit EagerTensor(TFE_TensorHandle* t) : AbstractTensor(kKind), t(t) {}
~EagerTensor() override { TFE_DeleteTensorHandle(t); }
static constexpr AbstractTensorKind kKind = kEagerTensor;
};
// Simple wrapper over a TFE_Op
class EagerOp : public AbstractOp {
public:
explicit EagerOp(TFE_Context* ctx) : AbstractOp(kKind), ctx_(ctx) {}
void SetOpType(const char* const op_type, TF_Status* s) override {
op_ = TFE_NewOp(ctx_, op_type, s);
}
void SetOpName(const char* const op_name, TF_Status* s) override {
// Name is ignored in eager mode.
}
void SetAttrType(const char* const attr_name, TF_DataType value,
TF_Status* s) override {
if (op_ == nullptr) {
TF_SetStatus(s, TF_FAILED_PRECONDITION,
"op_type must be specified before specifying attrs.");
return;
}
TFE_OpSetAttrType(op_, attr_name, value);
}
~EagerOp() override { TFE_DeleteOp(op_); }
static constexpr AbstractOpKind kKind = kEagerOp;
private:
friend class EagerContext; // For access to op_.
TFE_Op* op_ = nullptr;
TFE_Context* ctx_;
};
// Wraps a TFE_Context and dispatch EagerOp with EagerTensor inputs.
class EagerContext : public ExecutionContext {
public:
EagerContext() : ExecutionContext(kKind) {}
void Build(TFE_ContextOptions* options, TF_Status* status) {
eager_ctx_ = TFE_NewContext(options, status);
}
AbstractOp* CreateOperation() override {
// TODO(srbs): Should the lifetime of this op be tied to the context.
return new EagerOp(eager_ctx_);
}
void ExecuteOperation(AbstractOp* op, int num_inputs,
AbstractTensor* const* inputs, OutputList* o,
TF_Status* s) override {
auto* eager_op = dyncast<EagerOp>(op);
if (eager_op == nullptr) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Unable to cast AbstractOp to TF_EagerOp.");
return;
}
auto* tfe_op = eager_op->op_;
if (TF_GetCode(s) != TF_OK) return;
for (int i = 0; i < num_inputs; ++i) {
auto* eager_tensor = dyncast<const EagerTensor>(inputs[i]);
if (!eager_tensor) {
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
return;
}
TFE_OpAddInput(tfe_op, eager_tensor->t, s);
if (TF_GetCode(s) != TF_OK) return;
}
if (o->expected_num_outputs == -1) {
string msg =
"The number of outputs must be provided in eager mode. Use "
"TF_OutputListSetNumOutputs.";
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return;
}
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals;
int num_retvals = o->expected_num_outputs;
retvals.resize(num_retvals);
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
if (TF_GetCode(s) != TF_OK) {
return;
}
o->outputs.clear();
o->outputs.reserve(num_retvals);
for (int i = 0; i < num_retvals; ++i) {
o->outputs.push_back(new EagerTensor(retvals[i]));
}
}
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Can't add function parameter on an eager context.");
return nullptr;
}
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Can't use finalize function on an eager context.");
return nullptr;
}
void RegisterFunction(AbstractFunction* afunc, TF_Status* s) override {
auto* func = afunc->GetTfFunction(s);
if (!func) {
return;
}
TFE_ContextAddFunction(eager_ctx_, func, s);
}
~EagerContext() override { TFE_DeleteContext(eager_ctx_); }
static constexpr ExecutionContextKind kKind = kEagerContext;
private:
friend TFE_Context* ::TF_ExecutionContextGetTFEContext(
TF_ExecutionContext* ctx);
TFE_Context* eager_ctx_;
};
} // namespace internal
} // namespace tensorflow
// =============================================================================
// Public C API entry points
// These are only the entry points specific to the Eager API.
// =============================================================================
using tensorflow::internal::dyncast;
using tensorflow::internal::unwrap;
using tensorflow::AbstractContext;
using tensorflow::AbstractTensorHandle;
using tensorflow::dyn_cast;
using tensorflow::ImmediateExecutionContext;
using tensorflow::ImmediateExecutionTensorHandle;
using tensorflow::string;
using tensorflow::unwrap;
using tensorflow::wrap;
using tensorflow::strings::StrCat;
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions* options,
TF_Status* s) {
auto* ctx = new tensorflow::internal::EagerContext();
ctx->Build(options, s);
return wrap(ctx);
TFE_Context* c_ctx = TFE_NewContext(options, s);
if (TF_GetCode(s) != TF_OK) {
return nullptr;
}
return wrap(static_cast<AbstractContext*>(unwrap(c_ctx)));
}
TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t,
TF_Status* s) {
return wrap(new tensorflow::internal::EagerTensor(t));
return wrap(static_cast<AbstractTensorHandle*>(unwrap(t)));
}
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s) {
auto* eager_tensor = dyncast<tensorflow::internal::EagerTensor>(unwrap(at));
if (!eager_tensor) {
string msg = tensorflow::strings::StrCat("Not an eager tensor handle.",
reinterpret_cast<uintptr_t>(at));
auto handle = dyn_cast<ImmediateExecutionTensorHandle>(unwrap(at));
if (!handle) {
string msg =
StrCat("Not an eager tensor handle.", reinterpret_cast<uintptr_t>(at));
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
}
return eager_tensor->t;
return wrap(handle);
}
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx) {
auto* eager_ctx = dyncast<tensorflow::internal::EagerContext>(unwrap(ctx));
if (!eager_ctx) return nullptr;
return eager_ctx->eager_ctx_;
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx,
TF_Status* s) {
auto imm_ctx = dyn_cast<ImmediateExecutionContext>(unwrap(ctx));
if (!imm_ctx) {
string msg =
StrCat("Not an eager context.", reinterpret_cast<uintptr_t>(ctx));
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
}
return wrap(imm_ctx);
}

View File

@ -18,77 +18,198 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::dyn_cast;
using tensorflow::string;
namespace tensorflow {
namespace internal {
namespace tracing {
namespace graph {
class GraphContext;
class GraphOperation;
class GraphTensor;
// GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index
// into the list of outputs for the operation.
struct GraphTensor : public AbstractTensor {
TF_Output output{};
GraphContext* ctx = nullptr;
GraphTensor() : AbstractTensor(kKind) {}
GraphTensor(TF_Output output, GraphContext* ctx)
: AbstractTensor(kKind), output(output), ctx(ctx) {}
static constexpr AbstractTensorKind kKind = kGraphTensor;
class GraphTensor : public TracingTensorHandle {
public:
explicit GraphTensor(TF_Output output)
: TracingTensorHandle(kGraph), output_(output) {}
void Release() override { delete this; }
TF_Output output_;
// For LLVM style RTTI.
static bool classof(const AbstractTensorHandle* ptr) {
return ptr->getKind() == kGraph;
}
};
// GraphOp wraps and populate a TF_OperationDescription.
class GraphOp : public AbstractOp {
// GraphOperation wraps and populates a TF_OperationDescription.
class GraphOperation : public TracingOperation {
public:
explicit GraphOp(TF_Graph* g) : AbstractOp(kKind), g_(g) {}
void SetOpType(const char* const op_type, TF_Status* s) override {
explicit GraphOperation(TF_Graph* g) : TracingOperation(kGraph), g_(g) {}
void Release() override { delete this; }
Status Reset(const char* op, const char* raw_device_name) override {
if (op_) {
TF_SetStatus(
s, TF_FAILED_PRECONDITION,
strings::StrCat("SetOpType called on already built op.").c_str());
return;
return errors::FailedPrecondition("Reset called on already built op.");
}
if (op_name_ != nullptr) {
op_.reset(TF_NewOperation(g_, op_type, op_name_));
op_name_ = nullptr;
} else {
op_type_ = op_type;
if (raw_device_name) {
device_name_ = raw_device_name;
}
op_type_ = op;
return Status::OK();
}
void SetOpName(const char* const op_name, TF_Status* s) override {
Status SetOpName(const char* const op_name) override {
if (op_) {
TF_SetStatus(
s, TF_FAILED_PRECONDITION,
strings::StrCat("SetOpName called on already built op.").c_str());
return;
return errors::FailedPrecondition(
"SetOpName called on already built op.");
}
if (op_type_ != nullptr) {
op_.reset(TF_NewOperation(g_, op_type_, op_name));
op_type_ = nullptr;
} else {
op_name_ = op_name;
if (op_type_.empty()) {
return errors::FailedPrecondition(
"GraphOperation::Reset must be called before calling SetOpName.");
}
op_.reset(TF_NewOperation(g_, op_type_.c_str(), op_name));
return Status::OK();
}
void SetAttrType(const char* const attr_name, TF_DataType value,
TF_Status* s) override {
if (!op_) {
TF_SetStatus(
s, TF_FAILED_PRECONDITION,
"op_type and op_name must be specified before specifying attrs.");
return;
}
TF_SetAttrType(op_.get(), attr_name, value);
}
~GraphOp() override {}
const string& Name() const override { return op_type_; }
const string& DeviceName() const override { return device_name_; }
static constexpr AbstractOpKind kKind = kGraphOp;
Status SetDeviceName(const char* name) override {
// TODO(srbs): Implement this.
device_name_ = name;
return Status::OK();
}
Status AddInput(AbstractTensorHandle* input) override {
GraphTensor* t = dyn_cast<GraphTensor>(input);
if (!t) {
return tensorflow::errors::InvalidArgument(
"Unable to cast input to GraphTensor");
}
TF_AddInput(op_.get(), t->output_);
return Status::OK();
}
Status AddInputList(absl::Span<AbstractTensorHandle*> inputs) override {
return tensorflow::errors::Unimplemented(
"AddInputList has not been implemented yet.");
}
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) override {
auto* tf_opdesc = op_.release();
if (tf_opdesc == nullptr) {
return errors::InvalidArgument("AbstractOp is incomplete.");
}
TF_Status* s = TF_NewStatus();
auto* operation = TF_FinishOperation(tf_opdesc, s);
TF_RETURN_IF_ERROR(StatusFromTF_Status(s));
TF_DeleteStatus(s);
*num_retvals = TF_OperationNumOutputs(operation);
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = new GraphTensor({operation, i});
}
return Status::OK();
}
Status SetAttrString(const char* attr_name, const char* data,
size_t length) override {
return tensorflow::errors::Unimplemented(
"SetAttrString has not been implemented yet.");
}
Status SetAttrInt(const char* attr_name, int64_t value) override {
return tensorflow::errors::Unimplemented(
"SetAttrInt has not been implemented yet.");
}
Status SetAttrFloat(const char* attr_name, float value) override {
return tensorflow::errors::Unimplemented(
"SetAttrFloat has not been implemented yet.");
}
Status SetAttrBool(const char* attr_name, bool value) override {
return tensorflow::errors::Unimplemented(
"SetAttrBool has not been implemented yet.");
}
Status SetAttrType(const char* const attr_name, DataType value) override {
if (!op_) {
return Status(
error::Code::FAILED_PRECONDITION,
"op_type and op_name must be specified before specifying attrs.");
}
op_->node_builder.Attr(attr_name, value);
return Status::OK();
}
Status SetAttrShape(const char* attr_name, const int64_t* dims,
const int num_dims) override {
return tensorflow::errors::Unimplemented(
"SetAttrShape has not been implemented yet.");
}
Status SetAttrFunction(const char* attr_name,
const AbstractOperation* value) override {
return tensorflow::errors::Unimplemented(
"SetAttrFunction has not been implemented yet.");
}
Status SetAttrFunctionName(const char* attr_name, const char* value,
size_t length) override {
return tensorflow::errors::Unimplemented(
"SetAttrFunctionName has not been implemented yet.");
}
Status SetAttrTensor(const char* attr_name,
AbstractTensorInterface* tensor) override {
return tensorflow::errors::Unimplemented(
"SetAttrTensor has not been implemented yet.");
}
Status SetAttrStringList(const char* attr_name, const void* const* values,
const size_t* lengths, int num_values) override {
return tensorflow::errors::Unimplemented(
"SetAttrStringList has not been implemented yet.");
}
Status SetAttrFloatList(const char* attr_name, const float* values,
int num_values) override {
return tensorflow::errors::Unimplemented(
"SetAttrFloatList has not been implemented yet.");
}
Status SetAttrIntList(const char* attr_name, const int64_t* values,
int num_values) override {
return tensorflow::errors::Unimplemented(
"SetAttrIntList has not been implemented yet.");
}
Status SetAttrTypeList(const char* attr_name, const DataType* values,
int num_values) override {
return tensorflow::errors::Unimplemented(
"SetAttrTypeList has not been implemented yet.");
}
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
int num_values) override {
return tensorflow::errors::Unimplemented(
"SetAttrBoolList has not been implemented yet.");
}
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
const int* num_dims, int num_values) override {
return tensorflow::errors::Unimplemented(
"SetAttrShapeList has not been implemented yet.");
}
Status SetAttrFunctionList(
const char* attr_name,
absl::Span<const AbstractOperation*> values) override {
return tensorflow::errors::Unimplemented(
"SetAttrFunctionList has not been implemented yet.");
}
// For LLVM style RTTI.
static bool classof(const AbstractOperation* ptr) {
return ptr->getKind() == kGraph;
}
~GraphOperation() override {}
private:
friend class GraphContext; // For access to op_.
@ -96,123 +217,109 @@ class GraphOp : public AbstractOp {
std::unique_ptr<TF_OperationDescription> op_;
// Hold `op_type` and `op_name` till both are available since we need both
// to build a graph operation.
const char* op_type_ = nullptr;
string op_type_;
const char* op_name_ = nullptr;
// TODO(srbs): Use this.
string device_name_;
};
// GraphFunction is a thin wrapper over a TF_Function.
struct GraphFunction : public AbstractFunction {
TF_Function* func = nullptr;
GraphFunction() : AbstractFunction(kKind) {}
GraphFunction() : AbstractFunction(kGraph) {}
explicit GraphFunction(TF_Function* func)
: AbstractFunction(kKind), func(func) {}
: AbstractFunction(kGraph), func(func) {}
~GraphFunction() override {
if (func) TF_DeleteFunction(func);
}
TF_Function* GetTfFunction(TF_Status* s) override { return func; }
Status GetFunctionDef(FunctionDef** fdef) override {
*fdef = &func->fdef;
return Status::OK();
}
static constexpr AbstractFunctionKind kKind = kGraphFunc;
// For LLVM style RTTI.
static bool classof(const AbstractFunction* ptr) {
return ptr->getKind() == kGraph;
}
};
// GraphContext wraps a TF_Graph modeling a single function and manages the
// "execution" of operation, i.e. adding them to the function.
class GraphContext : public ExecutionContext {
class GraphContext : public TracingContext {
public:
explicit GraphContext(const char* name)
: ExecutionContext(kKind),
: TracingContext(kGraph),
graph_(new TF_Graph(), TF_DeleteGraph),
name_(name) {}
AbstractOp* CreateOperation() override {
// TODO(srbs): Should the lifetime of this op be tied to the context.
return new GraphOp(graph_.get());
void Release() override { delete this; }
TracingOperation* CreateOperation() override {
return new GraphOperation(graph_.get());
}
void ExecuteOperation(AbstractOp* op, int num_inputs,
AbstractTensor* const* inputs, OutputList* o,
TF_Status* s) override {
auto* graph_op = dyncast<GraphOp>(op);
if (graph_op == nullptr) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Unable to cast AbstractOp to TF_GraphOp.");
return;
Status AddParameter(DataType dtype, TracingTensorHandle** output) override {
auto operation = CreateOperation();
TF_RETURN_IF_ERROR(operation->Reset("Placeholder", nullptr));
TF_RETURN_IF_ERROR(
operation->SetOpName(absl::StrCat("_input_", inputs_.size()).c_str()));
TF_RETURN_IF_ERROR(operation->SetAttrType("dtype", dtype));
int num_outputs = 1;
std::vector<AbstractTensorHandle*> outputs(num_outputs);
TF_RETURN_IF_ERROR(operation->Execute(
absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
if (num_outputs != 1) {
return errors::Internal("Expected 1 output but found ", num_outputs);
}
auto* tf_opdesc = graph_op->op_.release();
if (tf_opdesc == nullptr) {
TF_SetStatus(s, TF_INVALID_ARGUMENT, "AbstractOp is incomplete.");
return;
}
for (int i = 0; i < num_inputs; ++i) {
auto* graph_tensor = dyncast<GraphTensor>(inputs[i]);
if (!graph_tensor) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Capturing eager tensors is not supported yet.");
return;
} else {
if (graph_tensor->ctx != this) {
TF_SetStatus(
s, TF_INVALID_ARGUMENT,
"Capturing tensors from other graphs is not supported yet.");
return;
}
TF_AddInput(tf_opdesc, graph_tensor->output);
}
}
auto* operation = TF_FinishOperation(tf_opdesc, s);
// TF_FinishOperation deletes `tf_opdesc` so clear its reference.
graph_op->op_ = nullptr;
if (TF_GetCode(s) != TF_OK) return;
int num_outputs = TF_OperationNumOutputs(operation);
o->outputs.clear();
o->outputs.reserve(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
o->outputs.push_back(new GraphTensor({operation, i}, this));
auto* t = dyn_cast<GraphTensor>(outputs[0]);
if (!t) {
return tensorflow::errors::InvalidArgument(
"Unable to cast input to GraphTensor");
}
inputs_.push_back(t->output_);
*output = tensorflow::down_cast<TracingTensorHandle*>(outputs[0]);
return Status::OK();
}
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override {
TF_OperationDescription* opdesc =
TF_NewOperation(graph_.get(), "Placeholder",
absl::StrCat("_input_", inputs_.size()).c_str());
TF_SetAttrType(opdesc, "dtype", dtype);
auto* operation = TF_FinishOperation(opdesc, s);
if (!s->status.ok()) return nullptr;
inputs_.push_back(TF_Output{operation, 0});
return new GraphTensor(inputs_.back(), this);
}
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override {
Status Finalize(OutputList* outputs, AbstractFunction** f) override {
std::unique_ptr<GraphFunction> func(new GraphFunction);
std::vector<TF_Output> graph_outputs;
graph_outputs.reserve(outputs->outputs.size());
for (AbstractTensor* abstract_output : outputs->outputs) {
GraphTensor* output = dyncast<GraphTensor>(abstract_output);
for (auto* abstract_output : outputs->outputs) {
GraphTensor* output = dyn_cast<GraphTensor>(abstract_output);
if (!output) {
TF_SetStatus(s, TF_UNIMPLEMENTED,
"Returning a non-graph tensor from a function has not "
"been implemented yet.");
return nullptr;
return errors::Unimplemented(
"Returning a non-graph tensor from a function has not "
"been implemented yet.");
}
graph_outputs.push_back(output->output);
graph_outputs.push_back(output->output_);
}
auto s = TF_NewStatus();
func->func = TF_GraphToFunction(
graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(),
graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s);
if (TF_GetCode(s) != TF_OK) return nullptr;
return func.release();
TF_RETURN_IF_ERROR(StatusFromTF_Status(s));
TF_DeleteStatus(s);
*f = func.release();
return Status::OK();
}
void RegisterFunction(AbstractFunction* func, TF_Status* s) override {
TF_SetStatus(s, TF_UNIMPLEMENTED,
"Registering graph functions has not been implemented yet.");
Status RegisterFunction(AbstractFunction* func) override {
return errors::Unimplemented(
"Registering graph functions has not been implemented yet.");
}
~GraphContext() override {}
static constexpr ExecutionContextKind kKind = kGraphContext;
Status RemoveFunction(const string& func) override {
return errors::Unimplemented(
"GraphContext::RemoveFunction has not been implemented yet.");
}
// For LLVM style RTTI.
static bool classof(const AbstractContext* ptr) {
return ptr->getKind() == kGraph;
}
private:
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
@ -220,7 +327,7 @@ class GraphContext : public ExecutionContext {
const char* name_;
};
static ExecutionContext* GraphTracingFactory(const char* name, TF_Status* s) {
static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) {
return new GraphContext(name);
}
@ -231,5 +338,6 @@ static bool register_tracing = [] {
return true;
}();
} // namespace internal
} // namespace graph
} // namespace tracing
} // namespace tensorflow

View File

@ -19,6 +19,10 @@ limitations under the License.
#include <vector>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
@ -26,7 +30,14 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace internal {
// Represents the results of the execution of an operation.
struct OutputList {
std::vector<AbstractTensorHandle*> outputs;
int expected_num_outputs = -1;
};
namespace tracing {
// =============================================================================
// Implementation detail for the unified execution APIs for Eager and tracing
@ -37,165 +48,75 @@ namespace internal {
// `c_api_unified_experimental.h` header.
// =============================================================================
// We can't depend on C++ rtti, but we still want to be able to have a safe
// dynamic_cast to provide diagnostics to the user when the API is misused.
// Instead we model RTTI by listing all the possible subclasses for each
// abstract base. Each subclass initializes the base class with the right
// `kind`, which allows an equivalent to `std::dynamic_cast` provided by this
// utility.
template <typename T, typename S>
T* dyncast(S source) {
if (source->getKind() != T::kKind) {
return nullptr;
}
return tensorflow::down_cast<T*>(source);
}
// Represents either an EagerTensor or a GraphTensor.
// Represents either a MlirTensor or a GraphTensor.
// This base class does not expose any public methods other than to distinguish
// which subclass it actually is. The user is responsible to use the right
// type of AbstractTensor in their context (do not pass an EagerTensor to a
// type of AbstractTensor in their context (do not pass an MlirTensor to a
// GraphContext and vice-versa).
class AbstractTensor {
class TracingTensorHandle : public AbstractTensorHandle {
protected:
enum AbstractTensorKind { kMlirTensor, kGraphTensor, kEagerTensor };
explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {}
explicit TracingTensorHandle(AbstractTensorHandleKind kind)
: AbstractTensorHandle(kind) {}
public:
// Returns which subclass is this instance of.
AbstractTensorKind getKind() const { return kind_; }
virtual ~AbstractTensor() = default;
private:
const AbstractTensorKind kind_;
};
// Represents the results of the execution of an operation.
struct OutputList {
std::vector<AbstractTensor*> outputs;
int expected_num_outputs = -1;
};
// Holds the result of tracing a function.
class AbstractFunction {
protected:
enum AbstractFunctionKind { kGraphFunc };
explicit AbstractFunction(AbstractFunctionKind kind) : kind_(kind) {}
public:
// Returns which subclass is this instance of.
AbstractFunctionKind getKind() const { return kind_; }
virtual ~AbstractFunction() = default;
// Temporary API till we figure the right abstraction for AbstractFunction.
// At the moment both Eager and Graph needs access to a "TF_Function" object.
virtual TF_Function* GetTfFunction(TF_Status* s) = 0;
private:
const AbstractFunctionKind kind_;
// For LLVM style RTTI.
static bool classof(const AbstractTensorHandle* ptr) {
return ptr->getKind() == kGraph || ptr->getKind() == kMlir;
}
};
// An abstract operation describes an operation by its type, name, and
// attributes. It can be "executed" by the context with some input tensors.
// It is allowed to reusing the same abstract operation for multiple execution
// on a given context, with the same or different input tensors.
class AbstractOp {
class TracingOperation : public AbstractOperation {
protected:
enum AbstractOpKind { kMlirOp, kGraphOp, kEagerOp };
explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {}
explicit TracingOperation(AbstractOperationKind kind)
: AbstractOperation(kind) {}
public:
// Returns which subclass is this instance of.
AbstractOpKind getKind() const { return kind_; }
virtual ~AbstractOp() = default;
// Sets the type of the operation (for example `AddV2`).
virtual void SetOpType(const char* op_type, TF_Status* s) = 0;
// Sets the name of the operation: this is an optional identifier that is
// not intended to carry semantics and preserved/propagated without
// guarantees.
virtual void SetOpName(const char* op_name, TF_Status* s) = 0;
virtual Status SetOpName(const char* op_name) = 0;
// Add a `TypeAttribute` on the operation.
virtual void SetAttrType(const char* attr_name, TF_DataType value,
TF_Status* s) = 0;
private:
const AbstractOpKind kind_;
// For LLVM style RTTI.
static bool classof(const AbstractOperation* ptr) {
return ptr->getKind() == kGraph || ptr->getKind() == kMlir;
}
};
// This holds the context for the execution: dispatching operations either to an
// eager implementation or to a graph implementation.
struct ExecutionContext {
// MLIR implementation or to a graph implementation.
class TracingContext : public AbstractContext {
protected:
enum ExecutionContextKind { kMlirContext, kGraphContext, kEagerContext };
explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {}
explicit TracingContext(AbstractContextKind kind) : AbstractContext(kind) {}
public:
// Returns which subclass is this instance of.
ExecutionContextKind getKind() const { return k; }
virtual ~ExecutionContext() = default;
// Executes the operation on the provided inputs and populate the OutputList
// with the results. The input tensors must match the current context.
// The effect of "executing" an operation depends on the context: in an Eager
// context it will dispatch it to the runtime for execution, while in a
// tracing context it will add the operation to the current function.
virtual void ExecuteOperation(AbstractOp* op, int num_inputs,
AbstractTensor* const* inputs, OutputList* o,
TF_Status* s) = 0;
// Creates an empty AbstractOperation suitable to use with this context.
virtual AbstractOp* CreateOperation() = 0;
// Add a function parameter and return the corresponding tensor.
// This is only valid with an ExecutionContext obtained from a TracingContext,
// it'll always error out with an eager context.
virtual AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) = 0;
virtual Status AddParameter(DataType dtype, TracingTensorHandle**) = 0;
// Finalize this context and make a function out of it. The context is in a
// invalid state after this call and must be destroyed.
// This is only valid with an ExecutionContext obtained from a TracingContext,
// it'll always error out with an eager context.
virtual AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) = 0;
virtual Status Finalize(OutputList* outputs, AbstractFunction**) = 0;
// Registers a functions with this context, after this the function is
// available to be called/referenced by its name in this context.
virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0;
private:
const ExecutionContextKind k;
// For LLVM style RTTI.
static bool classof(const AbstractContext* ptr) {
return ptr->getKind() == kGraph || ptr->getKind() == kMlir;
}
};
typedef ExecutionContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
typedef TracingContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
void SetDefaultTracingEngine(const char* name);
void RegisterTracingEngineFactory(const ::tensorflow::string& name,
FactoryFunction factory);
} // namespace tracing
// Create utilities to wrap/unwrap: this convert from the C opaque types to the
// C++ implementation, and back.
#define MAKE_WRAP_UNWRAP(C_TYPEDEF, CPP_CLASS) \
static inline CPP_CLASS* const& unwrap(C_TYPEDEF* const& o) { \
return reinterpret_cast<CPP_CLASS* const&>(o); \
} \
static inline const CPP_CLASS* const& unwrap(const C_TYPEDEF* const& o) { \
return reinterpret_cast<const CPP_CLASS* const&>(o); \
} \
static inline C_TYPEDEF* const& wrap(CPP_CLASS* const& o) { \
return reinterpret_cast<C_TYPEDEF* const&>(o); \
} \
static inline const C_TYPEDEF* const& wrap(const CPP_CLASS* const& o) { \
return reinterpret_cast<const C_TYPEDEF* const&>(o); \
}
MAKE_WRAP_UNWRAP(TF_ExecutionContext, ExecutionContext)
MAKE_WRAP_UNWRAP(TF_AbstractFunction, AbstractFunction)
MAKE_WRAP_UNWRAP(TF_AbstractTensor, AbstractTensor)
MAKE_WRAP_UNWRAP(TF_AbstractOp, AbstractOp)
MAKE_WRAP_UNWRAP(TF_OutputList, OutputList)
} // namespace internal
DEFINE_CONVERSION_FUNCTIONS(AbstractContext, TF_ExecutionContext)
DEFINE_CONVERSION_FUNCTIONS(AbstractTensorHandle, TF_AbstractTensor)
DEFINE_CONVERSION_FUNCTIONS(AbstractFunction, TF_AbstractFunction)
DEFINE_CONVERSION_FUNCTIONS(AbstractOperation, TF_AbstractOp)
DEFINE_CONVERSION_FUNCTIONS(OutputList, TF_OutputList)
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
@ -29,15 +30,19 @@ using tensorflow::string;
namespace tensorflow {
namespace {
class UnifiedCAPI : public ::testing::TestWithParam<const char*> {
class UnifiedCAPI
: public ::testing::TestWithParam<std::tuple<const char*, bool>> {
protected:
void SetUp() override { TF_SetTracingImplementation(GetParam()); }
void SetUp() override {
TF_SetTracingImplementation(std::get<0>(GetParam()));
}
};
TEST_P(UnifiedCAPI, TestBasicEager) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetTfrt(opts, std::get<1>(GetParam()));
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
@ -45,7 +50,8 @@ TEST_P(UnifiedCAPI, TestBasicEager) {
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract input tensor.
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx);
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
TF_AbstractTensor* at =
TF_CreateAbstractTensorFromEagerTensor(t, status.get());
@ -63,7 +69,7 @@ TEST_P(UnifiedCAPI, TestBasicEager) {
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Execute.
TF_ExecuteOperation(op, 2, inputs, o, ctx, status.get());
TF_ExecuteOperation(op, 2, inputs, o, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Clean up operation and inputs.
@ -109,9 +115,11 @@ TEST_P(UnifiedCAPI, TestBasicGraph) {
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t};
TF_OutputList* add_outputs = TF_NewOutputList();
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Execute.
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, status.get());
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Clean up operation and inputs.
@ -123,6 +131,7 @@ TEST_P(UnifiedCAPI, TestBasicGraph) {
// Build eager context.
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetTfrt(opts, std::get<1>(GetParam()));
TF_ExecutionContext* eager_execution_ctx =
TF_NewEagerExecutionContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
@ -137,16 +146,14 @@ TEST_P(UnifiedCAPI, TestBasicGraph) {
// Build an abstract input tensor.
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(eager_execution_ctx);
TF_ExecutionContextGetTFEContext(eager_execution_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
TF_AbstractTensor* input_t =
TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_ExecuteOperation(fn_op, 1, &input_t, add_outputs, eager_execution_ctx,
status.get());
TF_ExecuteOperation(fn_op, 1, &input_t, add_outputs, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(1, TF_OutputListNumOutputs(add_outputs));
@ -195,8 +202,10 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractTensor* inputs[2] = {arg0, arg1};
TF_OutputList* add_outputs = TF_NewOutputList();
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Trace the operation now (create a node in the graph).
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s);
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteAbstractOp(add_op);
// Extract the resulting tensor.
@ -215,8 +224,10 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractTensor* inputs[2] = {arg1, arg1};
TF_OutputList* add_outputs = TF_NewOutputList();
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Trace the operation now (create a node in the graph).
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s);
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteAbstractOp(add_op);
// Extract the resulting tensor.
@ -256,6 +267,7 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
// Build eager context.
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetTfrt(opts, std::get<1>(GetParam()));
TF_ExecutionContext* eager_execution_ctx =
TF_NewEagerExecutionContext(opts, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
@ -273,7 +285,8 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
std::vector<TF_AbstractTensor*> func_args;
{
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(eager_execution_ctx);
TF_ExecutionContextGetTFEContext(eager_execution_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
@ -286,7 +299,7 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
TF_OutputListSetNumOutputs(func_outputs, 2, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_ExecuteOperation(fn_op, func_args.size(), func_args.data(), func_outputs,
eager_execution_ctx, s);
s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteAbstractOp(fn_op);
for (TF_AbstractTensor* t : func_args) TF_DeleteAbstractTensor(t);
@ -314,20 +327,21 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
TF_DeleteAbstractFunction(func);
}
TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
TEST_P(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetTfrt(opts, std::get<1>(GetParam()));
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
TF_AbstractFunction* func = TF_FinalizeFunction(ctx, nullptr, status.get());
ASSERT_EQ(nullptr, func);
TF_AbstractFunction* f = TF_FinalizeFunction(ctx, nullptr, status.get());
ASSERT_EQ(nullptr, f);
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
}
TEST_P(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
TEST_P(UnifiedCAPI, TF_AbstractOpSetOpTypeAfterFinishingOpBuildingRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
@ -348,7 +362,7 @@ TEST_P(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
TF_DeleteExecutionContext(graph_ctx);
}
TEST_P(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
TEST_P(UnifiedCAPI, TF_AbstractOpSetOpNameAfterFinishingOpBuildingRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
@ -369,116 +383,44 @@ TEST_P(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
TF_DeleteExecutionContext(graph_ctx);
}
TEST_P(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
// Build an Eager context.
TEST_P(UnifiedCAPI, TF_AbstractTensorGetEagerTensorOnGraphTensorRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an Eager operation.
auto* op = TF_NewAbstractOp(ctx);
TF_AbstractOpSetOpType(op, "Add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract input tensor.
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx);
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
TF_AbstractTensor* at =
TF_CreateAbstractTensorFromEagerTensor(t, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {at, at};
TF_OutputList* o = TF_NewOutputList();
TF_OutputListSetNumOutputs(o, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build a Graph context.
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Execute eager op using graph context.
TF_ExecuteOperation(op, 2, inputs, o, graph_ctx, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
// Clean up operation and inputs.
TF_DeleteAbstractOp(op);
TF_DeleteAbstractTensor(at);
TF_DeleteOutputList(o);
TF_DeleteExecutionContext(ctx);
TF_DeleteExecutionContext(graph_ctx);
}
TEST_P(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_OutputList* placeholder_outputs = TF_NewOutputList();
// Execute.
TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs,
graph_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs));
TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0);
// Delete placeholder op.
TF_DeleteAbstractOp(placeholder_op);
// Build an abstract operation.
auto* add_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(add_op, "Add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(add_op, "my_add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t};
TF_OutputList* add_outputs = TF_NewOutputList();
// Build eager context.
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContext* eager_execution_ctx =
TF_NewEagerExecutionContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
// Execute.
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, eager_execution_ctx,
status.get());
auto placeholder_t =
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
TF_AbstractTensorGetEagerTensor(placeholder_t, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
// Clean up operation and inputs.
TF_DeleteAbstractTensor(placeholder_t);
TF_DeleteAbstractOp(add_op);
TF_DeleteOutputList(add_outputs);
TF_DeleteOutputList(placeholder_outputs);
TF_DeleteExecutionContext(graph_ctx);
TF_DeleteExecutionContext(eager_execution_ctx);
}
TEST_P(UnifiedCAPI, TF_ExecutionContextGetTFEContextFromFunctionContextRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_ExecutionContextGetTFEContext(graph_ctx, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
TF_DeleteExecutionContext(graph_ctx);
}
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
::testing::Values("graphdef", "mlir"));
::testing::Combine(::testing::Values("graphdef",
"mlir"),
::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
::testing::Combine(::testing::Values("graphdef",
"mlir"),
::testing::Values(false)));
#endif
} // namespace
} // namespace tensorflow

View File

@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
#include <memory>
#include <vector>
#include "absl/types/optional.h"
@ -37,7 +38,6 @@ namespace tensorflow {
// TensorHandles & Operations.
class ImmediateExecutionContext : public AbstractContext {
public:
static constexpr AbstractContextKind kKind = kImmediateExecution;
// Optimized scalar creation functions
virtual AbstractTensorInterface* CreateInt64Scalar(int64 value) = 0;
virtual AbstractTensorInterface* CreateUint64Scalar(uint64 value) = 0;
@ -102,11 +102,31 @@ class ImmediateExecutionContext : public AbstractContext {
// already exists.
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
// For LLVM style RTTI.
static bool classof(const AbstractContext* ptr) {
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
}
protected:
ImmediateExecutionContext() : AbstractContext(kKind) {}
explicit ImmediateExecutionContext(AbstractContextKind kind)
: AbstractContext(kind) {}
~ImmediateExecutionContext() override {}
};
namespace internal {
struct ImmediateExecutionContextDeleter {
void operator()(ImmediateExecutionContext* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using ImmediateContextPtr =
std::unique_ptr<ImmediateExecutionContext,
internal::ImmediateExecutionContextDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_

View File

@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_
#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
@ -32,7 +34,6 @@ namespace tensorflow {
// Abstract interface to an operation.
class ImmediateExecutionOperation : public AbstractOperation {
public:
static constexpr AbstractOperationKind kKind = kImmediateExecution;
virtual void Clear() = 0;
virtual const tensorflow::OpDef* OpDef() const = 0;
@ -43,11 +44,31 @@ class ImmediateExecutionOperation : public AbstractOperation {
// Experimental
virtual Status SetUseXla(bool enable) = 0;
// For LLVM style RTTI.
static bool classof(const AbstractOperation* ptr) {
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
}
protected:
ImmediateExecutionOperation() : AbstractOperation(kKind) {}
explicit ImmediateExecutionOperation(AbstractOperationKind kind)
: AbstractOperation(kind) {}
~ImmediateExecutionOperation() override {}
};
namespace internal {
struct ImmediateExecutionOperationDeleter {
void operator()(ImmediateExecutionOperation* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using ImmediateOpPtr =
std::unique_ptr<ImmediateExecutionOperation,
internal::ImmediateExecutionOperationDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_

View File

@ -33,8 +33,6 @@ namespace tensorflow {
// is needed a static_cast can be applied.
class ImmediateExecutionTensorHandle : public AbstractTensorHandle {
public:
static constexpr AbstractTensorHandleKind kKind = kImmediateExecution;
// Returns tensor dtype.
virtual tensorflow::DataType DataType() const = 0;
// Returns number of dimensions.
@ -54,11 +52,31 @@ class ImmediateExecutionTensorHandle : public AbstractTensorHandle {
// Return a copy of the handle.
virtual ImmediateExecutionTensorHandle* Copy() = 0;
// For LLVM style RTTI.
static bool classof(const AbstractTensorHandle* ptr) {
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
}
protected:
ImmediateExecutionTensorHandle() : AbstractTensorHandle(kKind) {}
explicit ImmediateExecutionTensorHandle(AbstractTensorHandleKind kind)
: AbstractTensorHandle(kind) {}
~ImmediateExecutionTensorHandle() override {}
};
namespace internal {
struct ImmediateExecutionTensorHandleDeleter {
void operator()(ImmediateExecutionTensorHandle* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using ImmediateTensorHandlePtr =
std::unique_ptr<ImmediateExecutionTensorHandle,
internal::ImmediateExecutionTensorHandleDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_

View File

@ -262,14 +262,14 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
components.reserve(underlying_devices_.size());
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
int64_t* device_id = new int64_t;
int32_t* device_id = new int32_t;
*device_id = device_index;
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
sizeof(int64_t),
TF_INT32, /*dims=*/nullptr, /*num_dims=*/0, device_id,
sizeof(int32_t),
[](void* data, size_t, void* arg) {
delete reinterpret_cast<int64_t*>(data);
delete reinterpret_cast<int32_t*>(data);
},
nullptr),
TF_DeleteTensor);
@ -283,7 +283,7 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT32);
TFE_TensorHandle* device_handle;
int num_outputs = 1;
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);

View File

@ -296,8 +296,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
TFE_DeleteTensorHandle(result_handle);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<int64_t>(components[0].get(), 0);
ExpectScalarEq<int64_t>(components[1].get(), 1);
ExpectScalarEq<int32_t>(components[0].get(), 0);
ExpectScalarEq<int32_t>(components[1].get(), 1);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);

View File

@ -1,5 +1,5 @@
# Experimental gcs filesystem plugin.
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object", "tf_cc_test")
package(
licenses = ["notice"], # Apache 2.0
@ -19,14 +19,44 @@ tf_cc_shared_object(
cc_library(
name = "gcs_filesystem_impl",
srcs = ["gcs_filesystem.cc"],
hdrs = ["gcs_filesystem.h"],
copts = select({
"//conditions:default": [],
"//tensorflow:windows": get_win_copts(),
}),
deps = [
":gcs_helper",
"//tensorflow/c:env",
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "gcs_helper",
srcs = ["gcs_helper.cc"],
hdrs = ["gcs_helper.h"],
linkstatic = 1,
deps = [
"//tensorflow/c:env",
],
)
tf_cc_test(
name = "gcs_filesystem_test",
srcs = [
"gcs_filesystem.cc",
"gcs_filesystem_test.cc",
],
tags = [
"manual",
"notap",
],
deps = [
":gcs_filesystem_impl",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core/platform:stacktrace_handler",
"//tensorflow/core/platform:test",
],
)

View File

@ -12,12 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h"
#include <stdlib.h>
#include <string.h>
#include "absl/strings/string_view.h"
#include "google/cloud/storage/client.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/env.h"
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
#include "tensorflow/c/tf_status.h"
// Implementation of a filesystem for GCS environments.
@ -36,8 +38,8 @@ static inline void TF_SetStatusFromGCSStatus(
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); }
static void ParseGCSPath(absl::string_view fname, bool object_empty_ok,
char** bucket, char** object, TF_Status* status) {
void ParseGCSPath(const std::string& fname, bool object_empty_ok,
std::string* bucket, std::string* object, TF_Status* status) {
size_t scheme_end = fname.find("://") + 2;
if (fname.substr(0, scheme_end + 1) != "gs://") {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
@ -46,33 +48,19 @@ static void ParseGCSPath(absl::string_view fname, bool object_empty_ok,
}
size_t bucket_end = fname.find("/", scheme_end + 1);
if (bucket_end == absl::string_view::npos) {
if (bucket_end == std::string::npos) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"GCS path doesn't contain a bucket name.");
return;
}
absl::string_view bucket_view =
fname.substr(scheme_end + 1, bucket_end - scheme_end - 1);
*bucket =
static_cast<char*>(plugin_memory_allocate(bucket_view.length() + 1));
memcpy(*bucket, bucket_view.data(), bucket_view.length());
(*bucket)[bucket_view.length()] = '\0';
absl::string_view object_view = fname.substr(bucket_end + 1);
if (object_view.empty()) {
if (object_empty_ok) {
*object = nullptr;
return;
} else {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"GCS path doesn't contain an object name.");
return;
}
*bucket = fname.substr(scheme_end + 1, bucket_end - scheme_end - 1);
*object = fname.substr(bucket_end + 1);
if (object->empty() && !object_empty_ok) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"GCS path doesn't contain an object name.");
}
*object =
static_cast<char*>(plugin_memory_allocate(object_view.length() + 1));
// object_view.data() is a null-terminated string_view because fname is.
strcpy(*object, object_view.data());
}
// SECTION 1. Implementation for `TF_RandomAccessFile`
@ -86,6 +74,18 @@ namespace tf_random_access_file {
// SECTION 2. Implementation for `TF_WritableFile`
// ----------------------------------------------------------------------------
namespace tf_writable_file {
typedef struct GCSFile {
const std::string bucket;
const std::string object;
gcs::Client* gcs_client; // not owned
TempFile outfile;
bool sync_need;
} GCSFile;
static void Cleanup(TF_WritableFile* file) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
delete gcs_file;
}
// TODO(vnvo2409): Implement later
@ -104,7 +104,7 @@ namespace tf_read_only_memory_region {
namespace tf_gcs_filesystem {
// TODO(vnvo2409): Add lazy-loading and customizing parameters.
static void Init(TF_Filesystem* filesystem, TF_Status* status) {
void Init(TF_Filesystem* filesystem, TF_Status* status) {
google::cloud::StatusOr<gcs::Client> client =
gcs::Client::CreateDefaultClient();
if (!client) {
@ -117,8 +117,52 @@ static void Init(TF_Filesystem* filesystem, TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
}
void Cleanup(TF_Filesystem* filesystem) {
plugin_memory_free(filesystem->plugin_filesystem);
}
// TODO(vnvo2409): Implement later
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_client = static_cast<gcs::Client*>(filesystem->plugin_filesystem);
char* temp_file_name = TF_GetTempFileName("");
file->plugin_file = new tf_writable_file::GCSFile(
{std::move(bucket), std::move(object), gcs_client,
TempFile(temp_file_name, std::ios::binary | std::ios::out), true});
// We are responsible for freeing the pointer returned by TF_GetTempFileName
free(temp_file_name);
TF_SetStatus(status, TF_OK, "");
}
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_client = static_cast<gcs::Client*>(filesystem->plugin_filesystem);
char* temp_file_name = TF_GetTempFileName("");
auto gcs_status = gcs_client->DownloadToFile(bucket, object, temp_file_name);
TF_SetStatusFromGCSStatus(gcs_status, status);
auto status_code = TF_GetCode(status);
if (status_code != TF_OK && status_code != TF_NOT_FOUND) {
return;
}
// If this file does not exist on server, we will need to sync it.
bool sync_need = (status_code == TF_NOT_FOUND);
file->plugin_file = new tf_writable_file::GCSFile(
{std::move(bucket), std::move(object), gcs_client,
TempFile(temp_file_name, std::ios::binary | std::ios::app), sync_need});
free(temp_file_name);
TF_SetStatus(status, TF_OK, "");
}
} // namespace tf_gcs_filesystem
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
@ -126,9 +170,17 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri);
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_gcs_filesystem::Init;
ops->filesystem_ops->cleanup = tf_gcs_filesystem::Cleanup;
ops->filesystem_ops->new_writable_file = tf_gcs_filesystem::NewWritableFile;
ops->filesystem_ops->new_appendable_file =
tf_gcs_filesystem::NewAppendableFile;
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {

View File

@ -0,0 +1,34 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_
#include "google/cloud/storage/client.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h"
void ParseGCSPath(const std::string& fname, bool object_empty_ok,
std::string* bucket, std::string* object, TF_Status* status);
namespace tf_gcs_filesystem {
void Init(TF_Filesystem* filesystem, TF_Status* status);
void Cleanup(TF_Filesystem* filesystem);
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status);
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status);
} // namespace tf_gcs_filesystem
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_

View File

@ -0,0 +1,75 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/platform/stacktrace_handler.h"
#include "tensorflow/core/platform/test.h"
#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x))
namespace tensorflow {
namespace {
class GCSFilesystemTest : public ::testing::Test {
public:
void SetUp() override {
status_ = TF_NewStatus();
filesystem_ = new TF_Filesystem;
tf_gcs_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_) << "Can not initialize filesystem. "
<< TF_Message(status_);
}
void TearDown() override {
TF_DeleteStatus(status_);
tf_gcs_filesystem::Cleanup(filesystem_);
delete filesystem_;
}
protected:
TF_Filesystem* filesystem_;
TF_Status* status_;
};
TEST_F(GCSFilesystemTest, ParseGCSPath) {
std::string bucket, object;
ParseGCSPath("gs://bucket/path/to/object", false, &bucket, &object, status_);
ASSERT_TF_OK(status_);
ASSERT_EQ(bucket, "bucket");
ASSERT_EQ(object, "path/to/object");
ParseGCSPath("gs://bucket/", true, &bucket, &object, status_);
ASSERT_TF_OK(status_);
ASSERT_EQ(bucket, "bucket");
ParseGCSPath("bucket/path/to/object", false, &bucket, &object, status_);
ASSERT_EQ(TF_GetCode(status_), TF_INVALID_ARGUMENT);
// bucket name must end with "/"
ParseGCSPath("gs://bucket", true, &bucket, &object, status_);
ASSERT_EQ(TF_GetCode(status_), TF_INVALID_ARGUMENT);
ParseGCSPath("gs://bucket/", false, &bucket, &object, status_);
ASSERT_EQ(TF_GetCode(status_), TF_INVALID_ARGUMENT);
}
} // namespace
} // namespace tensorflow
GTEST_API_ int main(int argc, char** argv) {
tensorflow::testing::InstallStacktraceHandler();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -0,0 +1,34 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
#include <stdio.h>
#include <fstream>
#include <string>
#include <utility>
TempFile::TempFile(const char* temp_file_name, std::ios::openmode mode)
: std::fstream(temp_file_name, mode), name_(temp_file_name) {}
TempFile::TempFile(TempFile&& rhs)
: std::fstream(std::move(rhs)), name_(std::move(rhs.name_)) {}
TempFile::~TempFile() {
std::fstream::close();
std::remove(name_.c_str());
}
const std::string TempFile::getName() const { return name_; }

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_
#include <fstream>
#include <string>
class TempFile : public std::fstream {
public:
// We should specify openmode each time we call TempFile.
TempFile(const char* temp_file_name, std::ios::openmode mode);
TempFile(TempFile&& rhs);
~TempFile() override;
const std::string getName() const;
private:
const std::string name_;
};
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_

View File

@ -3,6 +3,10 @@
# Targets in this directory are pure C++ "Classes" underlying the C API types
# under tf/c/experimental/saved_model/public/. They are subject to change and
# have visibility limited to Tensorflow's implementation only.
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
)
package(
default_visibility = [
@ -47,6 +51,46 @@ cc_library(
],
)
cc_library(
name = "saved_model_utils",
srcs = [
"saved_model_utils.cc",
],
hdrs = [
"saved_model_utils.h",
],
deps = [
"//tensorflow/c:tf_tensor_internal",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "test_utils",
testonly = True,
srcs = [
"test_utils.cc",
],
hdrs = [
"test_utils.h",
],
deps = [
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core/common_runtime:core_cpu_lib",
"//tensorflow/core/common_runtime/eager:context",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "tf_saved_model_impl",
srcs = [
@ -84,3 +128,47 @@ filegroup(
],
visibility = ["//tensorflow/core:__pkg__"],
)
tf_cc_test(
name = "constant_loading_test",
srcs = [
"constant_loading_test.cc",
],
deps = [
":saved_model_utils",
":test_utils",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime:core_cpu_lib",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:core",
],
)
tf_cc_test(
name = "saved_variable_loading_test",
srcs = [
"saved_variable_loading_test.cc",
],
deps = [
":saved_model_utils",
":test_utils",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
"//tensorflow/core:all_kernels",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime:core_cpu_lib",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:core",
],
)

View File

@ -0,0 +1,111 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <vector>
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
class ConstantTest : public ::testing::TestWithParam<
std::tuple<DataType, std::vector<int64>, bool>> {
public:
ConstantTest()
: device_mgr_(testing::CreateTestingDeviceMgr()),
ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {}
EagerContext* context() { return ctx_.get(); }
private:
std::unique_ptr<StaticDeviceMgr> device_mgr_;
EagerContextPtr ctx_;
};
// Basic sanity check that roundtripping a Tensor->Tensorproto->Constant
// preserves values.
TEST_P(ConstantTest, CreateConstantSuccessful) {
// Get test parameters
auto& test_params = GetParam();
DataType dtype = std::get<0>(test_params);
TensorShape shape(std::get<1>(test_params));
bool tensorproto_use_tensor_content = std::get<2>(test_params);
// Construct a Tensor with the given dtype + shape
Tensor expected(dtype, shape);
testing::FillNumericTensorBuffer(expected.dtype(), expected.NumElements(),
expected.data(), 42);
// Serialize it to a Tensorproto
TensorProto proto;
if (tensorproto_use_tensor_content) {
expected.AsProtoTensorContent(&proto);
} else {
expected.AsProtoField(&proto);
}
// Revival should succeed w/o errors
std::unique_ptr<Constant> revived;
TF_EXPECT_OK(internal::TensorProtoToConstant(context(), proto, &revived));
// The revived tensorhandle should have the exact same dtype, shape, +
// approx equivalent data to the original.
ImmediateExecutionTensorHandle* handle = revived->handle();
Status status;
AbstractTensorPtr revived_tensor(handle->Resolve(&status));
TF_EXPECT_OK(status) << "Failed to convert tensorhandle to tensor";
EXPECT_EQ(revived_tensor->Type(), expected.dtype());
EXPECT_EQ(revived_tensor->NumElements(), expected.NumElements());
EXPECT_EQ(revived_tensor->NumDims(), expected.dims());
for (int i = 0; i < expected.dims(); ++i) {
EXPECT_EQ(revived_tensor->Dim(i), expected.dim_size(i));
}
testing::CheckBufferDataIsEqual(expected.dtype(), expected.NumElements(),
revived_tensor->Data(), expected.data());
}
// Test against combinations of tensors that are
// 1. Varying dtypes
// 2. Varying shapes
// 3. TensorProto serialized using tensor_content vs repeated type
INSTANTIATE_TEST_SUITE_P(
ConstantIntegerDtypesTest, ConstantTest,
::testing::Combine(
::testing::ValuesIn(testing::DataTypeSetToVector(kDataTypeIsInteger)),
::testing::ValuesIn(testing::InterestingShapes()),
::testing::Values(false, true)));
INSTANTIATE_TEST_SUITE_P(
ConstantFloatingDtypesTest, ConstantTest,
::testing::Combine(::testing::Values(DT_FLOAT, DT_DOUBLE),
::testing::ValuesIn(testing::InterestingShapes()),
::testing::Values(false, true)));
} // namespace
} // namespace tensorflow

View File

@ -14,44 +14,6 @@ package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "owned_eager_op",
hdrs = [
"owned_eager_op.h",
],
deps = [
"//tensorflow/c/eager:immediate_execution_operation",
],
)
cc_library(
name = "owned_tensor_handle",
hdrs = [
"owned_tensor_handle.h",
],
deps = [
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core/common_runtime/eager:tensor_handle",
],
)
cc_library(
name = "owned_eager_context",
hdrs = ["owned_eager_context.h"],
deps = [
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/core/common_runtime/eager:context",
],
)
cc_library(
name = "owned_tensor",
hdrs = ["owned_tensor.h"],
deps = [
"//tensorflow/c:tensor_interface",
],
)
cc_library(
name = "variable_ops",
srcs = [
@ -61,14 +23,15 @@ cc_library(
"variable_ops.h",
],
deps = [
":owned_eager_op",
":owned_tensor_handle",
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/types:span",
],
)
@ -78,14 +41,13 @@ tf_cc_test(
srcs = [
"variable_ops_test.cc",
],
tags = [
"no_windows", # TODO(b/159210739): Remove this tag after fixing the bug.
],
deps = [
":owned_eager_context",
":owned_tensor",
":owned_tensor_handle",
":variable_ops",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:all_kernels",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",

View File

@ -1,54 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_EAGER_CONTEXT_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_EAGER_CONTEXT_H_
#include <memory>
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/core/common_runtime/eager/context.h"
namespace tensorflow {
namespace internal {
struct ImmediateExecutionContextDeleter {
void operator()(ImmediateExecutionContext* p) const {
if (p != nullptr) {
p->Release();
}
}
};
struct EagerContextDeleter {
void operator()(EagerContext* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using AbstractContextPtr =
std::unique_ptr<ImmediateExecutionContext,
internal::ImmediateExecutionContextDeleter>;
using EagerContextPtr =
std::unique_ptr<EagerContext, internal::EagerContextDeleter>;
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_EAGER_CONTEXT_H_

View File

@ -1,54 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_TENSOR_HANDLE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_TENSOR_HANDLE_H_
#include <memory>
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
namespace tensorflow {
namespace internal {
struct TensorHandleDeleter {
void operator()(TensorHandle* p) const {
if (p != nullptr) {
p->Release();
}
}
};
struct AbstractTensorHandleDeleter {
void operator()(ImmediateExecutionTensorHandle* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using TensorHandlePtr =
std::unique_ptr<TensorHandle, internal::TensorHandleDeleter>;
using AbstractTensorHandlePtr =
std::unique_ptr<ImmediateExecutionTensorHandle,
internal::AbstractTensorHandleDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_TENSOR_HANDLE_H_

View File

@ -16,13 +16,15 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/experimental/saved_model/core/ops/owned_eager_op.h"
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
@ -35,8 +37,8 @@ static const char kNoSharingResourceID[] =
Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
DataType dtype, TensorShape shape,
AbstractTensorHandlePtr* handle) {
AbstractOpPtr varhandle_op = AbstractOpPtr(ctx->CreateOperation());
ImmediateTensorHandlePtr* handle) {
ImmediateOpPtr varhandle_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(varhandle_op->Reset("VarHandleOp", nullptr));
TF_RETURN_IF_ERROR(varhandle_op->SetAttrType("dtype", dtype));
@ -55,17 +57,20 @@ Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
int num_retvals = 1;
TF_RETURN_IF_ERROR(varhandle_op->Execute(
absl::MakeSpan(&var_handle, num_retvals), &num_retvals));
if (var_handle->getKind() != ImmediateExecutionTensorHandle::kKind) {
AbstractTensorHandlePtr owned_var_handle(var_handle);
if (!tensorflow::isa<ImmediateExecutionTensorHandle>(
owned_var_handle.get())) {
return errors::Internal("Unexpected tensor handle kind.");
}
handle->reset(reinterpret_cast<ImmediateExecutionTensorHandle*>(var_handle));
handle->reset(reinterpret_cast<ImmediateExecutionTensorHandle*>(
owned_var_handle.release()));
return Status();
}
Status AssignVariable(ImmediateExecutionContext* ctx,
ImmediateExecutionTensorHandle* variable_handle,
DataType dtype, ImmediateExecutionTensorHandle* value) {
AbstractOpPtr assign_op(ctx->CreateOperation());
ImmediateOpPtr assign_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(assign_op->Reset("AssignVariableOp", nullptr));
TF_RETURN_IF_ERROR(assign_op->SetAttrType("dtype", dtype));
TF_RETURN_IF_ERROR(assign_op->AddInput(variable_handle));
@ -78,8 +83,8 @@ Status AssignVariable(ImmediateExecutionContext* ctx,
Status ReadVariable(ImmediateExecutionContext* ctx,
ImmediateExecutionTensorHandle* variable_handle,
DataType dtype, AbstractTensorHandlePtr* output) {
AbstractOpPtr read_op = AbstractOpPtr(ctx->CreateOperation());
DataType dtype, ImmediateTensorHandlePtr* output) {
ImmediateOpPtr read_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(read_op->Reset("ReadVariableOp", nullptr));
TF_RETURN_IF_ERROR(read_op->SetAttrType("dtype", dtype));
TF_RETURN_IF_ERROR(read_op->AddInput(variable_handle));
@ -88,16 +93,18 @@ Status ReadVariable(ImmediateExecutionContext* ctx,
int num_retvals = 1;
TF_RETURN_IF_ERROR(
read_op->Execute(absl::MakeSpan(&value, num_retvals), &num_retvals));
if (value->getKind() != ImmediateExecutionTensorHandle::kKind) {
AbstractTensorHandlePtr owned_value(value);
if (!tensorflow::isa<ImmediateExecutionTensorHandle>(owned_value.get())) {
return errors::Internal("Unexpected tensor handle kind.");
}
output->reset(reinterpret_cast<ImmediateExecutionTensorHandle*>(value));
output->reset(
reinterpret_cast<ImmediateExecutionTensorHandle*>(owned_value.release()));
return Status();
}
Status DestroyResource(ImmediateExecutionContext* ctx,
ImmediateExecutionTensorHandle* handle) {
AbstractOpPtr destroy_op = AbstractOpPtr(ctx->CreateOperation());
ImmediateOpPtr destroy_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(destroy_op->Reset("DestroyResourceOp", nullptr));
TF_RETURN_IF_ERROR(destroy_op->SetAttrBool("ignore_lookup_error", true));
TF_RETURN_IF_ERROR(destroy_op->AddInput(handle));

View File

@ -18,7 +18,6 @@ limitations under the License.
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
@ -32,7 +31,7 @@ namespace internal {
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L1867-L1872
Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
DataType dtype, TensorShape shape,
AbstractTensorHandlePtr* handle);
ImmediateTensorHandlePtr* handle);
// Executes an AssignVariableOp using `ctx`, assigning the variable associated
// with `variable_handle` with `value`. `dtype` must be the datatype of the
@ -48,7 +47,7 @@ Status AssignVariable(ImmediateExecutionContext* ctx,
// the dtype of the variable associated with `variable_handle`.
Status ReadVariable(ImmediateExecutionContext* ctx,
ImmediateExecutionTensorHandle* variable_handle,
DataType dtype, AbstractTensorHandlePtr* output);
DataType dtype, ImmediateTensorHandlePtr* output);
// Executes DestroyResourceOp on `handle`, using `ctx`. This is equivalent to
// the cleanup that occurs in a tf.Variable's EagerResourceDeleter:

View File

@ -17,9 +17,8 @@ limitations under the License.
#include <memory>
#include "tensorflow/c/experimental/saved_model/core/ops/owned_eager_context.h"
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor.h"
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/tensor.h"
@ -30,10 +29,10 @@ limitations under the License.
namespace tensorflow {
namespace {
AbstractTensorHandlePtr CreateScalarTensorHandle(EagerContext* context,
float value) {
ImmediateTensorHandlePtr CreateScalarTensorHandle(EagerContext* context,
float value) {
AbstractTensorPtr tensor(context->CreateFloatScalar(value));
AbstractTensorHandlePtr handle(context->CreateLocalHandle(tensor.get()));
ImmediateTensorHandlePtr handle(context->CreateLocalHandle(tensor.get()));
return handle;
}
@ -62,7 +61,7 @@ class VariableOpsTest : public ::testing::Test {
// Sanity check for variable creation
TEST_F(VariableOpsTest, CreateVariableSuccessful) {
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
AbstractTensorHandlePtr handle;
ImmediateTensorHandlePtr handle;
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
context(), DT_FLOAT, {}, &handle));
// The created TensorHandle should be a DT_Resource
@ -72,7 +71,7 @@ TEST_F(VariableOpsTest, CreateVariableSuccessful) {
// Sanity check for variable destruction
TEST_F(VariableOpsTest, DestroyVariableSuccessful) {
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
AbstractTensorHandlePtr handle;
ImmediateTensorHandlePtr handle;
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
context(), DT_FLOAT, {}, &handle));
@ -83,18 +82,18 @@ TEST_F(VariableOpsTest, DestroyVariableSuccessful) {
// Sanity check for handle assignment and reading
TEST_F(VariableOpsTest, AssignVariableAndReadSuccessful) {
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
AbstractTensorHandlePtr variable;
ImmediateTensorHandlePtr variable;
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
context(), DT_FLOAT, {}, &variable));
// Create a Scalar float TensorHandle with value 42, and assign it to
// the variable.
AbstractTensorHandlePtr my_value = CreateScalarTensorHandle(context(), 42.0);
ImmediateTensorHandlePtr my_value = CreateScalarTensorHandle(context(), 42.0);
TF_EXPECT_OK(internal::AssignVariable(context(), variable.get(), DT_FLOAT,
my_value.get()));
// Read back the value from the variable, and check that it is 42.
AbstractTensorHandlePtr read_value_handle;
ImmediateTensorHandlePtr read_value_handle;
TF_EXPECT_OK(internal::ReadVariable(context(), variable.get(), DT_FLOAT,
&read_value_handle));
Status status;

View File

@ -0,0 +1,60 @@
# This package contains classes corresponding to Revived SavedObjectGraph types
# used by SavedModel. See https://cs.opensource.google/tensorflow/tensorflow/+/c575e2ba93c442121d98d3f125d83fed1339924d:tensorflow/core/protobuf/saved_object_graph.proto;l=56-62
package(
default_visibility = [
# Restricting visibility for now
"//tensorflow/c/experimental/saved_model/core:__pkg__",
],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "constant",
srcs = [
"constant.cc",
],
hdrs = [
"constant.h",
],
deps = [
":tensorhandle_convertible",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:tensor_handle",
],
)
cc_library(
name = "variable",
srcs = [
"variable.cc",
],
hdrs = [
"variable.h",
],
deps = [
":tensorhandle_convertible",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core/ops:variable_ops",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:context",
"@com_google_absl//absl/types:optional",
],
)
cc_library(
name = "tensorhandle_convertible",
hdrs = [
"tensorhandle_convertible.h",
],
deps = [
"//tensorflow/c/eager:immediate_execution_tensor_handle",
],
)

View File

@ -0,0 +1,46 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include <memory>
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
Constant::Constant(ImmediateTensorHandlePtr handle)
: TensorHandleConvertible(std::move(handle)) {}
Status Constant::Create(ImmediateExecutionContext* ctx,
AbstractTensorInterface* tensor,
std::unique_ptr<Constant>* output) {
ImmediateExecutionTensorHandle* handle = ctx->CreateLocalHandle(tensor);
if (handle == nullptr) {
return errors::Internal("Failed to convert tensor to tensorhandle");
}
output->reset(new Constant(ImmediateTensorHandlePtr(handle)));
return Status();
}
} // namespace tensorflow

View File

@ -0,0 +1,55 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_CONSTANT_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_CONSTANT_H_
#include <memory>
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/tensor.pb.h"
namespace tensorflow {
// This class corresponds to python's tf.constant, which is effectively a
// TensorHandle explicitly initialized to some value.
// For now this doesn't do much beyond wrap Context's CreateLocalHandle method,
// and offer a subclass of TensorHandleConvertible. Note that similar to
// the python's eager mode logic, we bypass calling the "Const" op:
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/framework/constant_op.py#L301
class Constant : public TensorHandleConvertible {
public:
static Status Create(ImmediateExecutionContext* ctx,
AbstractTensorInterface* tensor,
std::unique_ptr<Constant>* output);
// RevivedConstant is movable, but not copyable.
Constant(Constant&& other) = default;
Constant& operator=(Constant&& other) = default;
~Constant() override = default;
private:
explicit Constant(ImmediateTensorHandlePtr handle);
Constant(const Constant&) = delete;
Constant& operator=(const Constant&) = delete;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_CONSTANT_H_

View File

@ -0,0 +1,49 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSORHANDLE_CONVERTIBLE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSORHANDLE_CONVERTIBLE_H_
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
namespace tensorflow {
// A common interface for objects that can be converted to a TensorHandle.
// Examples of objects that implement this include Variables, Constants, Assets,
// etc. This is used to convert captured objects into a ConcreteFunction's
// captured TensorHandles:
// https://github.com/tensorflow/tensorflow/blob/676a68963ea4b64fe479b9cede06aa8f5b290ab8/tensorflow/python/saved_model/load.py#L229-L240
class TensorHandleConvertible {
public:
explicit TensorHandleConvertible(ImmediateTensorHandlePtr handle)
: handle_(std::move(handle)) {}
ImmediateExecutionTensorHandle* handle() { return handle_.get(); }
// TensorHandleConvertible is movable, but not copyable.
TensorHandleConvertible(TensorHandleConvertible&& other) = default;
TensorHandleConvertible& operator=(TensorHandleConvertible&& other) = default;
virtual ~TensorHandleConvertible() = default;
protected:
TensorHandleConvertible(const TensorHandleConvertible&) = delete;
TensorHandleConvertible& operator=(const TensorHandleConvertible&) = delete;
ImmediateTensorHandlePtr handle_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSORHANDLE_CONVERTIBLE_H_

View File

@ -0,0 +1,78 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
#include <memory>
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
Variable::Variable(ImmediateExecutionContext* ctx, DataType dtype,
TensorShape shape, absl::optional<std::string> name,
ImmediateTensorHandlePtr handle)
: TensorHandleConvertible(std::move(handle)),
name_(name.has_value() ? *name : "Variable"),
dtype_(dtype),
shape_(shape),
ctx_(ctx) {}
Variable::~Variable() {
// If the handle is null (perhaps because variable was std::moved from), then
// we don't have to do anything.
if (handle_ == nullptr) {
return;
}
Status status = internal::DestroyResource(ctx_, handle_.get());
if (!status.ok()) {
LOG(ERROR) << "Error destroying variable: " << name_
<< "due to: " << status;
}
}
DataType Variable::dtype() { return dtype_; }
TensorShape Variable::shape() { return shape_; }
Status Variable::Assign(ImmediateExecutionTensorHandle* handle) {
return internal::AssignVariable(ctx_, handle_.get(), dtype_, handle);
}
Status Variable::ReadValue(ImmediateTensorHandlePtr* out) {
return internal::ReadVariable(ctx_, handle_.get(), dtype_, out);
}
Status Variable::CreateUninitialized(ImmediateExecutionContext* ctx,
DataType dtype, TensorShape shape,
absl::optional<std::string> name,
std::unique_ptr<Variable>* output) {
ImmediateTensorHandlePtr handle;
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
ctx, dtype, shape, &handle));
output->reset(
new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));
return Status();
}
} // namespace tensorflow

View File

@ -0,0 +1,76 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_VARIABLE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_VARIABLE_H_
#include <memory>
#include "absl/types/optional.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
namespace tensorflow {
class Variable : public TensorHandleConvertible {
public:
// Creates an uninitialized resource variable. Note that a caller must
// call "assign" to associate a value with the variable.
static Status CreateUninitialized(ImmediateExecutionContext* ctx,
DataType dtype, TensorShape shape,
absl::optional<std::string> name,
std::unique_ptr<Variable>* output);
// The dtype of the underlying variable.
DataType dtype();
// The shape of the underlying variable.
TensorShape shape();
// Updates the variable's contents with `handle`.
Status Assign(ImmediateExecutionTensorHandle* handle);
// Reads the value of the variable, and stores it in `out`
Status ReadValue(ImmediateTensorHandlePtr* out);
// Variable is movable, but not copyable.
Variable(Variable&& other) = default;
Variable& operator=(Variable&& other) = default;
~Variable() override;
private:
Variable(ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape,
absl::optional<std::string> name, ImmediateTensorHandlePtr handle);
Variable(const Variable& variable) = delete;
Variable& operator=(const Variable&) = delete;
std::string name_;
DataType dtype_;
TensorShape shape_;
// ctx_ must outlive Variable.
ImmediateExecutionContext* ctx_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_VARIABLE_H_

View File

@ -0,0 +1,58 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
#include <memory>
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
#include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
namespace tensorflow {
namespace internal {
Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
const TensorProto& proto,
std::unique_ptr<Constant>* output) {
tensorflow::Tensor tensor;
bool parse_result = tensor.FromProto(proto);
if (!parse_result) {
return errors::Internal("Failed to parse tensor from tensorproto");
}
TensorInterface tensor_interface(std::move(tensor));
return Constant::Create(ctx, &tensor_interface, output);
}
// This follows the python variable restoration logic:
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L407
Status LoadSavedVariable(ImmediateExecutionContext* ctx,
const SavedVariable& variable,
std::unique_ptr<Variable>* output) {
const std::string& name = variable.name();
tensorflow::TensorShape shape(variable.shape());
tensorflow::DataType dtype = variable.dtype();
TF_RETURN_IF_ERROR(
Variable::CreateUninitialized(ctx, dtype, shape, name, output));
return Status();
}
} // namespace internal
} // namespace tensorflow

View File

@ -0,0 +1,49 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_
// Some internal utility functions for the SavedModelAPI, factored out into a
// separately unit-testable header.
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
namespace tensorflow {
namespace internal {
// Load a TensorProto into a tensorflow::Constant. This is similar to the
// constant loading logic in python:
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L437
Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
const TensorProto& proto,
std::unique_ptr<Constant>* output);
// Creates a tensorflow::Variable from a SavedVariable. This is similar to the
// logic in:
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L407
// Note that the caller **must assign a value** to the loaded variable.
Status LoadSavedVariable(ImmediateExecutionContext* ctx,
const SavedVariable& variable,
std::unique_ptr<Variable>* output);
} // namespace internal
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_

View File

@ -0,0 +1,122 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <vector>
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
namespace tensorflow {
namespace {
class SavedVariableLoadingTest : public ::testing::TestWithParam<
std::tuple<DataType, std::vector<int64>>> {
public:
SavedVariableLoadingTest()
: device_mgr_(testing::CreateTestingDeviceMgr()),
ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {}
EagerContext* context() { return ctx_.get(); }
private:
std::unique_ptr<StaticDeviceMgr> device_mgr_;
EagerContextPtr ctx_;
};
// Sanity check that constructing a tensorflow::Variable from a SavedVariable
// 1. does not cause an error
// 2. preserves dtype and shape.
TEST_P(SavedVariableLoadingTest, LoadSavedVariableSuccessful) {
auto& test_params = GetParam();
DataType dtype = std::get<0>(test_params);
TensorShape shape(std::get<1>(test_params));
SavedVariable saved_variable;
saved_variable.set_dtype(dtype);
shape.AsProto(saved_variable.mutable_shape());
std::unique_ptr<Variable> var;
TF_EXPECT_OK(internal::LoadSavedVariable(context(), saved_variable, &var));
EXPECT_EQ(var->dtype(), dtype);
EXPECT_EQ(var->shape(), shape);
}
// Assigning and reading values should yield
// consistent results.
TEST_P(SavedVariableLoadingTest, AssignAndReadVariableSuccesful) {
auto& test_params = GetParam();
DataType dtype = std::get<0>(test_params);
std::vector<int64> shape_vector = std::get<1>(test_params);
TensorShape shape(shape_vector);
// Create the variable.
Status status;
std::unique_ptr<Variable> var;
TF_EXPECT_OK(Variable::CreateUninitialized(context(), dtype, shape,
absl::nullopt, &var));
// Create a TensorHandle
ImmediateTensorHandlePtr expected_handle =
testing::CreateTensorHandle(context(), dtype, shape_vector, 42);
AbstractTensorPtr expected_tensor(expected_handle->Resolve(&status));
TF_EXPECT_OK(status) << status.error_message();
// Assign the tensorhandle to the variable.
TF_EXPECT_OK(var->Assign(expected_handle.get()));
// Read back the value from the variable
ImmediateTensorHandlePtr output_handle;
TF_EXPECT_OK(var->ReadValue(&output_handle));
AbstractTensorPtr output_tensor(output_handle->Resolve(&status));
TF_EXPECT_OK(status) << status.error_message();
// Check that output_tensor == expected_tensor
EXPECT_EQ(output_tensor->Type(), expected_tensor->Type());
EXPECT_EQ(output_tensor->NumElements(), expected_tensor->NumElements());
testing::CheckBufferDataIsEqual(
output_tensor->Type(), output_tensor->NumElements(),
output_tensor->Data(), expected_tensor->Data());
}
// Test against combinations of SavedVariables of
// 1. Varying dtypes
// 2. Varying shapes
INSTANTIATE_TEST_SUITE_P(
SavedVariableIntegerDtypesTest, SavedVariableLoadingTest,
::testing::Combine(
::testing::ValuesIn(testing::DataTypeSetToVector(kDataTypeIsInteger)),
::testing::ValuesIn(testing::InterestingShapes())));
INSTANTIATE_TEST_SUITE_P(
SavedVariableFloatingDtypesTest, SavedVariableLoadingTest,
::testing::Combine(::testing::Values(DT_FLOAT, DT_DOUBLE),
::testing::ValuesIn(testing::InterestingShapes())));
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,143 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
#include <memory>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace testing {
std::unique_ptr<StaticDeviceMgr> CreateTestingDeviceMgr() {
return std::make_unique<StaticDeviceMgr>(
DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:0"));
}
EagerContextPtr CreateTestingEagerContext(DeviceMgr* device_mgr) {
return EagerContextPtr(new EagerContext(
SessionOptions(),
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
tensorflow::ContextMirroringPolicy::MIRRORING_NONE,
/* async= */ false,
/* lazy_copy_function_remote_inputs= */ false, device_mgr,
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
/* custom_kernel_creator= */ nullptr,
/* cluster_flr= */ nullptr));
}
std::vector<DataType> DataTypeSetToVector(DataTypeSet set) {
std::vector<DataType> result;
result.reserve(set.size());
for (DataType dt : set) {
result.push_back(dt);
}
return result;
}
std::vector<std::vector<int64>> InterestingShapes() {
std::vector<std::vector<int64>> interesting_shapes;
interesting_shapes.push_back({}); // Scalar
interesting_shapes.push_back({10}); // 1D Vector
interesting_shapes.push_back({3, 3}); // 2D Matrix
interesting_shapes.push_back({1, 4, 6, 10}); // Higher Dimension Tensor
return interesting_shapes;
}
ImmediateTensorHandlePtr CreateTensorHandle(ImmediateExecutionContext* ctx,
DataType dtype,
absl::Span<const int64> shape,
int8 value) {
AbstractTensorPtr tensor(ctx->CreateTensor(dtype, shape));
CHECK_NE(tensor.get(), nullptr)
<< "Tensor creation failed for tensor of dtype: "
<< DataTypeString(dtype);
CHECK_EQ(tensor->Type(), dtype);
for (int i = 0; i < shape.size(); ++i) {
CHECK_EQ(tensor->Dim(i), shape[i]);
}
FillNumericTensorBuffer(tensor->Type(), tensor->NumElements(), tensor->Data(),
value);
ImmediateTensorHandlePtr handle(ctx->CreateLocalHandle(tensor.get()));
CHECK_NE(handle.get(), nullptr);
return handle;
}
void FillNumericTensorBuffer(DataType dtype, size_t num_elements, void* buffer,
int8 value) {
switch (dtype) {
#define CASE(type) \
case DataTypeToEnum<type>::value: { \
type* typed_buffer = static_cast<type*>(buffer); \
for (size_t i = 0; i < num_elements; ++i) { \
typed_buffer[i] = value; \
} \
break; \
}
TF_CALL_INTEGRAL_TYPES(CASE);
TF_CALL_double(CASE);
TF_CALL_float(CASE);
#undef CASE
default:
CHECK(false) << "Unsupported data type: " << DataTypeString(dtype);
break;
}
}
// Checks the underlying data is equal for the buffers for two numeric tensors.
// Note: The caller must ensure to check that the dtypes and sizes of the
// underlying buffers are the same before calling this.
void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
void* b) {
switch (dtype) {
#define CASE(type) \
case DataTypeToEnum<type>::value: { \
type* typed_a = static_cast<type*>(a); \
type* typed_b = static_cast<type*>(b); \
for (int64 i = 0; i < num_elements; ++i) { \
if (DataTypeIsFloating(dtype)) { \
EXPECT_FLOAT_EQ(typed_a[i], typed_b[i]); \
} else { \
EXPECT_EQ(typed_a[i], typed_b[i]); \
} \
} \
break; \
}
TF_CALL_INTEGRAL_TYPES(CASE);
TF_CALL_double(CASE);
TF_CALL_float(CASE);
#undef CASE
default:
CHECK(false) << "Unsupported data type: " << DataTypeString(dtype);
}
}
} // namespace testing
} // namespace tensorflow

View File

@ -0,0 +1,75 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TEST_UTILS_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TEST_UTILS_H_
#include <memory>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace testing {
// Creates a DeviceMgr suitable for local tests.
std::unique_ptr<StaticDeviceMgr> CreateTestingDeviceMgr();
// Creates an EagerContext suitable for local tests. Does not take ownership
// of `device_mgr`.
EagerContextPtr CreateTestingEagerContext(DeviceMgr* device_mgr);
// Converts a tensorflow::DatatypeSet to std::vector<DataType>.
// This is useful for tests using GTest's ::testing::ValuesIn, since
// DataTypeSet doesn't fullfill all the constraints of an STL-like iterable.
std::vector<DataType> DataTypeSetToVector(DataTypeSet set);
// Returns a vector of shapes intended to be "interesting" test cases.
// Currently, this returns scalar, 1D vector, 2D matrix, and a 4D tensor shapes
std::vector<std::vector<int64>> InterestingShapes();
// Returns a TensorHandle of `dtype` and `shape`, filled with `value`.
// `dtype` must be an integer dtype, float, or double.
// If a TensorHandle cannot be created successfully, this function will
// CHECK fail. This should only be used for testing purposes.
ImmediateTensorHandlePtr CreateTensorHandle(ImmediateExecutionContext* ctx,
DataType dtype,
absl::Span<const int64> shape,
int8 value);
// Fills a numeric tensor's buffer with `value`.
// dtype must be any integer dtype, float or double.
void FillNumericTensorBuffer(DataType dtype, size_t num_elements, void* buffer,
int8 value);
// Checks the underlying data is equal for the buffers for two numeric tensors.
// Note: The caller must ensure to check that the dtypes and sizes of the
// underlying buffers are the same before calling this.
// dtype must be any integer dtype, float, or double.
void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
void* b);
} // namespace testing
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TEST_UTILS_H_

View File

@ -54,6 +54,20 @@ class AbstractTensorInterface {
virtual ~AbstractTensorInterface() {}
};
namespace internal {
struct AbstractTensorInterfaceDeleter {
void operator()(AbstractTensorInterface* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using AbstractTensorPtr =
std::unique_ptr<AbstractTensorInterface,
internal::AbstractTensorInterfaceDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_TENSOR_INTERFACE_H_

View File

@ -259,9 +259,6 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) {
RunTest(x, x_init_value, y, y_shape);
}
// TODO(rocm):
// Re-enable this test once 3D pooling is supported on ROCm platform
#ifndef TENSORFLOW_USE_ROCM
TEST_F(NNGradTest, MaxPool3DGradHelper) {
TensorShape x_shape({1, 3, 3, 3, 1});
TensorShape y_shape({1, 1, 1, 1, 1});
@ -274,7 +271,6 @@ TEST_F(NNGradTest, MaxPool3DGradHelper) {
SetRandomValuesForMaxPooling<float>(&x_init_value);
RunTest(x, x_init_value, y, y_shape);
}
#endif
TEST_F(NNGradTest, AvgPoolGradHelper) {
TensorShape x_shape({1, 2, 2, 1});
@ -287,9 +283,6 @@ TEST_F(NNGradTest, AvgPoolGradHelper) {
RunTest(x, x_shape, y, y_shape);
}
// TODO(rocm):
// Re-enable this test once 3D pooling is supported on ROCm platform
#ifndef TENSORFLOW_USE_ROCM
TEST_F(NNGradTest, AvgPool3DGradHelper) {
TensorShape x_shape({1, 3, 3, 3, 1});
TensorShape y_shape({1, 1, 1, 1, 1});
@ -300,7 +293,6 @@ TEST_F(NNGradTest, AvgPool3DGradHelper) {
auto y = AvgPool3D(scope_, x, ksize, strides, "SAME");
RunTest(x, x_shape, y, y_shape);
}
#endif
TEST_F(NNGradTest, LRN) {
TensorShape x_shape({1, 1, 2, 1});

View File

@ -10,6 +10,9 @@ tf_cc_test(
srcs = [
"saved_model_api_test.cc",
],
data = [
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
],
deps = [
"//tensorflow/cc/experimental/base/public:runtime",
"//tensorflow/cc/experimental/base/public:runtime_builder",

View File

@ -90,7 +90,7 @@ TEST_P(CPPSavedModelAPITest, LoadsSavedModel) {
// That unblocks writing other tests that require a TF_SavedModel*,
// like loading a ConcreteFunction. This test at least checks that the
// C API builds and can be minimally run.
EXPECT_EQ(status.code(), TF_UNIMPLEMENTED);
EXPECT_EQ(status.code(), TF_UNIMPLEMENTED) << status.message();
}
INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests,

View File

@ -69,6 +69,7 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"@llvm-project//llvm:ARMCodeGen", # fixdeps: keep
"@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep
"@llvm-project//llvm:Support",
"@llvm-project//llvm:Target",
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
"//tensorflow/core:regexp_internal",

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "absl/base/call_once.h"
#include "llvm-c/Target.h"
#include "llvm/Support/ManagedStatic.h"
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/compiler/aot/quantize.h"

View File

@ -108,8 +108,7 @@ class XlaExecutableClosure {
explicit XlaExecutableClosure(
xla::LocalClient* client, xla::LocalExecutable* executable,
const XlaCompiler::CompilationResult* compilation_result,
std::map<int, OptionalTensor> resource_var_snapshots,
int num_constant_args)
ResourceVarsSnapshot resource_var_snapshots, int num_constant_args)
: client_(client),
executable_(executable),
compilation_result_(compilation_result),
@ -124,7 +123,7 @@ class XlaExecutableClosure {
const XlaCompiler::CompilationResult* compilation_result() const {
return compilation_result_;
}
const std::map<int, OptionalTensor>& resource_var_snapshots() const {
const ResourceVarsSnapshot& resource_var_snapshots() const {
return resource_var_snapshots_;
}
int num_constant_args() const { return num_constant_args_; }
@ -133,7 +132,7 @@ class XlaExecutableClosure {
xla::LocalClient* client_;
xla::LocalExecutable* executable_;
const XlaCompiler::CompilationResult* compilation_result_;
std::map<int, OptionalTensor> resource_var_snapshots_;
ResourceVarsSnapshot resource_var_snapshots_;
int num_constant_args_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure);
@ -276,10 +275,10 @@ static Status BuildCompilationCache(OpKernelContext* ctx,
static Status CompileToLocalExecutable(
OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
const XlaPlatformInfo& platform_info, absl::Span<const int> resources,
const XlaPlatformInfo& platform_info,
absl::Span<VariableInfo const> variable_infos,
absl::Span<const int> constants, bool lazy, xla::LocalClient** client,
std::map<int, OptionalTensor>* variables,
const XlaCompiler::CompilationResult** kernel,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable) {
// We store information about the JIT-compiled XLA computation
// in the ResourceMgr.
@ -299,7 +298,6 @@ static Status CompileToLocalExecutable(
// this is more obviously correct.)
core::ScopedUnref cache_ref(cache);
TF_RETURN_IF_ERROR(SnapshotResourceVariables(ctx, resources, variables));
*client = static_cast<xla::LocalClient*>(cache->client());
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
@ -337,11 +335,11 @@ static Status CompileToLocalExecutable(
std::vector<XlaCompiler::Argument> args;
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
constant_args, *variables, ctx, &args));
constant_args, variable_infos, ctx, &args));
return cache->Compile(options, function, args, compile_options,
lazy ? XlaCompilationCache::CompileMode::kLazy
: XlaCompilationCache::CompileMode::kStrict,
kernel, executable);
compilation_result, executable);
}
void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
@ -349,16 +347,22 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
<< Canonicalize(function_.name(), AttrSlice(&function_.attr()));
xla::LocalClient* client;
const XlaCompiler::CompilationResult* kernel;
const XlaCompiler::CompilationResult* compilation_result;
xla::LocalExecutable* executable;
std::map<int, OptionalTensor> variables;
ResourceVarsSnapshot variables;
{
std::vector<VariableInfo> variable_infos;
OP_REQUIRES_OK(
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
Status s = CompileToLocalExecutable(
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_,
resources_, constants_, /*lazy=*/false, &client, &variables, &kernel,
&executable);
variable_infos, constants_, /*lazy=*/false, &client,
&compilation_result, &executable);
OP_REQUIRES_OK(ctx, s);
OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_,
variable_infos, &variables));
}
se::Stream* stream =
@ -373,7 +377,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
client, allocator,
/*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
platform_info_.UseMultipleStreams());
launch_context.PopulateInputs(ctx, kernel, variables,
launch_context.PopulateInputs(ctx, compilation_result, variables,
/*missing_ctx_input_prefix=*/0);
// Execute the computation.
@ -413,7 +417,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
executable->executable()->module().input_output_alias_config();
OP_REQUIRES_OK(
ctx, launch_context.PopulateOutputs(
ctx, kernel, run_result.ConsumeValueOrDie(),
ctx, compilation_result, run_result.ConsumeValueOrDie(),
/*missing_ctx_input_prefix=*/0, input_output_alias, variables));
VLOG(1) << "Done";
}
@ -494,7 +498,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
xla::LocalClient* client;
const XlaCompiler::CompilationResult* kernel;
xla::LocalExecutable* executable;
std::map<int, OptionalTensor> variables;
ResourceVarsSnapshot variables;
bool cannot_compile_cluster;
{
@ -506,9 +510,16 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
cannot_compile_cluster) {
executable = nullptr;
} else {
std::vector<VariableInfo> variable_infos;
OP_REQUIRES_OK(
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
Status status = CompileToLocalExecutable(
ctx, function_, has_ref_vars_, platform_info_, resources_, constants_,
/*lazy=*/!must_compile_, &client, &variables, &kernel, &executable);
ctx, function_, has_ref_vars_, platform_info_, variable_infos,
constants_,
/*lazy=*/!must_compile_, &client, &kernel, &executable);
OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_,
variable_infos, &variables));
if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
OP_REQUIRES_OK(ctx, status);
}

View File

@ -1837,7 +1837,7 @@ absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable() {
"ConcatOffset", "Const", "MirrorPad", "Pack", "Pad", "PadV2", "Reverse",
"ReverseV2", "ReverseSequence", "Slice", "Split", "SplitV",
"StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign",
"Tile", "Transpose", "InvertPermutation", "Unpack"}}};
"Tile", "Transpose", "InvertPermutation", "Unpack", "DeviceIndex"}}};
// clang-format on
return result;
}

View File

@ -28,32 +28,23 @@ limitations under the License.
namespace tensorflow {
namespace {
std::map<int, OptionalTensor> GetVariables(OpKernelContext* ctx) {
std::map<int, OptionalTensor> variables;
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
// Returns argument indices corresponding to the resource variable inputs of
// kernel context `ctx`.
static std::vector<int> GetResourceVariableIndices(OpKernelContext* ctx) {
std::vector<int> out;
for (int64 i = 0; i < ctx->num_inputs(); i++) {
if (ctx->input(i).dtype() == DT_RESOURCE) {
core::RefCountPtr<Var> variable;
ResourceHandle handle = HandleFromInput(ctx, i);
OptionalTensor& optional = variables[i];
optional.name = handle.name();
if (LookupResource(ctx, handle, &variable).ok()) {
tf_shared_lock lock(*variable->mu());
optional.present = true;
optional.value = *variable->tensor();
}
out.push_back(i);
}
}
return variables;
return out;
}
} // namespace
Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
const XlaDevice::Metadata& metadata,
const XlaCompiler::CompilationResult* result,
xla::LocalExecutable* executable) {
std::map<int, OptionalTensor> variables = GetVariables(ctx);
xla::LocalExecutable* executable,
const ResourceVarsSnapshot& variable_args) {
xla::LocalClient* client = metadata.client();
// Builds an XLA allocator for the device.
@ -62,7 +53,7 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
/*allocate_xla_tensors=*/true,
/*use_multiple_streams=*/metadata.UseMultipleStreams());
launch_context.PopulateInputs(ctx, result, variables,
launch_context.PopulateInputs(ctx, result, variable_args,
/*missing_ctx_input_prefix=*/0);
se::Stream* stream =
@ -87,7 +78,7 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
executable->executable()->module().input_output_alias_config();
TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
ctx, result, run_result.ConsumeValueOrDie(),
/*missing_ctx_input_prefix=*/0, input_output_alias, variables));
/*missing_ctx_input_prefix=*/0, input_output_alias, variable_args));
return Status::OK();
}
@ -115,7 +106,7 @@ Status XlaCompileOnDemandOp::ShouldArgumentBeConstant(
Status XlaCompileOnDemandOp::Compile(
OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
const XlaCompiler::CompilationResult** result,
xla::LocalExecutable** executable) {
ResourceVarsSnapshot* variable_args, xla::LocalExecutable** executable) {
std::map<int, Tensor> constant_arguments;
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
const Tensor& device_tensor = ctx->input(i);
@ -190,12 +181,18 @@ Status XlaCompileOnDemandOp::Compile(
// rather than a one-element tuple.
compile_options.always_return_tuple = false;
std::map<int, OptionalTensor> variable_args = GetVariables(ctx);
std::vector<int> variables_indices = GetResourceVariableIndices(ctx);
std::vector<XlaCompiler::Argument> args;
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
constant_arguments, variable_args, ctx, &args));
{
std::vector<VariableInfo> variable_infos;
TF_RETURN_IF_ERROR(
GetVariableInfosFromCtxInputs(ctx, variables_indices, &variable_infos));
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
TF_RETURN_IF_ERROR(SnapshotResourceVariables(
ctx, variables_indices, variable_infos, variable_args));
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
constant_arguments, variable_infos, ctx, &args));
}
return cache->CompileSingleOp(options, args, ctx, compile_options, result,
executable);
@ -206,8 +203,10 @@ void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) {
xla::LocalExecutable* executable;
const XlaDevice::Metadata* metadata;
OP_REQUIRES_OK(ctx, XlaDevice::GetMetadata(ctx, &metadata));
OP_REQUIRES_OK(ctx, Compile(ctx, *metadata, &result, &executable));
OP_REQUIRES_OK(ctx, Run(ctx, *metadata, result, executable));
ResourceVarsSnapshot variable_args;
OP_REQUIRES_OK(ctx,
Compile(ctx, *metadata, &result, &variable_args, &executable));
OP_REQUIRES_OK(ctx, Run(ctx, *metadata, result, executable, variable_args));
}
} // namespace tensorflow

View File

@ -20,6 +20,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/framework/function.h"
@ -47,10 +48,12 @@ class XlaCompileOnDemandOp : public OpKernel {
bool* result);
Status Compile(OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
const XlaCompiler::CompilationResult** result,
ResourceVarsSnapshot* variable_args,
xla::LocalExecutable** executable);
Status Run(OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
const XlaCompiler::CompilationResult* result,
xla::LocalExecutable* executable);
xla::LocalExecutable* executable,
const ResourceVarsSnapshot& variable_args);
};
} // namespace tensorflow

View File

@ -52,7 +52,8 @@ const char kPossibleNonVariableResourceHintMessage[] =
"resource inputs to XLA.";
} // anonymous namespace
VariableInfo::VariableInfo(int index, Var* var) : index_(index), var_(var) {}
VariableInfo::VariableInfo(int index, absl::string_view name, Var* var)
: index_(index), name_(name), var_(var) {}
VariableInfo::VariableInfo(VariableInfo&& other)
: index_(other.index_), var_(other.var_), lock_held_(other.lock_held_) {
other.index_ = -1;
@ -87,16 +88,15 @@ VariableInfo::~VariableInfo() {
// Returns a vector of VariableInfo instances for the resource variable inputs
// to the kernel with context `ctx`. The input indices for the resource
// variable inputs are in `variable_indices`.
static Status GetVariableInfosFromCtxInputs(
OpKernelContext* ctx, absl::Span<const int> variable_indices,
std::vector<VariableInfo>* result) {
Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx,
absl::Span<const int> variable_indices,
std::vector<VariableInfo>* result) {
std::vector<const ResourceHandle*> resource_handles;
absl::c_transform(
variable_indices, std::back_inserter(resource_handles),
[&](int variable_idx) { return &HandleFromInput(ctx, variable_idx); });
std::vector<core::RefCountPtr<Var>> variables;
Status s = LookupResources(ctx, resource_handles, &variables);
if (!s.ok()) {
errors::AppendToMessage(&s, kPossibleNonVariableResourceHintMessage);
@ -109,7 +109,9 @@ static Status GetVariableInfosFromCtxInputs(
// *Release* the variable because we're going to unref it later in
// ~VariableInfo.
Var* variable = variables[i].release();
result->emplace_back(variable_indices[i], variable);
int input_idx = variable_indices[i];
std::string var_name = HandleFromInput(ctx, input_idx).name();
result->emplace_back(input_idx, var_name, variable);
}
return Status::OK();
@ -162,21 +164,12 @@ Status LockVariables(absl::Span<VariableInfo> variables) {
Status SnapshotResourceVariables(OpKernelContext* ctx,
absl::Span<const int> variable_indices,
std::map<int, OptionalTensor>* result) {
std::vector<VariableInfo> variable_infos;
TF_RETURN_IF_ERROR(
GetVariableInfosFromCtxInputs(ctx, variable_indices, &variable_infos));
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
absl::Span<VariableInfo const> variable_infos,
ResourceVarsSnapshot* result) {
for (int i = 0; i < variable_indices.size(); i++) {
if (variable_infos[i].var()) {
OptionalTensor& tensor = (*result)[variable_indices[i]];
tensor.name = HandleFromInput(ctx, variable_indices[i]).name();
tensor.present = true;
tensor.value = *variable_infos[i].var()->tensor();
} else {
(*result)[variable_indices[i]] = OptionalTensor();
}
Var* var = variable_infos[i].var();
(*result)[variable_indices[i]] =
var ? absl::make_optional(*var->tensor()) : absl::nullopt;
}
return Status::OK();
}
@ -197,8 +190,7 @@ XlaComputationLaunchContext::XlaComputationLaunchContext(
void XlaComputationLaunchContext::PopulateInputs(
OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
const std::map<int, OptionalTensor>& variables,
int missing_ctx_input_prefix) {
const ResourceVarsSnapshot& variables, int missing_ctx_input_prefix) {
// Build ShapedBuffers that point directly to the Tensor buffers.
arg_ptrs_ =
std::vector<ShapedBuffer*>(compilation_result->xla_input_shapes.size());
@ -210,7 +202,7 @@ void XlaComputationLaunchContext::PopulateInputs(
CHECK_GE(arg_num, missing_ctx_input_prefix);
const xla::Shape& shape = compilation_result->xla_input_shapes[i];
const Tensor* t = variables.count(arg_num)
? &(variables.at(arg_num).value)
? &(variables.at(arg_num).value())
: &(ctx->input(arg_num - missing_ctx_input_prefix));
CHECK(t);
@ -262,7 +254,7 @@ static const Tensor* FindAliasedTensorForOutput(
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias,
absl::Span<const int> input_mapping,
const std::map<int, OptionalTensor>& resource_var_snapshots) {
const ResourceVarsSnapshot& resource_var_snapshots) {
if (MustAliasOutput(input_output_alias, output_num)) {
int xla_param = input_output_alias.GetAliasedParameter({output_num})
.value()
@ -274,8 +266,8 @@ static const Tensor* FindAliasedTensorForOutput(
// entry time.
if (input_tensor->dtype() == DT_RESOURCE) {
auto& v = resource_var_snapshots.at(missing_ctx_input_prefix + tf_param);
CHECK(v.present);
return &v.value;
CHECK(v.has_value());
return &v.value();
}
return input_tensor;
}
@ -298,9 +290,9 @@ static Tensor GetOrCreateTensorForOutput(
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias,
absl::Span<const int> input_mapping,
const std::map<int, OptionalTensor>& resource_var_snapshots,
DataType output_dtype, const TensorShape& output_shape,
se::DeviceMemoryBase output_buffer, Allocator* output_allocator) {
const ResourceVarsSnapshot& resource_var_snapshots, DataType output_dtype,
const TensorShape& output_shape, se::DeviceMemoryBase output_buffer,
Allocator* output_allocator) {
if (const Tensor* aliased_tensor = FindAliasedTensorForOutput(
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
input_mapping, resource_var_snapshots)) {
@ -431,13 +423,13 @@ static xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
// not a Tensor.
Var* variable = nullptr;
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
ctx, HandleFromInput(ctx, actual_input_index), &variable,
[&write](Var** ptr) {
*ptr = new Var(write.type);
return Status::OK();
}));
variable_infos.emplace_back(actual_input_index, variable);
const ResourceHandle handle = HandleFromInput(ctx, actual_input_index);
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(ctx, handle, &variable,
[&write](Var** ptr) {
*ptr = new Var(write.type);
return Status::OK();
}));
variable_infos.emplace_back(actual_input_index, handle.name(), variable);
}
return variable_infos;
}
@ -447,7 +439,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
const XlaCompiler::CompilationResult* compilation_result,
ScopedShapedBuffer output, int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias,
const std::map<int, OptionalTensor>& resource_var_snapshots) {
const ResourceVarsSnapshot& resource_var_snapshots) {
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
Allocator* allocator = ctx->device()->GetAllocator({});
@ -484,10 +476,36 @@ Status XlaComputationLaunchContext::PopulateOutputs(
stream->ThenRecordEvent(definition_event.get());
}
std::vector<TensorShape> output_tensor_shapes;
output_tensor_shapes.reserve(ctx->num_outputs());
if (output.on_host_shape().is_dynamic()) {
TF_ASSIGN_OR_RETURN(
auto transfer_manager,
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
xla::Shape output_host_shape = output.on_host_shape();
xla::Shape output_device_shape = output.on_device_shape();
TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
stream, &output, &output_host_shape, &output_device_shape));
output.set_shapes(output_host_shape, output_device_shape);
for (int i = 0; i < ctx->num_outputs(); ++i) {
const xla::Shape& subshape =
xla::ShapeUtil::GetSubshape(output_host_shape, {i});
TensorShape shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
output_tensor_shapes.push_back(shape);
}
} else {
for (int i = 0; i < ctx->num_outputs(); ++i) {
output_tensor_shapes.push_back(compilation_result->outputs[i].shape);
}
}
// Copy XLA results to the OpOutputList.
int output_num = 0;
for (int i = 0; i < ctx->num_outputs(); ++i) {
const TensorShape& shape = compilation_result->outputs[i].shape;
const TensorShape& shape = output_tensor_shapes[i];
const DataType& type = compilation_result->outputs[i].type;
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
<< DataTypeString(type);
@ -564,12 +582,21 @@ Status XlaComputationLaunchContext::PopulateOutputs(
Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
const std::map<int, Tensor>& constant_args,
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
absl::Span<VariableInfo const> variable_args, OpKernelContext* ctx,
std::vector<XlaCompiler::Argument>* args) {
args->resize(ctx->num_inputs());
absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup;
for (const VariableInfo& info : variable_args) {
CHECK(!info.var() || info.lock_held())
<< "Need to hold the lock on resource variables "
"before calling BuildXlaCompilerArguments";
variable_info_lookup.emplace(info.index(), &info);
}
for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) {
XlaCompiler::Argument& arg = (*args)[input_num];
if (constant_args.count(input_num) > 0) {
// Handles compile-time constants.
const Tensor& input = constant_args.at(input_num);
@ -578,7 +605,7 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
arg.type = input.dtype();
arg.shape = input.shape();
arg.constant_value = input;
} else if (variable_args.count(input_num) == 0) {
} else if (variable_info_lookup.count(input_num) == 0) {
// Handles the non-constant arguments.
const Tensor& input = ctx->input(input_num);
TF_RET_CHECK(input.dtype() != DT_RESOURCE);
@ -594,14 +621,14 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
// Handles resource variables.
const Tensor& input = ctx->input(input_num);
TF_RET_CHECK(input.dtype() == DT_RESOURCE);
const OptionalTensor& variable = variable_args.at(input_num);
arg.name = variable.name;
const VariableInfo& variable = *variable_info_lookup[input_num];
arg.name = std::string(variable.name());
arg.kind = XlaCompiler::Argument::kResource;
arg.resource_kind = XlaResource::kVariable;
if (variable.present) {
const Tensor& value = variable.value;
arg.type = value.dtype();
arg.shape = value.shape();
if (variable.var()) {
const Tensor* value = variable.var()->tensor();
arg.type = value->dtype();
arg.shape = value->shape();
arg.initialized = true;
} else {
// The values of uninitialized variables are not passed as inputs, since

View File

@ -34,36 +34,17 @@ limitations under the License.
namespace tensorflow {
// Struct that represents a possibly-absent Tensor.
struct OptionalTensor {
string name; // A descriptive name
bool present = false; // Is the tensor present?
Tensor value; // If present, what is the Tensor's value?
};
// Takes a snapshot of the values of resource variable arguments, whose indices
// are specified in `variable_indices` argument. We snapshot tensors that back
// resource variables since concurrent updates may modify the shape, and it is
// important that the shapes used for compilation match the true shapes of the
// buffers.
//
// We snapshot the entire set of resource variables as one atomic operation.
// This models Read->* dependencies between resource variable operations. See
// jit/resource_operation_safety_analysis for details.
//
// Returns a map of TensorFlow argument index to resource variable. If a
// resource variable is not initialized, the corresponding OptionalTensor
// will have its `present` field set to false.
Status SnapshotResourceVariables(OpKernelContext* ctx,
absl::Span<const int> variable_indices,
std::map<int, OptionalTensor>* result);
// Snapshot of resource variables for a TF kernel invocation, mapping from
// parameter number to values at execution time. If the resource variable is not
// initialized, the value will not be present.
using ResourceVarsSnapshot = absl::flat_hash_map<int, absl::optional<Tensor>>;
// Information about the state of a variable passed as input to the _XlaCompile
// and _XlaRun operators. Unlocks the resource variable and decrements its
// refcount on destruction.
class VariableInfo {
public:
explicit VariableInfo(int index, Var* var);
explicit VariableInfo(int index, absl::string_view name, Var* var);
VariableInfo(VariableInfo&& other);
VariableInfo& operator=(VariableInfo&& other);
@ -79,6 +60,9 @@ class VariableInfo {
// "empty", i.e. it does not track a resource variable.
Var* var() const { return var_; }
// Returns the variable name.
absl::string_view name() const { return name_; }
// Returns true if the resource variable lock was successfully acquired by
// this thread.
bool lock_held() const { return lock_held_; }
@ -88,6 +72,7 @@ class VariableInfo {
private:
int index_;
std::string name_;
Var* var_;
// We can't use a optional<mutex_lock> here because it confuses the compiler's
@ -96,6 +81,20 @@ class VariableInfo {
bool lock_held_ = false;
};
// Takes a snapshot of the values of resource variable arguments, whose indices
// are specified in `variable_indices` argument. We snapshot tensors that back
// resource variables since concurrent updates may modify the shape, and it is
// important that the shapes used for compilation match the true shapes of the
// buffers.
//
// We snapshot the entire set of resource variables as one atomic operation.
// This models Read->* dependencies between resource variable operations. See
// jit/resource_operation_safety_analysis for details.
Status SnapshotResourceVariables(OpKernelContext* ctx,
absl::Span<const int> variable_indices,
absl::Span<VariableInfo const> variable_infos,
ResourceVarsSnapshot* result);
// Acquires the mutexes for all the variables in `variables` using a
// deadlock-safe protocol (acquire the mutexes in increasing-address order).
//
@ -104,6 +103,13 @@ class VariableInfo {
Status LockVariables(absl::Span<VariableInfo> variables)
TF_EXCLUSIVE_LOCK_FUNCTION();
// Returns a vector of VariableInfo instances for the resource variable inputs
// to the kernel with context `ctx`. The input indices for the resource
// variable inputs are in `variable_indices`.
Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx,
absl::Span<const int> variable_indices,
std::vector<VariableInfo>* result);
// Helper class to perform the marshalling of TensorFlow inputs and outputs to
// ShapedBuffers suitable for passing to an XLA computation.
class XlaComputationLaunchContext {
@ -123,9 +129,10 @@ class XlaComputationLaunchContext {
// Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch
// op.
// Precondition: variables in `variable_args` are locked.
static Status BuildXlaCompilerArguments(
const std::map<int, Tensor>& constant_args,
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
absl::Span<VariableInfo const> variable_args, OpKernelContext* ctx,
std::vector<XlaCompiler::Argument>* args);
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
@ -137,7 +144,7 @@ class XlaComputationLaunchContext {
// (in other words, no inputs actually required by the kernel can be missing).
void PopulateInputs(OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
const std::map<int, OptionalTensor>& variables,
const ResourceVarsSnapshot& variables,
int missing_ctx_input_prefix);
// Given the XLA output in `output`, populate all outputs of `ctx`. Also
@ -155,7 +162,7 @@ class XlaComputationLaunchContext {
const XlaCompiler::CompilationResult* compilation_result,
xla::ScopedShapedBuffer output, int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias,
const std::map<int, OptionalTensor>& resource_var_snapshots);
const ResourceVarsSnapshot& resource_var_snapshots);
// Return the argument list. Only valid after PopulateInputs() has been
// called.

View File

@ -0,0 +1,265 @@
# MLIR CodeGen for XLA
<!--*
# Document freshness: For more information, see go/fresh-source.
freshness: { owner: 'timshen' reviewed: '2020-06-16' }
*-->
XLA operates on `HloInstruction` and performs many optimizations on this
representation, sharing a lot of these between targeted devices. As some point a
linear schedule is computed and the memory buffer is assigned to each value
statically. The device specific codegen operates by traversing this sequence and
calling "emitters" to generate a representation suitable for the device (for
example a single LLVM function per XLA computation on CPU, or a sequence of
"thunks" encapsulating GPU operations and possibly generated PTX when targeting
GPU).
As a staging step, we're currently in the process of intercepting the process
right after XLA completes the buffer-assignment phase and emit instead an MLIR
module in the `lhlo` dialect. From there we perform the codegen using MLIR
components (Linalg, affine, and GPU dialect mainly) depending on the device.
Below is the plan of record to incrementally migrate XLA/GPU by using `lhlo` as
the codegen input.
## Tasks
| | Host | Device
| ------------- | ------------------------ | ------------------------
| Input format | HloInstruction* (Task 1) | HloInstruction* (Task 1)
| Output format | xla::Thunk (Task 2) | LLVM IR (Task 3)
* **Task 1** changes both host and device input format from HloInstruction* to
LHLO.
* **Task 2** changes output format of host from thunks to "some landing pad
for host" (see below).
* **Task 3** migrates device output from LLVM IR to some form of MLIR. It's
optional to this project, and see the section "Migrating Device LLVM IR" for
details.
This project prioritizes having end-to-end runnable models with LHLO-emitters
enabled as much as possible. This implies that the following order list of
objectives by priority:
* Make XLA/GPU runnable with LHLO emitters, with existing Thunks and emitters
unmodified.
* Eliminate the references to HloInstruction\* in LHLO, case by case:
* Switch a legacy emitter to an MLIR-based emitter (e.g. Linalg), or
* Mechanically translate the existing emitter to take MLIR representation
(migrate to Standard with GPU Dialect).
## Migrating Thunks (Task 2)
xla::gpu::Thunk is a data structure that:
* Can be called into from the host (xla::gpu::Thunk::ExecuteOnStream()).
* Carries various data in its subclasses.
* Interacts with BufferAllocation::Slice and StreamExecutor.
* Launches kernels
* Calls into all runtime libraries.
The cost of that includes:
* Representing op-specific configuration data (e.g. convolution configs).
* Migrating op shape and operand shapes.
* Representing a tree of thunks (while, condition, etc).
The migration work is independent from LHLO / emitter migration. Under limited
resources, it's prioritized behind LHLO / emitter migration.
We have several choices on how to lower the host-side part from LHLO:
* TFRT
* (Pro) great CUDA and HIP wrappers for use.
* (Pro) easy to implement library calls (cuDNN, cuBLAS, cuFFT, etc), as
TFRT ops are interpreted by C++ code.
* (Con) host side is under development and not tested.
* (Con) the JAX integration isnt clear from a runtime point of view
* Jitted CPU code
* (Pro) great lower-ability. Create a few loops and conditions and it's
done.
* (Con) GPUDialect doesn't yet model chains/streams/asynchronicity/device
allocation.
* (Con) CUDA / HIP runtime support is minimal (toolkit path, version,
dynamic loading, etc).
* Existing (interpreting) XLA runtime
Tentative conclusion: Use jitted CPU code during the transition, and optionally
adopt TFRT in the end.
## Migrating Device LLVM IR (Task 3)
An elemental emitter generates target op by filling it element by element. Each
output element depends on a set of elements from the operands. All elements are
described by combining the buffer with dynamic indices. It's sufficient to
describe almost all "math" ops, but for performance reasons only a large subset
of "math" ops are implemented directly in (Cpu|Gpu)ElementalIrEmitter.
ElementalIrEmitter is unique in that:
* A large portion of the code is shared between XLA/GPU and CPU.
* It represents a large portion of ops seen in models, including all
element-wise ops.
* Most fusions solely depend on ElementalIrEmitter.
* It's structurally simple, as it describes a data dependency DAG between op
elements and operand elements.
* It's mostly portable and high-level (e.g. unlike GPU kReduce and GPU kCopy).
* Dynamic shape support is easy for at least element-wise ops.
Now, for all ops, elementally-emitted or not, there are several flavors of the
end state of each XLA op:
1. Device code stays as LLVM IR.
1. Refactor the old emitter to be like LHLO -> MLIR LLVM Dialect:
* (Cost) Will be throw-away work if we want to ultimately migrate to
Standard.
* (Benefit) It is easy and mechanical. Can be done in a short period.
* (Benefit) It doesn't benefit more compared to a).
1. Refactor old emitters to be like LHLO -> MLIR GPU + Standard + Loops:
* (Cost) Lifting existing emitters to Standard introduces some challenges.
Pointers and GEPs need to be converted to MemRefs and SubViews. Ensuring
amdgpu completeness is another one.
* (Cost) XLA/GPU heavily relies on LLVM metadata:
* `range` for block/thread indices.
* `align`, `dereferenceable`, `invariant.load`, `alias.scope`,
`noalias` for load/stores.
* `llvm.loop.unroll.disable`, `llvm.loop.unroll.full`,
`llvm.loop.vectorize.enable` for sequential loops.
* (Benefit) Can be long-term. More portable.
1. Refactor old emitters to be LHLO -> Linalg, and write new Linalg emitters
* (Cost) This is case by case. Compared to previous options, a new
implementation that matches XLA's performance needs to go through the
benchmark <-> optimize workflow, which can be a significant cost for
some ops.
* (Benefit) unified stack; community support; portability; more
optimization potentials.
## Prioritization
While all three tasks mentioned above are parallelizable, under limited
resources they have to be serialized. The prioritization focuses on visible
results for completion of each task.
The prioritization is: Task1 (LHLO for legacy emitters) > Task 2 (Thunks) > Task
3 (MLIR emitters).
By the end of Task 1, users of XLA can generate an LHLO (e.g. kernel generator)
and execute them. The compilation format will not be serializable MLIR.
By the end of Task 2, LHLO lowers to proper, serializable MLIR. This enables
offline compilation.
By the end of Task 3, all XLA emitters are MLIR-based in its implementation.
## Detailed Design
### Step 1: (Task 1) Complete LHLO and Make Legacy Emitters Take LHLO
This step makes all existing XLA/GPU emitters interact with MLIR ops. This step
is pure refactoring and NFC.
This step is mostly mechanical, but it's worth noticing the following
discrepancies between an unnested HloComputation and LHLO:
* Each HloInstruction has direct access to its operands (a data-flow DAG). On
contrary, each LHLO op only has access to its operand buffers (a bipartite
between ops and buffers). LHLO ops have to go through use-def chains to
access their operand ops.
* Unnested legacy emitters empirically almost never access their operands. The
only exception is kReduce.
* Unnested legacy emitters access BufferAssignment only for getting slices,
not for accessing aux data structures like dataflow\_analysis() or
alias\_analysis(). llvm\_ir builds its own alias\_analysis() based on slice
information.
The conclusion is that LHLO should fit right-in without major hassle.
### Step 2: (Optional) Profiling Support
**This step is only needed if we start to discard some of the XLA Thunk logic
(see the next step).**
Before actually turning on any MLIR-based emitters, we need profiling for
MLIR-based emitters.
Currently XLA performs its own profiling by calling into StreamExecutor's timer.
The timer under the hood inserts two events before and after a kernel launch,
and measures the sync time between these two events.
There are roughly three approaches to support profiling in MLIR:
* Run a profiler end-to-end
* Add a profile op for each op in LHLO, using an injected profiler.
The "end-to-end" approach is transparent to MLIR, but suffers the same problem
that makes XLA not use it in the first place: library calls collected by a
profiler (nvprof/...) can't easily relate to HLO ops. For example, cuDNN
launches multiple kernels for each HLO, and it's hard to tell which kernels
correspond to which HLO.
The "injected profiler" approach requires:
* LHLO to take a profiler as a parameter.
* inserting profile.start / profile.end before and after each op.
* a pass from that lowers profile.{start,end} to a C++ implementation.
The exact profiling can't be easily done for MLIR-generated ops, since:
* MLIR doesn't have a timer, nor it depends on TFRT / StreamExecutor.
* MLIR doesn't easily call into C functions with complicated parameters.
### Step 3: (Task 2) Migrating Thunks
This step migrates all host ops and library calls. This step will eliminate most
of the thunks and produce serializable MLIR instead.
There are roughly three kinds of thunks:
* KernelThunk, which launches a kernel.
* Control flow thunks, which has host control flow logic (conditional, while,
for, sequence) and launch body kernels.
* Library thunks: cuDNN, cuBLAS, cuFFT, NCCL, etc.
The **bottom line** is to:
* Create a Thunk dialect that provides (de)serialize logic for all existing
C++-based Thunks.
* Change emitters to emit a graph of Thunk dialect.
**Optionally**, we can relieve some thunks from C++ implementation. KernelThunk
can lower to the GPU LaunchKernelOp. Control flow thunks can leverage the CFG
Dialect for loops and conditions, combined with LaunchKernelOp. This optional
step requires profiling and stream support.
### Step 4: (Task 3) Migrated ElementalIrEmitter
Once profiling is ready, we can complete and tune all ElementalIrEmitter-based
emitters in MLIR. Then we turn them on by default, assuming that all of these
MLIR-based emitters use a single stream.
Notice that it's beneficial to migrate XLA/CPU's ElementalIrEmitter as well,
since they share a large portion of the code.
With all benchmarking and performance hunting done (TODO: define performance
parity), we turn on the new MLIR-based elemental emitter, and delete the legacy
ElementalIrEmitter.
This step also provides easy fusion transitions (nested ops) for the later
migration.
### Step 5: Multi-Stream Support or Drop
We can't delete
[some of the emitters](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/service/gpu/stream_assignment.cc#L140)
until we support it in MLIR, or we drop the feature. It's a relatively large
amount of work in MLIR and a small amount of gain for XLA. We should investigate
current users of multi-stream XLA/GPU users, and try to delete this feature if
reasonable.
### Step 6: (Task 3) Migrated Device Ops
This step migrates all unnested ops, then we can delete all unnested emitters.
This calls on a rewrite/refactor for kCopy and kReduce. kReduce is already
worked on for plenty, so the actual amount of work that needs to be done remains
to be seen.

View File

@ -150,6 +150,7 @@ gentbl(
td_srcs = [
":tensorflow_lite_ops_td_files",
"@llvm-project//mlir:StdOpsTdFiles",
"utils/utils.td",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
],
)
@ -314,7 +315,6 @@ tf_cc_test(
cc_library(
name = "tensorflow_lite_legalize_tf",
srcs = [
"transforms/device_index_selector.cc",
"transforms/dilated_conv.cc",
"transforms/generated_legalize_tf.inc",
"transforms/generated_lower_static_tensor_list.inc",
@ -380,6 +380,7 @@ cc_library(
"transforms/passes.h",
],
deps = [
"convert_type",
":tensorflow_lite",
":validators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
@ -410,6 +411,7 @@ cc_library(
"transforms/passes.h",
],
deps = [
"convert_type",
":tensorflow_lite",
":validators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
@ -672,6 +674,7 @@ cc_library(
"utils/convert_type.h",
],
deps = [
":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:protos_all_cc",
@ -787,7 +790,9 @@ tf_cc_binary(
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"//tensorflow/cc/saved_model:loader",
"//tensorflow/compiler/mlir:init_mlir",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:errors",
@ -840,7 +845,10 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
"//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/core:core_cpu_base",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
@ -867,6 +875,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:decode_constant_pass",
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",

View File

@ -446,7 +446,7 @@ static void GenOperandResultVerifier(raw_ostream &os,
auto desc =
definit->getDef()->getValueAsString("tflRuntimeTypeDescription");
// Emit a loop to check all the dynamic values in the pack.
// Emit a loop to check all operands.
os << formatv(" for (Value v : top.getODS{0}{1}s({2})) {{\n",
// Capitalize the first letter to match the function name
valueKind.substr(0, 1).upper(), valueKind.substr(1),
@ -455,14 +455,10 @@ static void GenOperandResultVerifier(raw_ostream &os,
os << " (void)v;\n"
<< " if (!("
<< tgfmt(pred.getCondition(), &fctx.withSelf("v.getType()")) << ")) {\n"
<< " if (failure_on_operand_type_mismatch) {\n"
<< formatv(
" return op->emitOpError(\"{0} #\") << index "
" return op->emitOpError(\"{0} #\") << index "
"<< \" must be {1}, but got \" << v.getType();\n",
valueKind, desc)
<< " } else {\n"
<< " return ::mlir::LogicalResult::Failure;\n"
<< " }\n"
<< " }\n" // if
<< " ++index;\n"
<< " }\n"; // for
@ -487,8 +483,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
mlir::tblgen::FmtContext verify_ctx;
os << "::mlir::LogicalResult " << op.getCppClassName()
<< "::VerifyTflRuntimeConstraints(::mlir::Operation *op, bool "
"failure_on_operand_type_mismatch) {\n";
<< "::VerifyTflRuntimeConstraints(::mlir::Operation *op) {\n";
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
verify_ctx.withOp("top");
@ -529,11 +524,8 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
os << tgfmt(
" if (!($0)) {\n "
" if (failure_on_operand_type_mismatch) {\n"
" return top.emitOpError(\"failed to verify that $1\");\n"
" } else {\n"
" return ::mlir::LogicalResult::Failure;\n }\n }\n",
" if (!($0))\n"
" return top.emitOpError(\"failed to verify that $1\");\n",
&verify_ctx, tgfmt(pred.getCondition(), &verify_ctx), desc);
}
os << " return top.verify();\n}\n";

View File

@ -190,9 +190,8 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
}
static bool IsConst(Operation* op) {
return isa<mlir::ConstantOp>(op) || isa<mlir::TF::ConstOp>(op) ||
isa<tfl::ConstOp>(op) || isa<tfl::QConstOp>(op) ||
isa<tfl::SparseConstOp>(op) || isa<tfl::SparseQConstOp>(op);
return isa<mlir::ConstantOp, mlir::TF::ConstOp, tfl::ConstOp, tfl::QConstOp,
tfl::SparseConstOp, tfl::SparseQConstOp>(op);
}
template <typename T>
@ -240,10 +239,10 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
}
for (auto fn : module.getOps<FuncOp>()) {
if (fn.getBlocks().size() != 1) {
if (!llvm::hasSingleElement(fn)) {
return fn.emitError("should have exactly one basic block"), false;
}
auto& bb = fn.getBlocks().front();
auto& bb = fn.front();
for (auto arg : bb.getArguments()) {
if (!HasValidTFLiteType(arg, fn))
@ -1089,7 +1088,7 @@ void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) {
dict_attr.get("outputs").dyn_cast_or_null<mlir::StringAttr>()) {
str.getValue().split(output_names, ',', /*MaxSplit=*/-1,
/*KeepEmpty=*/false);
auto term = fn.getBlocks().back().getTerminator();
auto term = fn.back().getTerminator();
if (output_names.size() != term->getNumOperands()) {
fn.emitWarning() << "output names (" << output_names.size()
<< ") != terminator operands (" << term->getNumOperands()

View File

@ -94,8 +94,7 @@ def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> {
let methods = [
StaticInterfaceMethod<
[{Returns whether the op's operands/results are supported by runtime.}],
"LogicalResult", "VerifyTflRuntimeConstraints",
(ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch)
"LogicalResult", "VerifyTflRuntimeConstraints", (ins "Operation*":$op)
>,
];
}

View File

@ -138,6 +138,11 @@ bool IsI32Type(Type element_type) {
return element_type.isInteger(32) && !element_type.isUnsignedInteger();
}
// Return true when the given element_type is I64.
bool IsI64Type(Type element_type) {
return element_type.isInteger(64) && !element_type.isUnsignedInteger();
}
// Return true if the given Add operation has the CPU kernel supported shapes.
bool VerifyAddOpShapeConstraints(AddOp op) {
auto element_type = getElementTypeOrSelf(op.output().getType());
@ -174,7 +179,8 @@ bool VerifySubOpShapeConstraints(SubOp op) {
// Allows F32, QUI8, and QI16 outputs when the operands have valid shapes,
// which are broadcastable shapes up to five dimension or have same shapes.
if (element_type.isF32() || IsI32Type(element_type) ||
IsQUI8Type(element_type) || IsQI16Type(element_type)) {
IsI64Type(element_type) || IsQUI8Type(element_type) ||
IsQI16Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5);
@ -758,6 +764,22 @@ OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
return new_concat.getResult();
}
//===----------------------------------------------------------------------===//
// CustomOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(CustomOp op) {
OpaqueElementsAttr opaque_attr =
op.custom_option().cast<OpaqueElementsAttr>();
if (!opaque_attr.getType().hasStaticShape())
return op.emitOpError("custom_option should have a static shape.");
if (opaque_attr.getValue().size() !=
opaque_attr.getType().cast<ShapedType>().getDimSize(0))
return op.emitOpError(
"custom_option should have the same length of content with shape.");
return success();
}
//===----------------------------------------------------------------------===//
// FullyConnectedOp
//===----------------------------------------------------------------------===//
@ -2169,6 +2191,10 @@ static LogicalResult Verify(TransposeOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// WhileOp
//===----------------------------------------------------------------------===//
LogicalResult Verify(WhileOp op) {
if (op.getNumOperands() != op.getNumResults())
return op.emitOpError(llvm::formatv(
@ -2178,18 +2204,6 @@ LogicalResult Verify(WhileOp op) {
return success();
}
static LogicalResult Verify(CustomOp op) {
OpaqueElementsAttr opaque_attr =
op.custom_option().cast<OpaqueElementsAttr>();
if (!opaque_attr.getType().hasStaticShape())
return op.emitOpError("custom_option should have a static shape.");
if (opaque_attr.getValue().size() !=
opaque_attr.getType().cast<ShapedType>().getDimSize(0))
return op.emitOpError(
"custom_option should have the same length of content with shape.");
return success();
}
namespace {
// Canonicalize While op so that results and operands match and external values
// are via implicit capture rather than via block args.

View File

@ -571,6 +571,8 @@ def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [
TFL_OperandHasRank<2, 4>,
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 2>>,
AccumulatorUniformScale<3, 1, 2>,
TFL_ChannelDimIndexInterface, AffineOpCoefficient<0, 2>,
TFL_GpuTargetOp,
TFL_SparseOp]> {
let summary = "Transpose convolution operator";
@ -596,6 +598,8 @@ def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [
let verifier = [{ return Verify(*this); }];
let extraClassDeclaration = [{
// ChannelDimIndexInterface:
int GetChannelDimIndex() { return 0; }
// SparseOpInterface:
std::vector<int> GetSparseOperands() { return {1}; }
std::vector<std::vector<int>> GetFloatBlockSize() { return {}; }
@ -953,14 +957,14 @@ in the batch dimensions and broadcasting.
}];
let arguments = (ins
TFL_TensorOf<[F32]>:$x,
TFL_TensorOf<[F32]>:$y,
TFL_TensorOf<[F32, QI8]>:$x,
TFL_TensorOf<[F32, QI8]>:$y,
DefaultValuedAttr<BoolAttr, "false">:$adj_x,
DefaultValuedAttr<BoolAttr, "false">:$adj_y
);
let results = (outs
TFL_TensorOf<[F32]>:$output
TFL_TensorOf<[F32, QI8]>:$output
);
let hasOptions = 1;
@ -2860,11 +2864,11 @@ def TFL_SubOp : TFL_Op<"sub", [
}];
let arguments = (
ins TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$lhs,
TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$rhs,
ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$lhs,
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$rhs,
TFL_AFAttr:$fused_activation_function);
let results = (outs TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$output);
let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$output);
let hasFolder = 1;

View File

@ -21,6 +21,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/lite/toco:model_flags_proto_cc",

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <ostream>
#include <utility>
#include "llvm/ADT/None.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
@ -90,7 +91,8 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
pass_config.lower_tensor_list_ops = true;
return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module),
pass_config, result);
pass_config, result,
/*session=*/llvm::None);
}
} // namespace tensorflow

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <utility>
#include "llvm/ADT/None.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
@ -165,8 +166,10 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
pass_config.lower_tensor_list_ops = true;
// TODO(b/153507667): Pass the session object when importing logic is removed.
auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
toco_flags, std::move(module), pass_config, result);
toco_flags, std::move(module), pass_config, result,
/*session=*/llvm::None);
return status;
}

View File

@ -269,10 +269,10 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
return Status::OK();
}
Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
mlir::OwningModuleRef module,
const mlir::TFL::PassConfig& pass_config,
string* result) {
Status ConvertMLIRToTFLiteFlatBuffer(
const toco::TocoFlags& toco_flags, mlir::OwningModuleRef module,
const mlir::TFL::PassConfig& pass_config, string* result,
llvm::Optional<tensorflow::Session*> session) {
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
bool emit_custom_ops = toco_flags.allow_custom_ops();
@ -286,7 +286,7 @@ Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
mlir::PassManager pm(module->getContext());
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm);
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm, session);
// Convert back to outlined while format for export back to flatbuffer.
if (pass_config.legalize_tf_while) {
pm.addPass(mlir::TFL::CreateWhileOutlinePass());

View File

@ -18,9 +18,11 @@ limitations under the License.
#include <ostream>
#include <utility>
#include "llvm/ADT/Optional.h"
#include "mlir/IR/Module.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/lite/toco/model_flags.pb.h"
#include "tensorflow/lite/toco/toco_flags.pb.h"
#include "tensorflow/lite/toco/types.pb.h"
@ -44,10 +46,10 @@ Status PopulateQuantizationSpecs(
// Convert imported MLIR file to TfLite flatbuffer.
// This will also run relevant passes as well.
Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
mlir::OwningModuleRef module,
const mlir::TFL::PassConfig& pass_config,
string* result);
Status ConvertMLIRToTFLiteFlatBuffer(
const toco::TocoFlags& toco_flags, mlir::OwningModuleRef module,
const mlir::TFL::PassConfig& pass_config, string* result,
llvm::Optional<tensorflow::Session*> session);
// Give a warning for any unused flags that have been specified.
void WarningUnusedFlags(const toco::ModelFlags& model_flags,

View File

@ -289,8 +289,8 @@ class QuantizationDriver {
llvm::errs() << "\n\n\n" << current_op->getName() << "\n";
}
fn_.walk([&](Operation *op) {
if (llvm::isa<quant::QuantizeCastOp>(op) ||
llvm::isa<quant::DequantizeCastOp>(op) || llvm::isa<ConstantOp>(op))
if (llvm::isa<quant::QuantizeCastOp, quant::DequantizeCastOp, ConstantOp>(
op))
return;
if (current_op == op) llvm::errs() << "===>>>";
llvm::errs() << op->getName() << " : (";

View File

@ -172,7 +172,7 @@ struct QuantizationPattern : public RewritePattern {
Value quantized_value = op->getResult(0);
for (Operation* quantized_op : quantized_value.getUsers()) {
// If it is requantize op, we shouldn't rewrite this op.
if (llvm::isa<Q>(quantized_op) || llvm::isa<DQ>(quantized_op)) {
if (llvm::isa<Q, DQ>(quantized_op)) {
return failure();
}
@ -180,8 +180,8 @@ struct QuantizationPattern : public RewritePattern {
// ops dialect, we shouldn't rewrite.
if (quantized_op->isKnownTerminator() ||
quantized_op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
llvm::isa<quant::QuantizeCastOp>(quantized_op) ||
llvm::isa<quant::DequantizeCastOp>(quantized_op)) {
llvm::isa<quant::QuantizeCastOp, quant::DequantizeCastOp>(
quantized_op)) {
return failure();
}

View File

@ -42,7 +42,6 @@ class TestGraphDebugInfo(object):
func = model.get_concrete_function()
converter = tf.lite.TFLiteConverter.from_concrete_functions([func])
converter.experimental_new_converter = True
converter.convert()
# pylint: disable=line-too-long

View File

@ -51,7 +51,6 @@ class TestGraphDebugInfo(object):
# load the model and convert
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_path)
converter.experimental_new_converter = True
converter.convert()
# pylint: disable=line-too-long

View File

@ -26,6 +26,7 @@ filegroup(
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
"//tensorflow/compiler/mlir/lite:tf_tfl_translate",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:not",
],
)

View File

@ -0,0 +1,257 @@
# RUN: not tf_tfl_translate -tf-upgrade-legacy=false -tf-input-arrays=Placeholder,Placeholder_1 -tf-input-shapes=1,2:1 -tf-output-arrays=cond/Merge -tf-enable-shape-inference-on-import=false -mlir-print-debuginfo -output-mlir %s -o - 2>&1 | FileCheck %s
# CHECK: error: The graph has Control Flow V1 ops. TFLite converter doesn't support Control Flow V1 ops. Consider using Control Flow V2 ops instead.
node {
name: "Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 2
}
dim {
size: 2
}
}
tensor_content: "\315\314\314=\315\314L>\232\231\231>\315\314\314>"
}
}
}
}
node {
name: "Placeholder"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: -1
}
dim {
size: 2
}
}
}
}
}
node {
name: "Placeholder_1"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_BOOL
}
}
attr {
key: "shape"
value {
shape {
}
}
}
}
node {
name: "cond/Switch"
op: "Switch"
input: "Placeholder_1"
input: "Placeholder_1"
attr {
key: "T"
value {
type: DT_BOOL
}
}
}
node {
name: "cond/switch_t"
op: "Identity"
input: "cond/Switch:1"
attr {
key: "T"
value {
type: DT_BOOL
}
}
}
node {
name: "cond/switch_f"
op: "Identity"
input: "cond/Switch"
attr {
key: "T"
value {
type: DT_BOOL
}
}
}
node {
name: "cond/pred_id"
op: "Identity"
input: "Placeholder_1"
attr {
key: "T"
value {
type: DT_BOOL
}
}
}
node {
name: "cond/MatMul"
op: "MatMul"
input: "cond/MatMul/Switch:1"
input: "cond/MatMul/Switch_1:1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "transpose_a"
value {
b: false
}
}
attr {
key: "transpose_b"
value {
b: false
}
}
}
node {
name: "cond/MatMul/Switch"
op: "Switch"
input: "Placeholder"
input: "cond/pred_id"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@Placeholder"
}
}
}
}
node {
name: "cond/MatMul/Switch_1"
op: "Switch"
input: "Const"
input: "cond/pred_id"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@Const"
}
}
}
}
node {
name: "cond/Add"
op: "Add"
input: "cond/Add/Switch"
input: "cond/Add/Switch_1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "cond/Add/Switch"
op: "Switch"
input: "Placeholder"
input: "cond/pred_id"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@Placeholder"
}
}
}
}
node {
name: "cond/Add/Switch_1"
op: "Switch"
input: "Const"
input: "cond/pred_id"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@Const"
}
}
}
}
node {
name: "cond/Merge"
op: "Merge"
input: "cond/Add"
input: "cond/MatMul"
attr {
key: "N"
value {
i: 2
}
}
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "init"
op: "NoOp"
}
versions {
producer: 134
}

View File

@ -9,6 +9,15 @@ func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
// CHECK: return
}
func @sub(%arg0: tensor<1xi64>, %arg1: tensor<1xi64>) -> tensor<1xi64> {
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64>
return %0: tensor<1xi64>
// CHECK-LABEL: sub
// CHECK: tfl.sub %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xi64>
// CHECK: return
}
// CHECK-LABEL: testAddHighDimsHaveSameShape
func @testAddHighDimsHaveSameShape(%arg0: tensor<1x2x3x4x5x6x7x8xi32>, %arg1: tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32> {
// CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"}
@ -990,6 +999,13 @@ func @batch_to_space_nd(%arg0: tensor<4x2x2x3xf32>, %arg1: tensor<2xi32>, %arg2:
// CHECK: "tfl.batch_to_space_nd"(%arg0, %arg1, %arg2) : (tensor<4x2x2x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
}
func @batch_to_space_nd_unsupported(%arg0: tensor<?x1x1x1x4xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3x2xi32>) -> tensor<?x3x3x3x4xf32> {
%0 = "tf.BatchToSpaceND"(%arg0, %arg1, %arg2) : (tensor<?x1x1x1x4xf32>, tensor<3xi32>, tensor<3x2xi32>) -> tensor<?x3x3x3x4xf32>
return %0 : tensor<?x3x3x3x4xf32>
// CHECK-LABEL: batch_to_space_nd_unsupported
// CHECK: "tf.BatchToSpaceND"
}
func @space_to_batch_nd(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2x2xi32>) -> tensor<?xf32> {
%0 = "tf.SpaceToBatchND"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>

View File

@ -269,6 +269,14 @@ func @testSub(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
return %0#0 : tensor<? x i32>
}
// CHECK-LABEL: testSubInt64
func @testSubInt64(tensor<? x i64>, tensor<? x i64>) -> tensor<? x i64> {
^bb0(%arg0: tensor<? x i64>, %arg1: tensor<? x i64>):
// CHECK: tfl.sub %arg0, %arg1 {fused_activation_function = "RELU6"}
%0 = tfl.sub %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor<? x i64>
return %0#0 : tensor<? x i64>
}
// CHECK-LABEL: testMul
func @testMul(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):

View File

@ -984,3 +984,11 @@ func @ReorderAddWithConstant(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[RESULT:.*]] = tfl.add %arg0, %[[CONST]] {fused_activation_function = "NONE"} : tensor<2x2xf32>
}
func @RemoveCast(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
%1 = "tfl.cast"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %1 : tensor<2x2xf32>
// CHECK-LABEL: RemoveCast
// CHECK: return %arg0
}

View File

@ -70,6 +70,7 @@ func @prepareAdd(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
}
// CHECK-LABEL: prepareConv2DSplat
// PerTensor-LABEL: prepareConv2DSplat
func @prepareConv2DSplat(%arg0: tensor<1x5x5x3xf32>) -> tensor<1x5x5x3xf32> {
%w = constant dense<127.0> : tensor<3x3x3x3xf32>
%b = constant dense<0.0> : tensor<3xf32>
@ -89,6 +90,7 @@ func @prepareConv2DSplat(%arg0: tensor<1x5x5x3xf32>) -> tensor<1x5x5x3xf32> {
}
// CHECK-LABEL: prepareConv2D
// PerTensor-LABEL: prepareConv2D
func @prepareConv2D(%arg0: tensor<1x5x5x1xf32>) -> tensor<1x5x5x3xf32> {
%w = constant dense<[[[[0.0]]], [[[127.0]]], [[[-127.0]]]]> : tensor<3x1x1x1xf32>
%b = constant dense<0.0> : tensor<3xf32>
@ -108,6 +110,7 @@ func @prepareConv2D(%arg0: tensor<1x5x5x1xf32>) -> tensor<1x5x5x3xf32> {
}
// CHECK-LABEL: prepareDepthwiseConv2D
// PerTensor-LABEL: prepareDepthwiseConv2D
func @prepareDepthwiseConv2D(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
%w = constant dense<127.0> : tensor<32x3x3x3xf32>
%b = constant dense<0.0> : tensor<32xf32>
@ -127,6 +130,7 @@ func @prepareDepthwiseConv2D(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112
}
// CHECK-LABEL: QuantizeFullyConnected
// PerTensor-LABEL: QuantizeFullyConnected
func @QuantizeFullyConnected(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
%w = constant dense<127.0> : tensor<32x12xf32>
%b = constant dense<0.0> : tensor<32xf32>
@ -143,3 +147,22 @@ func @QuantizeFullyConnected(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112
// PerTensor: %[[dq:.*]] = "tfl.dequantize"(%0) : (tensor<32x12x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>) -> tensor<32x12xf32>
// PerTensor: "tfl.fully_connected"(%arg0, %[[dq]]
}
// CHECK-LABEL: QuantizeTransposeConv
// PerTensor-LABEL: QuantizeTransposeConv
func @QuantizeTransposeConv(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<4xi32>) -> tensor<1x32x42x128xf32> {
%w = constant dense<127.0> : tensor<1x32x42x128xf32>
%b = constant dense<0.0> : tensor<1x32x42x128xf32>
%tc = "tfl.transpose_conv"(%arg1, %arg0, %w, %b) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<1x32x42x128xf32>) -> tensor<1x32x42x128xf32>
return %tc : tensor<1x32x42x128xf32>
// CHECK: %[[CST:.*]] = constant dense<1.270000e+02> : tensor<1x32x42x128xf32>
// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CST]]) {qtype = tensor<1x32x42x128x!quant.uniform<i8<-127:127>:f32:0, {1.000000e+00}>>, volatile}
// CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) : (tensor<1x32x42x128x!quant.uniform<i8<-127:127>:f32:0, {1.000000e+00}>>) -> tensor<1x32x42x128xf32>
// CHECK: "tfl.transpose_conv"(%arg1, %arg0, %[[DEQUANTIZE]]
// PerTensor: %[[CST:.*]] = constant dense<1.270000e+02> : tensor<1x32x42x128xf32>
// PerTensor: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CST]]) {qtype = tensor<1x32x42x128x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>, volatile}
// PerTensor: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) : (tensor<1x32x42x128x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>) -> tensor<1x32x42x128xf32>
// PerTensor: "tfl.transpose_conv"(%arg1, %arg0, %[[DEQUANTIZE]]
}

View File

@ -528,6 +528,26 @@ func @PadStridedSliceNewAxisMask2(%arg0: tensor<4x64x64x1xf32>) -> tensor<1x4x64
return %1 : tensor<1x4x64x64xf32>
}
// CHECK-LABEL: @StridedSliceRewriteMasks
func @StridedSliceRewriteMasks(%arg0: tensor<8x4x16x2xf32>) -> tensor<8x4x16x1xf32> {
%cst = "tf.Const"() {device = "", value = dense<[1, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
%cst_0 = "tf.Const"() {device = "", value = dense<[1, 0, 0]> : tensor<3xi32>} : () -> tensor<3xi32>
%cst_1 = "tf.Const"() {device = "", value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK: %[[CST:.*]] = constant dense<[1, 0, 0, 1]> : tensor<4xi32>
// CHECK: %[[CST0:.*]] = constant dense<[1, 0, 0, 0]> : tensor<4xi32>
// CHECK: %[[CST1:.*]] = constant dense<1> : tensor<4xi32>
// CHECK: %[[RESULT:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST0]], %[[CST1]])
// CHECK-SAME: begin_mask = 7 : i64
// CHECK-SAME: ellipsis_mask = 0 : i64
// CHECK-SAME: end_mask = 14 : i64
// CHECK-SAME: new_axis_mask = 0 : i64
// CHECK-SAME: shrink_axis_mask = 0 : i64
%0 = "tf.StridedSlice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 1 : i64, device = "", ellipsis_mask = 2 : i64, end_mask = 4 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<8x4x16x2xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x4x16x1xf32>
return %0 : tensor<8x4x16x1xf32>
}
// CHECK-LABEL: @MatrixSetDiagV2Conversion
func @MatrixSetDiagV2Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
%cst = constant dense<0> : tensor<i32>

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
#include "llvm/ADT/Optional.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
@ -26,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
namespace mlir {
@ -39,35 +41,48 @@ namespace tensorflow {
void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
mlir::OpPassManager* pass_manager) {
pass_manager->addPass(mlir::TFL::CreatePrepareQuantizePass(quant_specs));
pass_manager->addPass(mlir::TFL::CreateQuantizePass());
bool emit_quant_adaptor_ops =
quant_specs.inference_type != quant_specs.inference_input_type;
pass_manager->addPass(
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
if (quant_specs.default_ranges.first.hasValue() ||
quant_specs.default_ranges.second.hasValue()) {
pass_manager->addPass(mlir::TFL::CreateDefaultQuantParamsPass(
quant_specs.default_ranges.first.getValueOr(0.0),
quant_specs.default_ranges.second.getValueOr(0.0),
quant_specs.IsSignedInferenceType()));
pass_manager->addPass(mlir::TFL::CreateQuantizePass());
pass_manager->addPass(
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
}
pass_manager->addPass(mlir::TFL::CreateQuantizePass());
bool emit_quant_adaptor_ops =
quant_specs.inference_type != quant_specs.inference_input_type;
pass_manager->addPass(
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
}
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
mlir::OpPassManager* pass_manager) {
mlir::OpPassManager* pass_manager,
llvm::Optional<tensorflow::Session*> session) {
mlir::TF::StandardPipelineOptions standard_pipeline_options;
standard_pipeline_options.enable_inliner = false;
standard_pipeline_options.form_clusters = pass_config.form_clusters;
mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options);
pass_manager->addPass(mlir::TFL::CreateDeviceIndexSelectorPass());
pass_manager->addPass(mlir::TF::CreateDeviceIndexSelectorPass());
if (pass_config.shape_inference) {
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
}
if (session.hasValue()) {
// Add a pass that converts reference variables to resource variables.
pass_manager->addPass(
mlir::TF::
CreateConvertReadonlyReferenceVariablesToResourceVariablesPass());
// Add a pass that promotes resource variable to the function arguments.
pass_manager->addPass(mlir::TF::CreatePromoteVarHandlesToArgsPass());
// Add a pass that creates global tensors and converts the function
// arguments to the tf_saved_model.bound_input arguments.
pass_manager->addPass(
mlir::tf_saved_model::CreateLiftVariablesPass(session.getValue()));
}
// Keep this pass after the shape inference pass, which couldn't do shape
// inference for non-tf ops.
if (!pass_config.quant_specs.serialized_quant_stats.empty()) {
@ -212,9 +227,6 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
// Saved model pass to mark global tensors immutable.
pm.addPass(mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
// Used to mark non-exported functions in saved model private.
pm.addPass(mlir::tf_saved_model::
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass());
// Op fusion pass.
pm.addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());

View File

@ -16,16 +16,21 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_
#include "llvm/ADT/Optional.h"
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/core/public/session.h"
namespace tensorflow {
// Add the TF to TFLite passes, specified in the pass_config, into a
// pass_manager.
// pass_manager. The session object will be provided when the TF MLIR is
// imported from saved model version one and utilized for capturing resource
// variables.
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
mlir::OpPassManager* pass_manager);
mlir::OpPassManager* pass_manager,
llvm::Optional<tensorflow::Session*> session);
// Add the Quantization passes, specified in the quant_specs, into a pass
// manager.

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <iostream>
#include "absl/strings/str_split.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
@ -29,6 +30,7 @@ limitations under the License.
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/compiler/mlir/init_mlir.h"
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
@ -37,6 +39,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h"
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/errors.h"
@ -168,11 +171,13 @@ int main(int argc, char **argv) {
module = tensorflow::ImportSavedModel(input_file_name, saved_model_version,
tags, exported_names, &context);
} else {
tensorflow::GraphImportConfig specs;
specs.upgrade_legacy = upgrade_legacy;
specs.prune_unused_nodes = true;
module = tensorflow::LoadFromGraphdefOrMlirSource(
input_file_name, input_mlir, use_splatted_constant, custom_opdefs,
debug_info_file, input_arrays, input_dtypes, input_shapes,
output_arrays,
/*prune_unused_nodes=*/true, &source_mgr, &context);
specs, debug_info_file, input_arrays, input_dtypes, input_shapes,
output_arrays, &source_mgr, &context);
}
// If errors occur, the library call in the above already logged the error
@ -220,7 +225,9 @@ int main(int argc, char **argv) {
pass_config.lower_tensor_list_ops = lower_tensor_list_ops;
pass_config.legalize_tf_while = convert_tf_while_to_tfl_while;
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm);
// TODO(b/153507667): Pass the session object when importing logic is removed.
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm,
/*session=*/llvm::None);
// TODO(b/150901738): Move those into tf_tfl_translate.cc.
// Convert back to outlined while format for export back to flatbuffer.
if (pass_config.legalize_tf_while) {

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Parser.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
@ -28,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
@ -39,18 +41,43 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
namespace {
using mlir::MLIRContext;
using mlir::ModuleOp;
using mlir::Operation;
using mlir::OwningModuleRef;
using stream_executor::port::StatusOr;
bool IsControlFlowV1Op(Operation* op) {
return mlir::isa<mlir::tf_executor::SwitchOp, mlir::tf_executor::MergeOp,
mlir::tf_executor::EnterOp, mlir::tf_executor::ExitOp,
mlir::tf_executor::NextIterationSinkOp,
mlir::tf_executor::NextIterationSourceOp>(op);
}
mlir::LogicalResult IsValidGraph(mlir::ModuleOp module) {
auto result = module.walk([&](Operation* op) {
return IsControlFlowV1Op(op) ? mlir::WalkResult::interrupt()
: mlir::WalkResult::advance();
});
if (result.wasInterrupted()) {
module.emitError(
"The graph has Control Flow V1 ops. TFLite converter doesn't support "
"Control Flow V1 ops. Consider using Control Flow V2 ops instead. See "
"https://www.tensorflow.org/api_docs/python/tf/compat/v1/"
"enable_control_flow_v2.");
return mlir::failure();
}
return mlir::success();
}
} // namespace
StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
const std::string& input_filename, bool input_mlir,
bool use_splatted_constant, const std::vector<std::string>& extra_tf_opdefs,
absl::string_view debug_info_file, absl::string_view input_arrays,
absl::string_view input_dtypes, absl::string_view input_shapes,
absl::string_view output_arrays, bool prune_unused_nodes,
const GraphImportConfig& specs, absl::string_view debug_info_file,
absl::string_view input_arrays, absl::string_view input_dtypes,
absl::string_view input_shapes, absl::string_view output_arrays,
llvm::SourceMgr* source_mgr, MLIRContext* context) {
// Set up the input file.
std::string error_message;
@ -85,15 +112,15 @@ StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
return tensorflow::GraphdefToSplattedMlirTranslateFunction(
file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
input_shapes, output_arrays, /*control_output_arrays=*/"",
prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
/*graph_as_function=*/false, /*upgrade_legacy=*/true,
specs.prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
/*graph_as_function=*/false, specs.upgrade_legacy,
/*enable_shape_inference=*/false, context);
}
return tensorflow::GraphdefToMlirTranslateFunction(
file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
input_shapes, output_arrays, /*control_output_arrays=*/"",
prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
/*graph_as_function=*/false, /*upgrade_legacy=*/true,
specs.prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
/*graph_as_function=*/false, specs.upgrade_legacy,
/*enable_shape_inference=*/false, context);
}
@ -104,7 +131,8 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
mlir::PassManager* pass_manager) {
mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(),
/*propagate=*/true);
if (failed(pass_manager->run(module))) {
if (failed(IsValidGraph(module)) || failed(pass_manager->run(module))) {
return statusHandler.ConsumeStatus();
}

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
@ -38,9 +39,9 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef>
LoadFromGraphdefOrMlirSource(
const std::string& input_filename, bool input_mlir,
bool use_splatted_constant, const std::vector<std::string>& extra_tf_opdefs,
absl::string_view debug_info_file, absl::string_view input_arrays,
absl::string_view input_dtypes, absl::string_view input_shapes,
absl::string_view output_arrays, bool prune_unused_nodes,
const GraphImportConfig& specs, absl::string_view debug_info_file,
absl::string_view input_arrays, absl::string_view input_dtypes,
absl::string_view input_shapes, absl::string_view output_arrays,
llvm::SourceMgr* source_mgr, mlir::MLIRContext* context);
// Load Saved model (either v1 or v2) into MLIR.

View File

@ -110,8 +110,7 @@ void DefaultQuantParamsPass::runOnFunction() {
func.walk([&](Operation *op) {
if (op->isKnownTerminator() ||
op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
llvm::isa<quant::QuantizeCastOp>(op) ||
llvm::isa<quant::DequantizeCastOp>(op))
llvm::isa<quant::QuantizeCastOp, quant::DequantizeCastOp>(op))
return;
for (auto res : op->getResults()) {

View File

@ -28,9 +28,11 @@ limitations under the License.
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Threading.h"
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
@ -767,13 +769,26 @@ void LegalizeTF::runOnFunction() {
[](Operation* op) {
auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op);
if (!tfl_op) return false;
return succeeded(tfl_op.VerifyTflRuntimeConstraints(
tfl_op.getOperation(),
/*failure_on_operand_type_mismatch=*/false));
return succeeded(tfl_op.VerifyTflRuntimeConstraints(op));
}));
} else {
target.addLegalDialect<TensorFlowLiteDialect>();
}
// Ignore transient errors by registering an no-op handler.
// Applying legalization patterns will emit unwanted, transient errors when
// the replaced TFLite ops do not meet the sanity checks. In order to ignore
// the transient errors, the following lines override a diagnostic handler
// with an no-op handler only while this pass runs.
uint64_t current_thread_id = llvm::get_threadid();
ScopedDiagnosticHandler scoped_diag_handler(
context, [&current_thread_id](Diagnostic&) -> LogicalResult {
// Consume only errors that are coming from the same thread in order not
// to ignore errors from other passes that are running. Things running
// in the pass manager can be multi-threaded.
return success(current_thread_id == llvm::get_threadid());
});
// Keep trying to convert.
// TODO(karimnosseir): This is similar to what apply greedy patterns does.
// Look if there is a function that tries until it converge.

View File

@ -36,12 +36,14 @@ limitations under the License.
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
@ -216,18 +218,6 @@ static Type GetShapeStrippedType(TypeAttr type_attr) {
}
}
bool NotFromQuantOpDifferentQuant(Value val, TypeAttr qtype_attr) {
auto val_defn_op = val.getDefiningOp();
TFL::QuantizeOp q_op = llvm::dyn_cast_or_null<TFL::QuantizeOp>(val_defn_op);
if (!q_op) return true;
// Ignore shape details - weŕe really only trying to
// check if quantization is the same.
auto stripped_src_qtype = GetShapeStrippedType(q_op.qtypeAttr());
auto stripped_qtype = GetShapeStrippedType(qtype_attr);
return stripped_src_qtype == stripped_qtype;
}
#include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc"
// Fuse Add with proceeding FullyConnected.

View File

@ -18,6 +18,7 @@ limitations under the License.
include "mlir/IR/OpBase.td"
include "mlir/Dialect/StandardOps/IR/Ops.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
include "tensorflow/compiler/mlir/lite/utils/utils.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
def F32ElementsAttr : ElementsAttrBase<
@ -26,17 +27,10 @@ def F32ElementsAttr : ElementsAttrBase<
def ExtractSingleElementAsFloat : NativeCodeCall<
"ExtractSingleElementAsFloat($_self.cast<ElementsAttr>())">;
// Checks if the value has only one user.
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
// Checks if the value has rank at most 'n'.
class HasRankAtMost<int n> : Constraint<
CPred<"$0.getType().cast<ShapedType>().getRank() <= " # n>>;
// Checks value is not produce by a TLF_QUant with
// different quantization attribute
def NotFromQuantOpDifferentQuant : Constraint<CPred<"NotFromQuantOpDifferentQuant($0,$1)">>;
//===----------------------------------------------------------------------===//
// Ternary ops patterns.
//===----------------------------------------------------------------------===//
@ -169,9 +163,9 @@ foreach BinaryOp = [TFL_DivOp, TFL_MulOp] in
// with the same scale. We want to remove the redundancy.
// TODO(fengliuai): move this to the sanity check of pre-quantize pass.
def eliminate_dq_q_pairs : Pat<
(TFL_QuantizeOp (TFL_DequantizeOp $in), $qt),
(replaceWithValue $in),
[(NotFromQuantOpDifferentQuant $in, $qt)]>;
(TFL_QuantizeOp (TFL_DequantizeOp $in), $qt),
(replaceWithValue $in),
[(NotFromQuantOpOrSameQuantType $in, $qt)]>;
// Constraint that makes sure both operands are the same operands.
@ -446,6 +440,10 @@ def : Pat<(TFL_MaximumOp (TFL_MulOp:$mul_out $input1,
(EqualOperands $input1, $input2),
(HasOneUse $mul_out)]>;
def RemoveTrivialCast : Pat<(TFL_CastOp:$output $input),
(replaceWithValue $input),
[(SameElementType $input, $output)]>;
// Checks if the operand0's rank is one less than operand1's rank.
def PReluAlphaRankCheck : Constraint<
CPred<"$0.getType().cast<ShapedType>().getRank() == "

View File

@ -91,9 +91,6 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass();
// Verifies runtime constraints.
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass();
// Creates function pass to select device index/fold tf.DeviceIndex.
std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass();
} // namespace TFL
} // namespace mlir

View File

@ -52,7 +52,7 @@ class PostQuantizePass : public PassWrapper<PostQuantizePass, FunctionPass> {
void RemoveQuantizationAdaptorOps(FuncOp func) {
mlir::OpBuilder builder(func.getBody());
auto& bb = func.getBlocks().front();
auto& bb = func.front();
auto* terminator = bb.getTerminator();
int num_args = bb.getNumArguments();

View File

@ -584,46 +584,50 @@ struct ConvertTFStridedSlice : public RewritePattern {
const int ellipsis_filled_dim_size = input_size - begin_shape[0] + 1;
llvm::APInt new_begin_mask = strided_slice_op.begin_mask();
llvm::APInt new_end_mask = strided_slice_op.end_mask();
int64_t begin_mask = strided_slice_op.begin_mask().getSExtValue();
int64_t end_mask = strided_slice_op.end_mask().getSExtValue();
int64_t new_begin_mask = 0;
int64_t new_end_mask = 0;
SmallVector<int32_t, 4> padded_begin;
SmallVector<int32_t, 4> padded_end;
SmallVector<int32_t, 4> padded_stride;
// Before the ellipsis.
uint64_t index = 1;
int count = 0;
while (index < ellipsis_mask) {
padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(count));
padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(count));
padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(count));
index <<= 1;
count++;
int index = 0;
int new_index = 0;
while (((ellipsis_mask >> index) & 1) == 0) {
padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(index));
padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(index));
padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(index));
if ((begin_mask >> index) & 1) new_begin_mask |= (1 << new_index);
if ((end_mask >> index) & 1) new_end_mask |= (1 << new_index);
++index;
++new_index;
}
// Ellipsis.
for (int i = 0; i < ellipsis_filled_dim_size; ++i) {
new_begin_mask |= ellipsis_mask;
new_end_mask |= ellipsis_mask;
for (; new_index < index + ellipsis_filled_dim_size; ++new_index) {
new_begin_mask |= (1 << new_index);
new_end_mask |= (1 << new_index);
// Mimic the begin/end/strides mask behavior.
padded_begin.push_back(0);
padded_end.push_back(0);
padded_stride.push_back(1);
ellipsis_mask <<= 1;
}
// Account for ellipsis mask.
count++;
++index;
// After the ellipsis.
for (; count < begin_shape[0]; ++count) {
padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(count));
padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(count));
padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(count));
for (; index < begin_shape[0]; ++index) {
padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(index));
padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(index));
padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(index));
if ((begin_mask >> index) & 1) new_begin_mask |= (1 << new_index);
if ((end_mask >> index) & 1) new_end_mask |= (1 << new_index);
}
auto attribute_type = rewriter.getIntegerType(64);
@ -645,7 +649,7 @@ struct ConvertTFStridedSlice : public RewritePattern {
end_op.getResult(), stride_op.getResult(),
rewriter.getIntegerAttr(attribute_type, new_begin_mask),
rewriter.getIntegerAttr(attribute_type, new_end_mask),
rewriter.getI64IntegerAttr(0),
/*ellipsis_maks=*/rewriter.getI64IntegerAttr(0),
rewriter.getIntegerAttr(attribute_type,
strided_slice_op.new_axis_mask()),
rewriter.getIntegerAttr(attribute_type,
@ -655,10 +659,12 @@ struct ConvertTFStridedSlice : public RewritePattern {
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
// TODO(renjieliu): Consider expand the transformation for shrink
// mask as well.
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
// TODO(renjieliu): Consider expand the transformation for shrink mask as
// well.
if (strided_slice_op.shrink_axis_mask().getZExtValue()) return failure();
// Handle new axis mask.
uint64_t new_axis_mask = strided_slice_op.new_axis_mask().getZExtValue();
if (new_axis_mask != 0) {

View File

@ -34,8 +34,7 @@ class RuntimeVerifyPass
void RuntimeVerifyPass::runOnFunction() {
getFunction().walk([&](TflRuntimeVerifyOpInterface op) {
if (failed(op.VerifyTflRuntimeConstraints(
op.getOperation(), /*failure_on_operand_type_mismatch=*/true)))
if (failed(op.VerifyTflRuntimeConstraints(op.getOperation())))
signalPassFailure();
});
}

Some files were not shown because too many files have changed in this diff Show More