Merge branch 'master' into sign-compare-warning-fixes-batch-1-fix2

This commit is contained in:
tg-at-google 2020-06-24 15:08:13 -04:00 committed by GitHub
commit 41707f74dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1442 changed files with 47843 additions and 21655 deletions

View File

@ -39,6 +39,7 @@
#
# Feature and Third party library support options:
# xla: Build TF with XLA
# tpu: Build TF with TPU support
# using_cuda: CUDA is available to build system.
# cuda: Build with full cuda support.
# rocm: Build with AMD GPU support (rocm).
@ -57,13 +58,12 @@
#
#
# Remote build execution options (only configured to work with TF team projects for now.)
# rbe: General RBE options shared by all flavors.
# rbe_linux: General RBE options used on all linux builds.
# rbe_win: General RBE options used on all windows builds.
# rbe: General RBE options shared by all flavors.
# rbe_linux: General RBE options used on all linux builds.
# rbe_win: General RBE options used on all windows builds.
#
# rbe_cpu_linux: RBE options to build with only CPU support.
# rbe_linux_cuda_nvcc: RBE options to build with GPU support using nvcc.
# rbe_gpu_linux: An alias for rbe_linux_cuda_nvcc
# rbe_cpu_linux: RBE options to build with only CPU support.
# rbe_linux_cuda_nvcc_py*: RBE options to build with GPU support using nvcc.
#
# rbe_linux_py2: Linux Python 2 RBE config.
# rbe_linux_py3: Linux Python 3 RBE config
@ -180,6 +180,9 @@ build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON
# AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498
build:dbg --copt -DDEBUG_BUILD
# Config to build TPU backend
build:tpu --define=with_tpu_support=true
build:tensorrt --action_env TF_NEED_TENSORRT=1
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
@ -396,33 +399,48 @@ build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1
build:rbe_linux_cuda_base --repo_env=TF_NEED_CUDA=1
test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --host_platform="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --platforms="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_nvcc --define=using_cuda_nvcc=true
test:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
build:rbe_linux_cuda10.1_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda10.1_nvcc_base --define=using_cuda_nvcc=true
build:rbe_linux_cuda10.1_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda10.1_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda10.1_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda10.1_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda10.1_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda10.1_nvcc_py2.7 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
build:rbe_linux_cuda10.1_nvcc_py3.5 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
build:rbe_linux_cuda10.1_nvcc_py3.6 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
build:rbe_linux_cuda10.1_nvcc_py3.7 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
build:rbe_linux_cuda10.1_nvcc_py3.8 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
build:rbe_linux_cuda_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_nvcc_base --define=using_cuda_nvcc=true
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
build:rbe_linux_cuda11.0_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda11.0_nvcc_base --define=using_cuda_nvcc=true
build:rbe_linux_cuda11.0_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda11.0_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda11.0_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform"
build:rbe_linux_cuda11.0_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform"
build:rbe_linux_cuda11.0_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform"
build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda"
build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_tensorrt"
build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_nccl"
build:rbe_linux_cuda11.0_nvcc_py2.7 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python2.7"
build:rbe_linux_cuda11.0_nvcc_py3.5 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.5"
build:rbe_linux_cuda11.0_nvcc_py3.6 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.6"
build:rbe_linux_cuda11.0_nvcc_py3.7 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.7"
build:rbe_linux_cuda11.0_nvcc_py3.8 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.8"
# Map default to CUDA 10.1.
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda10.1_nvcc_py2.7
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda10.1_nvcc_py3.5
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda10.1_nvcc_py3.6
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda10.1_nvcc_py3.7
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda10.1_nvcc_py3.8
# Deprecated configs that people might still use.
build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_nvcc_py36
build:rbe_gpu_linux --config=rbe_linux_cuda_nvcc
build:rbe_linux_cuda_clang_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda_clang_base --crosstool_top="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
@ -440,8 +458,6 @@ build:rbe_linux_cuda_clang_py36 --config=rbe_linux_cuda_clang_base --repo_env=TF
build:rbe_linux_cuda_clang_py37 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
build:rbe_linux_cuda_clang_py38 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
common:rbe_gpu_linux --config=rbe_linux_cuda_nvcc
build:rbe_linux_py2 --config=rbe_linux
build:rbe_linux_py2 --repo_env=PYTHON_BIN_PATH="/usr/bin/python2"
build:rbe_linux_py2 --python_path="/usr/bin/python2"

View File

@ -95,6 +95,7 @@ for general questions and discussion, and please direct specific questions to
The TensorFlow project strives to abide by generally accepted best practices in
open-source software development:
[![Fuzzing Status](https://oss-fuzz-build-logs.storage.googleapis.com/badges/tensorflow.svg)](https://bugs.chromium.org/p/oss-fuzz/issues/list?sort=-opened&can=1&q=proj:tensorflow)
[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1486/badge)](https://bestpractices.coreinfrastructure.org/projects/1486)
[![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-v1.4%20adopted-ff69b4.svg)](CODE_OF_CONDUCT.md)

View File

@ -1,3 +1,57 @@
# Release 2.4.0
<INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES>
## Breaking Changes
* <DOCUMENT BREAKING CHANGES HERE>
* <THIS SECTION SHOULD CONTAIN API, ABI AND BEHAVIORAL BREAKING CHANGES>
## Known Caveats
* <CAVEATS REGARDING THE RELEASE (BUT NOT BREAKING CHANGES). E.G. ADDING A NEW DEPENDENCY, BUMPING A DEPENDENCY NUMBER, LACK OF SUPPORT ON SOME PLATFORM, ETC>
## Major Features and Improvements
* <INSERT MAJOR FEATURE HERE, USING MARKDOWN SYNTAX>
* <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER>
## Bug Fixes and Other Changes
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
* <NOTES SHOULD BE GROUPED PER AREA>
* TF Core:
* <ADD RELEASE NOTES HERE>
* `tf.data`:
* <ADD RELEASE NOTES HERE>
* `tf.distribute`:
* <ADD RELEASE NOTES HERE>
* `tf.keras`:
* <ADD RELEASE NOTES HERE>
* `tf.function`/AutoGraph:
* <ADD RELEASE NOTES HERE>
* `tf.lite`:
* <ADD RELEASE NOTES HERE>
* `tf.random`:
* <ADD RELEASE NOTES HERE>
* Math and Linear Algebra:
* <ADD RELEASE NOTES HERE>
* TPU Enhancements:
* <ADD RELEASE NOTES HERE>
* XLA Support:
* <ADD RELEASE NOTES HERE>
* Tracing and Debugging:
* <ADD RELEASE NOTES HERE>
* Other:
* <ADD RELEASE NOTES HERE>
## Thanks to our Contributors
This release contains contributions from many people at Google, as well as:
<INSERT>, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
# Release 2.3.0
## Breaking Changes

View File

@ -467,6 +467,13 @@ config_setting(
visibility = ["//visibility:public"],
)
# This flag enables experimental TPU support
config_setting(
name = "with_tpu_support",
values = {"define": "with_tpu_support=true"},
visibility = ["//visibility:public"],
)
# Specifies via a config setting if this is a mobile build or not, makes
# it easier to combine settings later.
selects.config_setting_group(

View File

@ -624,7 +624,7 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
const int num_inputs = input_shapes->num_items;
NodeDef node_def;
tensorflow::AbstractOperationInterface* op = tensorflow::unwrap(tfe_op);
tensorflow::ImmediateExecutionOperation* op = tensorflow::unwrap(tfe_op);
node_def.set_name(op->Name());
node_def.set_op(op->Name());
for (int i = 0; i < num_inputs; ++i) {

View File

@ -54,7 +54,7 @@ Status ProcessInputs(
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
input_tensors->reserve(ninputs);
for (int i = 0; i < ninputs; ++i) {
Node* node = &inputs[i].oper->node;
Node* node = inputs[i].oper ? &inputs[i].oper->node : nullptr;
int idx = inputs[i].index;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
@ -90,7 +90,7 @@ Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
output_tensors->reserve(noutputs);
for (int i = 0; i < noutputs; ++i) {
Node* node = &outputs[i].oper->node;
Node* node = outputs[i].oper ? &outputs[i].oper->node : nullptr;
int idx = outputs[i].index;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
fn_body->graph.IsValidOutputTensor(node, idx),

View File

@ -38,9 +38,10 @@ tf_cuda_library(
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
":context_interface",
":operation_interface",
":tensor_handle_interface",
":immediate_execution_context",
":immediate_execution_operation",
":immediate_execution_tensor_handle",
":abstract_tensor_handle",
":tfe_context_internal",
":tfe_cancellation_manager_internal",
":tfe_executor_internal",
@ -101,13 +102,17 @@ tf_cuda_library(
filegroup(
name = "pywrap_required_hdrs",
srcs = [
"abstract_context.h",
"abstract_function.h",
"abstract_operation.h",
"abstract_tensor_handle.h",
"c_api_experimental.h",
"c_api_internal.h",
"c_api_unified_experimental.h",
"context_interface.h",
"dlpack.h",
"operation_interface.h",
"tensor_handle_interface.h",
"immediate_execution_context.h",
"immediate_execution_operation.h",
"immediate_execution_tensor_handle.h",
"tfe_cancellation_manager_internal.h",
"tfe_executor_internal.h",
"tfe_monitoring_internal.h",
@ -163,12 +168,22 @@ cc_library(
)
cc_library(
name = "tensor_handle_interface",
hdrs = ["tensor_handle_interface.h"],
name = "abstract_tensor_handle",
hdrs = ["abstract_tensor_handle.h"],
visibility = [
"//tensorflow:internal",
],
deps = [],
)
cc_library(
name = "immediate_execution_tensor_handle",
hdrs = ["immediate_execution_tensor_handle.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
":abstract_tensor_handle",
"//tensorflow/c:tensor_interface",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@ -177,13 +192,13 @@ cc_library(
)
cc_library(
name = "operation_interface",
hdrs = ["operation_interface.h"],
name = "abstract_operation",
hdrs = ["abstract_operation.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
":tensor_handle_interface",
":abstract_tensor_handle",
"//tensorflow/c:tensor_interface",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@ -193,14 +208,58 @@ cc_library(
)
cc_library(
name = "context_interface",
hdrs = ["context_interface.h"],
name = "immediate_execution_operation",
hdrs = ["immediate_execution_operation.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
":operation_interface",
":tensor_handle_interface",
":abstract_operation",
":abstract_tensor_handle",
":immediate_execution_tensor_handle",
"//tensorflow/c:tensor_interface",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "abstract_context",
hdrs = ["abstract_context.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
":abstract_function",
":abstract_operation",
],
)
cc_library(
name = "abstract_function",
hdrs = ["abstract_function.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:status",
],
)
cc_library(
name = "immediate_execution_context",
hdrs = ["immediate_execution_context.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
":abstract_context",
":immediate_execution_operation",
":immediate_execution_tensor_handle",
"//tensorflow/c:tensor_interface",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@ -217,7 +276,7 @@ cc_library(
"//tensorflow:internal",
],
deps = [
":context_interface",
":immediate_execution_context",
"//tensorflow/c:conversion_macros",
],
)
@ -277,7 +336,7 @@ cc_library(
"//tensorflow:internal",
],
deps = [
":operation_interface",
":immediate_execution_operation",
"//tensorflow/c:conversion_macros",
],
)
@ -300,7 +359,7 @@ cc_library(
"//tensorflow:internal",
],
deps = [
":tensor_handle_interface",
":immediate_execution_tensor_handle",
"//tensorflow/c:conversion_macros",
],
)
@ -480,6 +539,9 @@ tf_cuda_library(
":tfe_context_internal",
":tfe_op_internal",
":tfe_tensorhandle_internal",
":abstract_operation",
":abstract_context",
":abstract_tensor_handle",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
"//tensorflow/core:core_cpu",

View File

@ -0,0 +1,83 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_
#define TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_
#include <memory>
#include <vector>
#include "tensorflow/c/eager/abstract_function.h"
#include "tensorflow/c/eager/abstract_operation.h"
namespace tensorflow {
// Abstract interface to a context.
//
// This serves as a factory for creating `AbstractOperation`s and for
// registering traced functions.
// Operations creation within a context can only be executed in that context
// (for now at least).
// Implementations of the context may contain some state e.g. an execution
// environment, a traced representation etc.
class AbstractContext {
protected:
enum AbstractContextKind { kTracing, kImmediateExecution };
explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {}
virtual ~AbstractContext() {}
public:
AbstractContextKind getKind() const { return kind_; }
// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus clients MUST call Release() in order to
// destroy an instance of this class.
virtual void Release() = 0;
// Creates an operation builder and ties it to this context.
// The returned object can be used for setting operation's attributes,
// adding inputs and finally executing (immediately or lazily as in tracing)
// it in this context.
virtual AbstractOperation* CreateOperation() = 0;
// Registers a function with this context, after this the function is
// available to be called/referenced by its name in this context.
virtual Status RegisterFunction(AbstractFunction*) = 0;
// Remove a function. 'func' argument is the name of a previously added
// FunctionDef. The name is in fdef.signature.name.
virtual Status RemoveFunction(const string& func) = 0;
private:
const AbstractContextKind kind_;
};
namespace internal {
struct AbstractContextDeleter {
void operator()(AbstractContext* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using AbstractContextPtr =
std::unique_ptr<AbstractContext, internal::AbstractContextDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_

View File

@ -0,0 +1,46 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_
#define TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_
#include "tensorflow/c/c_api.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
// A traced function: this hides the complexity of converting the serialized
// representation between various supported formats e.g. FunctionDef and Mlir
// function.
class AbstractFunction {
protected:
enum AbstractFunctionKind { kGraphFunc, kMlirFunc };
explicit AbstractFunction(AbstractFunctionKind kind) : kind_(kind) {}
public:
// Returns which subclass is this instance of.
AbstractFunctionKind getKind() const { return kind_; }
virtual ~AbstractFunction() = default;
// Returns the AbstractFunction as a FunctionDef.
virtual Status GetFunctionDef(FunctionDef**) = 0;
private:
const AbstractFunctionKind kind_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_

View File

@ -12,24 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
#define TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
#define TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
struct TFE_Op;
namespace tensorflow {
// Abstract interface to an operation.
class AbstractOperationInterface {
// This interface allows building and executing an operation in either
// tracing or immediate execution mode.
class AbstractOperation {
protected:
enum AbstractOperationKind { kTracing, kImmediateExecution };
explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {}
virtual ~AbstractOperation() {}
public:
AbstractOperationKind getKind() const { return kind_; }
// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
@ -38,7 +45,6 @@ class AbstractOperationInterface {
// clients MUST call Release() in order to destroy an instance of this class.
virtual void Release() = 0;
virtual void Clear() = 0;
virtual Status Reset(const char* op, const char* raw_device_name) = 0;
virtual const string& Name() const = 0;
@ -66,12 +72,10 @@ class AbstractOperationInterface {
// existing and given constraints will be performed.
virtual Status SetDeviceName(const char* name) = 0;
virtual Status AddInput(AbstractTensorHandleInterface* input) = 0;
virtual Status AddInputList(
absl::Span<AbstractTensorHandleInterface*> inputs) = 0;
virtual Status Execute(absl::Span<AbstractTensorHandleInterface*> retvals,
virtual Status AddInput(AbstractTensorHandle* input) = 0;
virtual Status AddInputList(absl::Span<AbstractTensorHandle*> inputs) = 0;
virtual Status Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) = 0;
virtual const tensorflow::OpDef* OpDef() const = 0;
virtual Status SetAttrString(const char* attr_name, const char* data,
size_t length) = 0;
@ -82,7 +86,7 @@ class AbstractOperationInterface {
virtual Status SetAttrShape(const char* attr_name, const int64_t* dims,
const int num_dims) = 0;
virtual Status SetAttrFunction(const char* attr_name,
const AbstractOperationInterface* value) = 0;
const AbstractOperation* value) = 0;
virtual Status SetAttrFunctionName(const char* attr_name, const char* value,
size_t length) = 0;
virtual Status SetAttrTensor(const char* attr_name,
@ -102,19 +106,25 @@ class AbstractOperationInterface {
virtual Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
const int* num_dims, int num_values) = 0;
virtual Status SetAttrFunctionList(
const char* attr_name,
absl::Span<const AbstractOperationInterface*> values) = 0;
const char* attr_name, absl::Span<const AbstractOperation*> values) = 0;
virtual Status InputLength(const char* input_name, int* length) = 0;
virtual Status OutputLength(const char* output_name, int* length) = 0;
// Experimental
virtual Status SetUseXla(bool enable) = 0;
protected:
virtual ~AbstractOperationInterface() {}
private:
const AbstractOperationKind kind_;
};
namespace internal {
struct AbstractOperationDeleter {
void operator()(AbstractOperation* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using AbstractOpPtr =
std::unique_ptr<AbstractOperation, internal::AbstractOperationDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
#endif // TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_

View File

@ -0,0 +1,61 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
#define TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
#include <memory>
namespace tensorflow {
// Abstract interface to a Tensor handle in either tracing or immediate
// execution mode.
class AbstractTensorHandle {
protected:
enum AbstractTensorHandleKind { kTracing, kImmediateExecution };
explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {}
virtual ~AbstractTensorHandle() {}
public:
AbstractTensorHandleKind getKind() const { return kind_; }
// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus this must be allocated on the heap and
// clients MUST call Release() in order to destroy an instance of this class.
virtual void Release() = 0;
private:
const AbstractTensorHandleKind kind_;
};
namespace internal {
struct AbstractTensorHandleDeleter {
void operator()(AbstractTensorHandle* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using AbstractTensorHandlePtr =
std::unique_ptr<AbstractTensorHandle,
internal::AbstractTensorHandleDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_

View File

@ -21,6 +21,8 @@ limitations under the License.
#include <string>
#include <vector>
#include "tensorflow/c/eager/abstract_tensor_handle.h"
// clang-format off
#include "tensorflow/core/platform/platform.h"
// clang-format on
@ -31,8 +33,8 @@ limitations under the License.
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
@ -1119,7 +1121,7 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) {
tensorflow::AbstractOperationInterface* new_op =
tensorflow::ImmediateExecutionOperation* new_op =
tensorflow::unwrap(ctx)->CreateOperation();
status->status = new_op->Reset(op_or_function_name, nullptr);
if (!status->status.ok()) {
@ -1164,7 +1166,9 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
TF_Status* status) {
status->status = tensorflow::unwrap(op)->AddInputList(
{tensorflow::unwrap(inputs), static_cast<size_t>(num_inputs)});
{reinterpret_cast<tensorflow::AbstractTensorHandle**>(
tensorflow::unwrap(inputs)),
static_cast<size_t>(num_inputs)});
}
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
@ -1324,7 +1328,9 @@ void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
const TFE_Op** value, int num_values) {
auto s = tensorflow::unwrap(op)->SetAttrFunctionList(
attr_name, {tensorflow::unwrap(value), static_cast<size_t>(num_values)});
attr_name, {reinterpret_cast<const tensorflow::AbstractOperation**>(
tensorflow::unwrap(value)),
static_cast<size_t>(num_values)});
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
@ -1368,7 +1374,10 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) {
status->status = tensorflow::unwrap(op)->Execute(
absl::MakeSpan(tensorflow::unwrap(retvals), *num_retvals), num_retvals);
absl::MakeSpan(reinterpret_cast<tensorflow::AbstractTensorHandle**>(
tensorflow::unwrap(retvals)),
*num_retvals),
num_retvals);
}
TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,

View File

@ -38,7 +38,7 @@ using tensorflow::string;
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status) {
if (op_to_reset) {
tensorflow::AbstractOperationInterface* op =
tensorflow::ImmediateExecutionOperation* op =
tensorflow::unwrap(op_to_reset);
op->Clear();
status->status = op->Reset(op_or_function_name, raw_device_name);
@ -60,6 +60,12 @@ void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
context->SetShouldStoreGraphs(false);
}
uint64_t TFE_GetContextId(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return context->GetContextId();
}
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
int64_t value) {
cell->cell.IncrementBy(value);

View File

@ -300,6 +300,14 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetLazyRemoteInputsCopy(
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetTfrt(TFE_ContextOptions*,
bool use_tfrt);
// Returns the context_id from the EagerContext which is used by the
// EagerService to maintain consistency between client and worker. The
// context_id is initialized with a dummy value and is later set when the worker
// is initialized (either locally or remotely). The context_id can change during
// the process lifetime although this should cause the worker to be
// reinitialized (e.g. cleared caches) as well.
TF_CAPI_EXPORT extern uint64_t TFE_GetContextId(TFE_Context* ctx);
// -----------------------------------------------------------------------------
// Cancellation APIs.

View File

@ -12,15 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
#define TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
#include <memory>
#include <vector>
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/numeric_types.h"
@ -34,16 +36,9 @@ namespace tensorflow {
//
// A context is responsible for creating key objects such as Tensors,
// TensorHandles & Operations.
class AbstractContextInterface {
class ImmediateExecutionContext : public AbstractContext {
public:
// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus clients MUST call Release() in order to
// destroy an instance of this class.
virtual void Release() = 0;
static constexpr AbstractContextKind kKind = kImmediateExecution;
// Optimized scalar creation functions
virtual AbstractTensorInterface* CreateInt64Scalar(int64 value) = 0;
virtual AbstractTensorInterface* CreateUint64Scalar(uint64 value) = 0;
@ -74,15 +69,15 @@ class AbstractContextInterface {
void* memory_releaser_arg) = 0;
// Create a handle to wrap and manage a Tensor
virtual AbstractTensorHandleInterface* CreateLocalHandle(
virtual ImmediateExecutionTensorHandle* CreateLocalHandle(
AbstractTensorInterface* t) = 0;
// Copy the handle to another device.
virtual AbstractTensorHandleInterface* CopyTensorHandleToDevice(
AbstractTensorHandleInterface* handle, const char* device_name,
virtual ImmediateExecutionTensorHandle* CopyTensorHandleToDevice(
ImmediateExecutionTensorHandle* handle, const char* device_name,
Status* status) = 0;
// Create an operation to perform op execution
virtual AbstractOperationInterface* CreateOperation() = 0;
ImmediateExecutionOperation* CreateOperation() override = 0;
// Returns whether the runtime is backed by TFRT or the legacy TF Eager
// Runtime. This is necessary to decouple runtime-dependent
@ -107,14 +102,26 @@ class AbstractContextInterface {
// be executed as an op. Return error if the function with the same name
// already exists.
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
// Remove a function. 'func' argument is the name of a previously added
// FunctionDef. The name is in fdef.signature.name.
virtual Status RemoveFunction(const string& func) = 0;
protected:
virtual ~AbstractContextInterface() {}
ImmediateExecutionContext() : AbstractContext(kKind) {}
~ImmediateExecutionContext() override {}
};
namespace internal {
struct ImmediateExecutionContextDeleter {
void operator()(ImmediateExecutionContext* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using ImmediateContextPtr =
std::unique_ptr<ImmediateExecutionContext,
internal::ImmediateExecutionContextDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_

View File

@ -0,0 +1,69 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_
#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/status.h"
struct TFE_Op;
namespace tensorflow {
// Abstract interface to an operation.
class ImmediateExecutionOperation : public AbstractOperation {
public:
static constexpr AbstractOperationKind kKind = kImmediateExecution;
virtual void Clear() = 0;
virtual const tensorflow::OpDef* OpDef() const = 0;
virtual Status InputLength(const char* input_name, int* length) = 0;
virtual Status OutputLength(const char* output_name, int* length) = 0;
// Experimental
virtual Status SetUseXla(bool enable) = 0;
protected:
ImmediateExecutionOperation() : AbstractOperation(kKind) {}
~ImmediateExecutionOperation() override {}
};
namespace internal {
struct ImmediateExecutionOperationDeleter {
void operator()(ImmediateExecutionOperation* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using ImmediateOpPtr =
std::unique_ptr<ImmediateExecutionOperation,
internal::ImmediateExecutionOperationDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_

View File

@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
#define TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_
#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
@ -30,15 +31,9 @@ namespace tensorflow {
// files. The interface lists the common functionality that must be provided by
// any concrete implementation. However, in cases where the true concrete class
// is needed a static_cast can be applied.
class AbstractTensorHandleInterface {
class ImmediateExecutionTensorHandle : public AbstractTensorHandle {
public:
// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus this must be allocated on the heap and
// clients MUST call Release() in order to destroy an instance of this class.
virtual void Release() = 0;
static constexpr AbstractTensorHandleKind kKind = kImmediateExecution;
// Returns tensor dtype.
virtual tensorflow::DataType DataType() const = 0;
@ -57,12 +52,27 @@ class AbstractTensorHandleInterface {
virtual AbstractTensorInterface* Resolve(Status* status) = 0;
// Return a copy of the handle.
virtual AbstractTensorHandleInterface* Copy() = 0;
virtual ImmediateExecutionTensorHandle* Copy() = 0;
protected:
virtual ~AbstractTensorHandleInterface() {}
ImmediateExecutionTensorHandle() : AbstractTensorHandle(kKind) {}
~ImmediateExecutionTensorHandle() override {}
};
namespace internal {
struct ImmediateExecutionTensorHandleDeleter {
void operator()(ImmediateExecutionTensorHandle* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using ImmediateTensorHandlePtr =
std::unique_ptr<ImmediateExecutionTensorHandle,
internal::ImmediateExecutionTensorHandleDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_

View File

@ -262,14 +262,14 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
components.reserve(underlying_devices_.size());
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
int64_t* device_id = new int64_t;
int32_t* device_id = new int32_t;
*device_id = device_index;
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
sizeof(int64_t),
TF_INT32, /*dims=*/nullptr, /*num_dims=*/0, device_id,
sizeof(int32_t),
[](void* data, size_t, void* arg) {
delete reinterpret_cast<int64_t*>(data);
delete reinterpret_cast<int32_t*>(data);
},
nullptr),
TF_DeleteTensor);
@ -283,7 +283,7 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT32);
TFE_TensorHandle* device_handle;
int num_outputs = 1;
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);

View File

@ -296,8 +296,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
TFE_DeleteTensorHandle(result_handle);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<int64_t>(components[0].get(), 0);
ExpectScalarEq<int64_t>(components[1].get(), 1);
ExpectScalarEq<int32_t>(components[0].get(), 0);
ExpectScalarEq<int32_t>(components[1].get(), 1);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);

View File

@ -16,7 +16,7 @@ limitations under the License.
#define TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/context_interface.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
// Wraps a pointer to a context implementation.
//
@ -28,7 +28,7 @@ typedef struct TFE_Context TFE_Context;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractContextInterface, TFE_Context);
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionContext, TFE_Context);
} // namespace tensorflow

View File

@ -16,7 +16,7 @@ limitations under the License.
#define TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
// Wraps a pointer to an operation implementation.
//
@ -28,8 +28,8 @@ typedef struct TFE_Op TFE_Op;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface, TFE_Op);
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface*, TFE_Op*);
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionOperation, TFE_Op);
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionOperation*, TFE_Op*);
} // namespace tensorflow

View File

@ -16,7 +16,7 @@ limitations under the License.
#define TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
// Wraps a pointer to a tensor handle implementation.
//
@ -28,9 +28,9 @@ typedef struct TFE_TensorHandle TFE_TensorHandle;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface,
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionTensorHandle,
TFE_TensorHandle);
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface*,
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionTensorHandle*,
TFE_TensorHandle*);
} // namespace tensorflow

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/types.h"
struct TF_StringStream {
@ -146,6 +147,10 @@ TF_StringStream* TF_GetLocalTempDirectories() {
return list;
}
char* TF_GetTempFileName(const char* extension) {
return strdup(::tensorflow::io::GetTempFilename(extension).c_str());
}
TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void) {
return ::tensorflow::Env::Default()->NowNanos();
}

View File

@ -152,6 +152,10 @@ TF_CAPI_EXPORT extern TF_StringStream* TF_GetChildren(const char* filename,
// The caller is responsible for freeing the list (see TF_StringStreamDone).
TF_CAPI_EXPORT extern TF_StringStream* TF_GetLocalTempDirectories(void);
// Creates a temporary file name with an extension.
// The caller is responsible for freeing the returned pointer.
TF_CAPI_EXPORT extern char* TF_GetTempFileName(const char* extension);
// Returns the number of nanoseconds since the Unix epoch.
TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void);

View File

@ -1,5 +1,5 @@
# Experimental gcs filesystem plugin.
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object", "tf_cc_test")
package(
licenses = ["notice"], # Apache 2.0
@ -19,14 +19,46 @@ tf_cc_shared_object(
cc_library(
name = "gcs_filesystem_impl",
srcs = ["gcs_filesystem.cc"],
hdrs = ["gcs_filesystem.h"],
copts = select({
"//conditions:default": [],
"//tensorflow:windows": get_win_copts(),
}),
deps = [
":gcs_helper",
"//tensorflow/c:env",
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "gcs_helper",
srcs = ["gcs_helper.cc"],
hdrs = ["gcs_helper.h"],
linkstatic = 1,
deps = [
"//tensorflow/c:env",
],
)
tf_cc_test(
name = "gcs_filesystem_test",
srcs = [
"gcs_filesystem.cc",
"gcs_filesystem_test.cc",
],
local_defines = ["TF_GCS_FILESYSTEM_TEST"],
tags = [
"manual",
"notap",
],
deps = [
":gcs_filesystem_impl",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core/platform:stacktrace_handler",
"//tensorflow/core/platform:test",
],
)

View File

@ -12,12 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h"
#include <stdlib.h>
#include <string.h>
#include "absl/strings/string_view.h"
#include "google/cloud/storage/client.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/env.h"
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
#include "tensorflow/c/tf_status.h"
// Implementation of a filesystem for GCS environments.
@ -36,8 +38,8 @@ static inline void TF_SetStatusFromGCSStatus(
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); }
static void ParseGCSPath(absl::string_view fname, bool object_empty_ok,
char** bucket, char** object, TF_Status* status) {
void ParseGCSPath(absl::string_view fname, bool object_empty_ok, char** bucket,
char** object, TF_Status* status) {
size_t scheme_end = fname.find("://") + 2;
if (fname.substr(0, scheme_end + 1) != "gs://") {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
@ -86,6 +88,20 @@ namespace tf_random_access_file {
// SECTION 2. Implementation for `TF_WritableFile`
// ----------------------------------------------------------------------------
namespace tf_writable_file {
typedef struct GCSFile {
const char* bucket;
const char* object;
gcs::Client* gcs_client; // not owned
TempFile outfile;
bool sync_need;
} GCSFile;
static void Cleanup(TF_WritableFile* file) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
plugin_memory_free(const_cast<char*>(gcs_file->bucket));
plugin_memory_free(const_cast<char*>(gcs_file->object));
delete gcs_file;
}
// TODO(vnvo2409): Implement later
@ -104,7 +120,7 @@ namespace tf_read_only_memory_region {
namespace tf_gcs_filesystem {
// TODO(vnvo2409): Add lazy-loading and customizing parameters.
static void Init(TF_Filesystem* filesystem, TF_Status* status) {
void Init(TF_Filesystem* filesystem, TF_Status* status) {
google::cloud::StatusOr<gcs::Client> client =
gcs::Client::CreateDefaultClient();
if (!client) {
@ -117,8 +133,54 @@ static void Init(TF_Filesystem* filesystem, TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
}
void Cleanup(TF_Filesystem* filesystem) {
plugin_memory_free(filesystem->plugin_filesystem);
}
// TODO(vnvo2409): Implement later
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) {
char* bucket;
char* object;
ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_client = static_cast<gcs::Client*>(filesystem->plugin_filesystem);
char* temp_file_name = TF_GetTempFileName("");
file->plugin_file = new tf_writable_file::GCSFile(
{bucket, object, gcs_client,
TempFile(temp_file_name, std::ios::binary | std::ios::out), true});
// We are responsible for freeing the pointer returned by TF_GetTempFileName
free(temp_file_name);
TF_SetStatus(status, TF_OK, "");
}
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) {
char* bucket;
char* object;
ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_client = static_cast<gcs::Client*>(filesystem->plugin_filesystem);
char* temp_file_name = TF_GetTempFileName("");
auto gcs_status = gcs_client->DownloadToFile(bucket, object, temp_file_name);
TF_SetStatusFromGCSStatus(gcs_status, status);
auto status_code = TF_GetCode(status);
if (status_code != TF_OK && status_code != TF_NOT_FOUND) {
return;
}
// If this file does not exist on server, we will need to sync it.
bool sync_need = (status_code == TF_NOT_FOUND);
file->plugin_file = new tf_writable_file::GCSFile(
{bucket, object, gcs_client,
TempFile(temp_file_name, std::ios::binary | std::ios::app), sync_need});
free(temp_file_name);
TF_SetStatus(status, TF_OK, "");
}
} // namespace tf_gcs_filesystem
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
@ -126,9 +188,17 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri);
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_gcs_filesystem::Init;
ops->filesystem_ops->cleanup = tf_gcs_filesystem::Cleanup;
ops->filesystem_ops->new_writable_file = tf_gcs_filesystem::NewWritableFile;
ops->filesystem_ops->new_appendable_file =
tf_gcs_filesystem::NewAppendableFile;
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {

View File

@ -0,0 +1,35 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_
#include "absl/strings/string_view.h"
#include "google/cloud/storage/client.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h"
void ParseGCSPath(absl::string_view fname, bool object_empty_ok, char** bucket,
char** object, TF_Status* status);
namespace tf_gcs_filesystem {
void Init(TF_Filesystem* filesystem, TF_Status* status);
void Cleanup(TF_Filesystem* filesystem);
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status);
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status);
} // namespace tf_gcs_filesystem
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_

View File

@ -0,0 +1,57 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/platform/stacktrace_handler.h"
#include "tensorflow/core/platform/test.h"
#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x))
namespace tensorflow {
namespace {
class GCSFilesystemTest : public ::testing::Test {
public:
void SetUp() override {
status_ = TF_NewStatus();
filesystem_ = new TF_Filesystem;
tf_gcs_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_) << "Can not initialize filesystem. "
<< TF_Message(status_);
}
void TearDown() override {
TF_DeleteStatus(status_);
tf_gcs_filesystem::Cleanup(filesystem_);
delete filesystem_;
}
protected:
TF_Filesystem* filesystem_;
TF_Status* status_;
};
// We have to add this test here because there must be at least one test.
// This test will be removed in the future.
TEST_F(GCSFilesystemTest, TestInit) { ASSERT_TF_OK(status_); }
} // namespace
} // namespace tensorflow
GTEST_API_ int main(int argc, char** argv) {
tensorflow::testing::InstallStacktraceHandler();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -0,0 +1,34 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
#include <stdio.h>
#include <fstream>
#include <string>
#include <utility>
TempFile::TempFile(const char* temp_file_name, std::ios::openmode mode)
: std::fstream(temp_file_name, mode), name_(temp_file_name) {}
TempFile::TempFile(TempFile&& rhs)
: std::fstream(std::move(rhs)), name_(std::move(rhs.name_)) {}
TempFile::~TempFile() {
std::fstream::close();
std::remove(name_.c_str());
}
const std::string TempFile::getName() const { return name_; }

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_
#include <fstream>
#include <string>
class TempFile : public std::fstream {
public:
// We should specify openmode each time we call TempFile.
TempFile(const char* temp_file_name, std::ios::openmode mode);
TempFile(TempFile&& rhs);
~TempFile() override;
const std::string getName() const;
private:
const std::string name_;
};
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_

View File

@ -3,6 +3,10 @@
# Targets in this directory are pure C++ "Classes" underlying the C API types
# under tf/c/experimental/saved_model/public/. They are subject to change and
# have visibility limited to Tensorflow's implementation only.
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
)
package(
default_visibility = [
@ -23,8 +27,8 @@ cc_library(
],
deps = [
":function_metadata",
"//tensorflow/c/eager:operation_interface",
"//tensorflow/c/eager:tensor_handle_interface",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:protos_all_cc",
],
)
@ -47,6 +51,22 @@ cc_library(
],
)
cc_library(
name = "saved_model_utils",
srcs = [
"saved_model_utils.cc",
],
hdrs = [
"saved_model_utils.h",
],
deps = [
"//tensorflow/c:tf_tensor_internal",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "tf_saved_model_impl",
srcs = [
@ -84,3 +104,26 @@ filegroup(
],
visibility = ["//tensorflow/core:__pkg__"],
)
tf_cc_test(
name = "saved_model_utils_test",
srcs = [
"saved_model_utils_test.cc",
],
deps = [
":saved_model_utils",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime:core_cpu_lib",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:core",
],
)

View File

@ -15,12 +15,12 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
namespace tensorflow {
const std::vector<tensorflow::AbstractTensorHandleInterface*>&
const std::vector<tensorflow::ImmediateExecutionTensorHandle*>&
ConcreteFunction::GetCaptures() const {
return captures_;
}

View File

@ -18,8 +18,8 @@ limitations under the License.
#include <vector>
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/core/framework/function.pb.h"
@ -38,15 +38,15 @@ class ConcreteFunction {
virtual ~ConcreteFunction() = 0;
// This method returns the "Call" Op used to execute the function.
virtual AbstractOperationInterface* GetCallOp() = 0;
virtual ImmediateExecutionOperation* GetCallOp() = 0;
const std::vector<tensorflow::AbstractTensorHandleInterface*>& GetCaptures()
const std::vector<tensorflow::ImmediateExecutionTensorHandle*>& GetCaptures()
const;
const FunctionMetadata& GetFunctionMetadata() const;
private:
FunctionMetadata metadata_;
std::vector<tensorflow::AbstractTensorHandleInterface*> captures_;
std::vector<tensorflow::ImmediateExecutionTensorHandle*> captures_;
FunctionDef* function_;
};

View File

@ -14,44 +14,6 @@ package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "owned_eager_op",
hdrs = [
"owned_eager_op.h",
],
deps = [
"//tensorflow/c/eager:operation_interface",
],
)
cc_library(
name = "owned_tensor_handle",
hdrs = [
"owned_tensor_handle.h",
],
deps = [
"//tensorflow/c/eager:tensor_handle_interface",
"//tensorflow/core/common_runtime/eager:tensor_handle",
],
)
cc_library(
name = "owned_eager_context",
hdrs = ["owned_eager_context.h"],
deps = [
"//tensorflow/c/eager:context_interface",
"//tensorflow/core/common_runtime/eager:context",
],
)
cc_library(
name = "owned_tensor",
hdrs = ["owned_tensor.h"],
deps = [
"//tensorflow/c:tensor_interface",
],
)
cc_library(
name = "variable_ops",
srcs = [
@ -61,10 +23,11 @@ cc_library(
"variable_ops.h",
],
deps = [
":owned_eager_op",
":owned_tensor_handle",
"//tensorflow/c/eager:context_interface",
"//tensorflow/c/eager:tensor_handle_interface",
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@ -78,10 +41,11 @@ tf_cc_test(
"variable_ops_test.cc",
],
deps = [
":owned_eager_context",
":owned_tensor",
":owned_tensor_handle",
":variable_ops",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",

View File

@ -1,54 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_EAGER_CONTEXT_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_EAGER_CONTEXT_H_
#include <memory>
#include "tensorflow/c/eager/context_interface.h"
#include "tensorflow/core/common_runtime/eager/context.h"
namespace tensorflow {
namespace internal {
struct AbstractContextInterfaceDeleter {
void operator()(AbstractContextInterface* p) const {
if (p != nullptr) {
p->Release();
}
}
};
struct EagerContextDeleter {
void operator()(EagerContext* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using AbstractContextPtr =
std::unique_ptr<AbstractContextInterface,
internal::AbstractContextInterfaceDeleter>;
using EagerContextPtr =
std::unique_ptr<EagerContext, internal::EagerContextDeleter>;
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_EAGER_CONTEXT_H_

View File

@ -1,54 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_TENSOR_HANDLE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_TENSOR_HANDLE_H_
#include <memory>
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
namespace tensorflow {
namespace internal {
struct TensorHandleDeleter {
void operator()(TensorHandle* p) const {
if (p != nullptr) {
p->Release();
}
}
};
struct AbstractTensorHandleDeleter {
void operator()(AbstractTensorHandleInterface* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using TensorHandlePtr =
std::unique_ptr<TensorHandle, internal::TensorHandleDeleter>;
using AbstractTensorHandlePtr =
std::unique_ptr<AbstractTensorHandleInterface,
internal::AbstractTensorHandleDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_TENSOR_HANDLE_H_

View File

@ -16,9 +16,11 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/context_interface.h"
#include "tensorflow/c/experimental/saved_model/core/ops/owned_eager_op.h"
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@ -32,10 +34,10 @@ namespace internal {
static const char kNoSharingResourceID[] =
"cd2c89b7-88b7-44c8-ad83-06c2a9158347";
Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx,
Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
DataType dtype, TensorShape shape,
AbstractTensorHandlePtr* handle) {
AbstractOpPtr varhandle_op = AbstractOpPtr(ctx->CreateOperation());
ImmediateTensorHandlePtr* handle) {
ImmediateOpPtr varhandle_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(varhandle_op->Reset("VarHandleOp", nullptr));
TF_RETURN_IF_ERROR(varhandle_op->SetAttrType("dtype", dtype));
@ -50,17 +52,57 @@ Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx,
TF_RETURN_IF_ERROR(varhandle_op->SetAttrString(
"shared_name", kNoSharingResourceID, strlen(kNoSharingResourceID)));
AbstractTensorHandleInterface* var_handle = nullptr;
AbstractTensorHandle* var_handle = nullptr;
int num_retvals = 1;
TF_RETURN_IF_ERROR(varhandle_op->Execute(
absl::MakeSpan(&var_handle, num_retvals), &num_retvals));
handle->reset(var_handle);
AbstractTensorHandlePtr owned_var_handle(var_handle);
if (owned_var_handle->getKind() != ImmediateExecutionTensorHandle::kKind) {
return errors::Internal("Unexpected tensor handle kind.");
}
handle->reset(reinterpret_cast<ImmediateExecutionTensorHandle*>(
owned_var_handle.release()));
return Status();
}
Status DestroyResource(AbstractContextInterface* ctx,
AbstractTensorHandleInterface* handle) {
AbstractOpPtr destroy_op = AbstractOpPtr(ctx->CreateOperation());
Status AssignVariable(ImmediateExecutionContext* ctx,
ImmediateExecutionTensorHandle* variable_handle,
DataType dtype, ImmediateExecutionTensorHandle* value) {
ImmediateOpPtr assign_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(assign_op->Reset("AssignVariableOp", nullptr));
TF_RETURN_IF_ERROR(assign_op->SetAttrType("dtype", dtype));
TF_RETURN_IF_ERROR(assign_op->AddInput(variable_handle));
TF_RETURN_IF_ERROR(assign_op->AddInput(value));
int num_retvals = 0;
TF_RETURN_IF_ERROR(assign_op->Execute({}, &num_retvals));
return Status();
}
Status ReadVariable(ImmediateExecutionContext* ctx,
ImmediateExecutionTensorHandle* variable_handle,
DataType dtype, ImmediateTensorHandlePtr* output) {
ImmediateOpPtr read_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(read_op->Reset("ReadVariableOp", nullptr));
TF_RETURN_IF_ERROR(read_op->SetAttrType("dtype", dtype));
TF_RETURN_IF_ERROR(read_op->AddInput(variable_handle));
AbstractTensorHandle* value = nullptr;
int num_retvals = 1;
TF_RETURN_IF_ERROR(
read_op->Execute(absl::MakeSpan(&value, num_retvals), &num_retvals));
AbstractTensorHandlePtr owned_value(value);
if (owned_value->getKind() != ImmediateExecutionTensorHandle::kKind) {
return errors::Internal("Unexpected tensor handle kind.");
}
output->reset(
reinterpret_cast<ImmediateExecutionTensorHandle*>(owned_value.release()));
return Status();
}
Status DestroyResource(ImmediateExecutionContext* ctx,
ImmediateExecutionTensorHandle* handle) {
ImmediateOpPtr destroy_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(destroy_op->Reset("DestroyResourceOp", nullptr));
TF_RETURN_IF_ERROR(destroy_op->SetAttrBool("ignore_lookup_error", true));
TF_RETURN_IF_ERROR(destroy_op->AddInput(handle));

View File

@ -16,9 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H
#include "tensorflow/c/eager/context_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
@ -30,15 +29,31 @@ namespace internal {
// TensorHandle associated with the variable. This is equivalent to creating an
// unitialized TF2 tf.Variable.
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L1867-L1872
Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx,
Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
DataType dtype, TensorShape shape,
AbstractTensorHandlePtr* handle);
ImmediateTensorHandlePtr* handle);
// Executes an AssignVariableOp using `ctx`, assigning the variable associated
// with `variable_handle` with `value`. `dtype` must be the datatype of the
// underlying variable for `variable_handle`. Note that it is illegal to assign
// a variable to a Tensor with a different dtype than what the variable was
// created with.
Status AssignVariable(ImmediateExecutionContext* ctx,
ImmediateExecutionTensorHandle* variable_handle,
DataType dtype, ImmediateExecutionTensorHandle* value);
// Executes a ReadVariableOp using `ctx`. This reads the underlying variable
// value of `variable_handle` and copies the value to `output`. `dtype` must be
// the dtype of the variable associated with `variable_handle`.
Status ReadVariable(ImmediateExecutionContext* ctx,
ImmediateExecutionTensorHandle* variable_handle,
DataType dtype, ImmediateTensorHandlePtr* output);
// Executes DestroyResourceOp on `handle`, using `ctx`. This is equivalent to
// the cleanup that occurs in a tf.Variable's EagerResourceDeleter:
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L289-L290
Status DestroyResource(AbstractContextInterface* ctx,
AbstractTensorHandleInterface* handle);
Status DestroyResource(ImmediateExecutionContext* ctx,
ImmediateExecutionTensorHandle* handle);
} // namespace internal
} // namespace tensorflow

View File

@ -17,9 +17,8 @@ limitations under the License.
#include <memory>
#include "tensorflow/c/experimental/saved_model/core/ops/owned_eager_context.h"
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor.h"
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/tensor.h"
@ -30,6 +29,13 @@ limitations under the License.
namespace tensorflow {
namespace {
ImmediateTensorHandlePtr CreateScalarTensorHandle(EagerContext* context,
float value) {
AbstractTensorPtr tensor(context->CreateFloatScalar(value));
ImmediateTensorHandlePtr handle(context->CreateLocalHandle(tensor.get()));
return handle;
}
class VariableOpsTest : public ::testing::Test {
public:
VariableOpsTest()
@ -55,7 +61,7 @@ class VariableOpsTest : public ::testing::Test {
// Sanity check for variable creation
TEST_F(VariableOpsTest, CreateVariableSuccessful) {
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
AbstractTensorHandlePtr handle;
ImmediateTensorHandlePtr handle;
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
context(), DT_FLOAT, {}, &handle));
// The created TensorHandle should be a DT_Resource
@ -65,7 +71,7 @@ TEST_F(VariableOpsTest, CreateVariableSuccessful) {
// Sanity check for variable destruction
TEST_F(VariableOpsTest, DestroyVariableSuccessful) {
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
AbstractTensorHandlePtr handle;
ImmediateTensorHandlePtr handle;
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
context(), DT_FLOAT, {}, &handle));
@ -73,5 +79,28 @@ TEST_F(VariableOpsTest, DestroyVariableSuccessful) {
TF_EXPECT_OK(internal::DestroyResource(context(), handle.get()));
}
// Sanity check for handle assignment and reading
TEST_F(VariableOpsTest, AssignVariableAndReadSuccessful) {
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
ImmediateTensorHandlePtr variable;
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
context(), DT_FLOAT, {}, &variable));
// Create a Scalar float TensorHandle with value 42, and assign it to
// the variable.
ImmediateTensorHandlePtr my_value = CreateScalarTensorHandle(context(), 42.0);
TF_EXPECT_OK(internal::AssignVariable(context(), variable.get(), DT_FLOAT,
my_value.get()));
// Read back the value from the variable, and check that it is 42.
ImmediateTensorHandlePtr read_value_handle;
TF_EXPECT_OK(internal::ReadVariable(context(), variable.get(), DT_FLOAT,
&read_value_handle));
Status status;
AbstractTensorPtr read_value(read_value_handle->Resolve(&status));
TF_EXPECT_OK(status);
EXPECT_FLOAT_EQ(42.0, *static_cast<float*>(read_value->Data()));
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
# This package contains classes corresponding to Revived SavedObjectGraph types
# used by SavedModel. See https://cs.opensource.google/tensorflow/tensorflow/+/c575e2ba93c442121d98d3f125d83fed1339924d:tensorflow/core/protobuf/saved_object_graph.proto;l=56-62
package(
default_visibility = [
# Restricting visibility for now
"//tensorflow/c/experimental/saved_model/core:__pkg__",
],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "constant",
srcs = [
"constant.cc",
],
hdrs = [
"constant.h",
],
deps = [
":tensorhandle_convertible",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:tensor_handle",
],
)
cc_library(
name = "tensorhandle_convertible",
hdrs = [
"tensorhandle_convertible.h",
],
deps = [
"//tensorflow/c/eager:immediate_execution_tensor_handle",
],
)

View File

@ -0,0 +1,46 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include <memory>
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
Constant::Constant(ImmediateTensorHandlePtr handle)
: TensorHandleConvertible(std::move(handle)) {}
Status Constant::Create(ImmediateExecutionContext* ctx,
AbstractTensorInterface* tensor,
std::unique_ptr<Constant>* output) {
ImmediateExecutionTensorHandle* handle = ctx->CreateLocalHandle(tensor);
if (handle == nullptr) {
return errors::Internal("Failed to convert tensor to tensorhandle");
}
output->reset(new Constant(ImmediateTensorHandlePtr(handle)));
return Status();
}
} // namespace tensorflow

View File

@ -0,0 +1,55 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_CONSTANT_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_CONSTANT_H_
#include <memory>
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/tensor.pb.h"
namespace tensorflow {
// This class corresponds to python's tf.constant, which is effectively a
// TensorHandle explicitly initialized to some value.
// For now this doesn't do much beyond wrap Context's CreateLocalHandle method,
// and offer a subclass of TensorHandleConvertible. Note that similar to
// the python's eager mode logic, we bypass calling the "Const" op:
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/framework/constant_op.py#L301
class Constant : public TensorHandleConvertible {
public:
static Status Create(ImmediateExecutionContext* ctx,
AbstractTensorInterface* tensor,
std::unique_ptr<Constant>* output);
// RevivedConstant is movable, but not copyable.
Constant(Constant&& other) = default;
Constant& operator=(Constant&& other) = default;
~Constant() override = default;
private:
explicit Constant(ImmediateTensorHandlePtr handle);
Constant(const Constant&) = delete;
Constant& operator=(const Constant&) = delete;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_CONSTANT_H_

View File

@ -0,0 +1,49 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSORHANDLE_CONVERTIBLE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSORHANDLE_CONVERTIBLE_H_
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
namespace tensorflow {
// A common interface for objects that can be converted to a TensorHandle.
// Examples of objects that implement this include Variables, Constants, Assets,
// etc. This is used to convert captured objects into a ConcreteFunction's
// captured TensorHandles:
// https://github.com/tensorflow/tensorflow/blob/676a68963ea4b64fe479b9cede06aa8f5b290ab8/tensorflow/python/saved_model/load.py#L229-L240
class TensorHandleConvertible {
public:
explicit TensorHandleConvertible(ImmediateTensorHandlePtr handle)
: handle_(std::move(handle)) {}
ImmediateExecutionTensorHandle* handle() { return handle_.get(); }
// TensorHandleConvertible is movable, but not copyable.
TensorHandleConvertible(TensorHandleConvertible&& other) = default;
TensorHandleConvertible& operator=(TensorHandleConvertible&& other) = default;
virtual ~TensorHandleConvertible() = default;
protected:
TensorHandleConvertible(const TensorHandleConvertible&) = delete;
TensorHandleConvertible& operator=(const TensorHandleConvertible&) = delete;
ImmediateTensorHandlePtr handle_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSORHANDLE_CONVERTIBLE_H_

View File

@ -13,30 +13,26 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_EAGER_OP_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_EAGER_OP_H_
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
#include <memory>
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/tf_tensor_internal.h"
namespace tensorflow {
namespace internal {
struct AbstractOperationInterfaceDeleter {
void operator()(AbstractOperationInterface* p) const {
if (p != nullptr) {
p->Release();
}
Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
const TensorProto& proto,
std::unique_ptr<Constant>* output) {
tensorflow::Tensor tensor;
bool parse_result = tensor.FromProto(proto);
if (!parse_result) {
return errors::Internal("Failed to parse tensor from tensorproto");
}
};
TensorInterface tensor_interface(std::move(tensor));
return Constant::Create(ctx, &tensor_interface, output);
}
} // namespace internal
using AbstractOpPtr =
std::unique_ptr<AbstractOperationInterface,
internal::AbstractOperationInterfaceDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_EAGER_OP_H_

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_
// Some internal utility functions for the SavedModelAPI, factored out into a
// separately unit-testable header.
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/core/framework/tensor.pb.h"
namespace tensorflow {
namespace internal {
// Load a TensorProto into a tensorflow::Constant. This is similar to the
// constant loading logic in python:
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L437
Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
const TensorProto& proto,
std::unique_ptr<Constant>* output);
} // namespace internal
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_

View File

@ -0,0 +1,199 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
#include <string.h>
#include <memory>
#include <vector>
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace {
// Converts a tensorflow::DatatypeSet to std::vector<DataType>.
// This is needed for GTest's ::testing::ValuesIn, since
// DataTypeSet doesn't fullfill all the constraints of an STL-like iterable.
std::vector<DataType> DataTypeSetToVector(DataTypeSet set) {
std::vector<DataType> result;
result.reserve(set.size());
for (DataType dt : set) {
result.push_back(dt);
}
return result;
}
// Returns a vector of shapes intended to be "interesting" test cases.
std::vector<std::vector<int64>> InterestingShapes() {
std::vector<std::vector<int64>> interesting_shapes;
interesting_shapes.push_back({}); // Scalar
interesting_shapes.push_back({10}); // 1D Vector
interesting_shapes.push_back({3, 3}); // 2D Matrix
interesting_shapes.push_back({1, 4, 6, 10}); // Higher Dimension Tensor
return interesting_shapes;
}
// Fills a numeric tensor with `value`.
void FillNumericTensor(Tensor* tensor, int8 value) {
switch (tensor->dtype()) {
#define CASE(type) \
case DataTypeToEnum<type>::value: { \
const auto& flattened = tensor->flat<type>(); \
for (int i = 0; i < tensor->NumElements(); ++i) { \
flattened(i) = value; \
} \
break; \
}
TF_CALL_INTEGRAL_TYPES(CASE);
TF_CALL_double(CASE);
TF_CALL_float(CASE);
#undef CASE
default:
CHECK(false) << "Unsupported data type: "
<< DataTypeString(tensor->dtype());
break;
}
}
// Checks the underlying data is equal for the buffers for two numeric tensors.
// Note: The caller must ensure to check that the dtypes and sizes of the
// underlying buffers are the same before calling this.
void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
void* b) {
switch (dtype) {
#define CASE(type) \
case DataTypeToEnum<type>::value: { \
type* typed_a = static_cast<type*>(a); \
type* typed_b = static_cast<type*>(b); \
for (int64 i = 0; i < num_elements; ++i) { \
if (DataTypeIsFloating(dtype)) { \
EXPECT_FLOAT_EQ(typed_a[i], typed_b[i]); \
} else { \
EXPECT_EQ(typed_a[i], typed_b[i]); \
} \
} \
break; \
}
TF_CALL_INTEGRAL_TYPES(CASE);
TF_CALL_double(CASE);
TF_CALL_float(CASE);
#undef CASE
default:
CHECK(false) << "Unsupported data type: " << DataTypeString(dtype);
}
}
class ConstantTest : public ::testing::TestWithParam<
std::tuple<DataType, std::vector<int64>, bool>> {
public:
ConstantTest()
: device_mgr_(std::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
"CPU", {}, "/job:localhost/replica:0/task:0"))),
ctx_(new EagerContext(
SessionOptions(),
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
tensorflow::ContextMirroringPolicy::MIRRORING_NONE,
/* async= */ false,
/* lazy_copy_function_remote_inputs= */ false, device_mgr_.get(),
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
/* custom_kernel_creator= */ nullptr,
/* cluster_flr= */ nullptr)) {}
EagerContext* context() { return ctx_.get(); }
private:
std::unique_ptr<StaticDeviceMgr> device_mgr_;
EagerContextPtr ctx_;
};
// Basic sanity check that roundtripping a Tensor->Tensorproto->Constant
// preserves values.
TEST_P(ConstantTest, CreateConstantSuccessful) {
// Get test parameters
auto& test_params = GetParam();
DataType dtype = std::get<0>(test_params);
TensorShape shape(std::get<1>(test_params));
bool tensorproto_use_tensor_content = std::get<2>(test_params);
// Construct a Tensor with the given dtype + shape
Tensor expected(dtype, shape);
FillNumericTensor(&expected, 42);
// Serialize it to a Tensorproto
TensorProto proto;
if (tensorproto_use_tensor_content) {
expected.AsProtoTensorContent(&proto);
} else {
expected.AsProtoField(&proto);
}
// Revival should succeed w/o errors
std::unique_ptr<Constant> revived;
TF_EXPECT_OK(internal::TensorProtoToConstant(context(), proto, &revived));
// The revived tensorhandle should have the exact same dtype, shape, +
// approx equivalent data to the original.
ImmediateExecutionTensorHandle* handle = revived->handle();
Status status;
AbstractTensorPtr revived_tensor(handle->Resolve(&status));
TF_EXPECT_OK(status) << "Failed to convert tensorhandle to tensor";
EXPECT_EQ(revived_tensor->Type(), expected.dtype());
EXPECT_EQ(revived_tensor->NumElements(), expected.NumElements());
EXPECT_EQ(revived_tensor->NumDims(), expected.dims());
for (int i = 0; i < expected.dims(); ++i) {
EXPECT_EQ(revived_tensor->Dim(i), expected.dim_size(i));
}
CheckBufferDataIsEqual(expected.dtype(), expected.NumElements(),
revived_tensor->Data(), expected.data());
}
// Test against combinations of tensors that are
// 1. Varying dtypes
// 2. Varying shapes
// 3. TensorProto serialized using tensor_content vs repeated type
INSTANTIATE_TEST_SUITE_P(
ConstantIntegerDtypesTest, ConstantTest,
::testing::Combine(
::testing::ValuesIn(DataTypeSetToVector(kDataTypeIsInteger)),
::testing::ValuesIn(InterestingShapes()),
::testing::Values(false, true)));
INSTANTIATE_TEST_SUITE_P(
ConstantFloatingDtypesTest, ConstantTest,
::testing::Combine(::testing::Values(DT_FLOAT, DT_DOUBLE),
::testing::ValuesIn(InterestingShapes()),
::testing::Values(false, true)));
} // namespace
} // namespace tensorflow

View File

@ -178,7 +178,7 @@ cc_library(
":tensorhandle_list_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:tensor_handle_interface",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/eager:tfe_tensorhandle_internal",
],
)
@ -190,7 +190,7 @@ cc_library(
],
deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/c/eager:tensor_handle_interface",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
],
)

View File

@ -17,7 +17,7 @@ limitations under the License.
#include <stddef.h>
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"

View File

@ -19,7 +19,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
// Internal structures used by the SavedModel C API. These are likely to
// change and should not be depended on.
@ -29,7 +29,7 @@ typedef struct TF_TensorHandleList TF_TensorHandleList;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(
std::vector<tensorflow::AbstractTensorHandleInterface*>,
std::vector<tensorflow::ImmediateExecutionTensorHandle*>,
TF_TensorHandleList)
} // namespace tensorflow

View File

@ -54,6 +54,20 @@ class AbstractTensorInterface {
virtual ~AbstractTensorInterface() {}
};
namespace internal {
struct AbstractTensorInterfaceDeleter {
void operator()(AbstractTensorInterface* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using AbstractTensorPtr =
std::unique_ptr<AbstractTensorInterface,
internal::AbstractTensorInterfaceDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_TENSOR_INTERFACE_H_

View File

@ -259,9 +259,6 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) {
RunTest(x, x_init_value, y, y_shape);
}
// TODO(rocm):
// Re-enable this test once 3D pooling is supported on ROCm platform
#ifndef TENSORFLOW_USE_ROCM
TEST_F(NNGradTest, MaxPool3DGradHelper) {
TensorShape x_shape({1, 3, 3, 3, 1});
TensorShape y_shape({1, 1, 1, 1, 1});
@ -274,7 +271,6 @@ TEST_F(NNGradTest, MaxPool3DGradHelper) {
SetRandomValuesForMaxPooling<float>(&x_init_value);
RunTest(x, x_init_value, y, y_shape);
}
#endif
TEST_F(NNGradTest, AvgPoolGradHelper) {
TensorShape x_shape({1, 2, 2, 1});
@ -287,9 +283,6 @@ TEST_F(NNGradTest, AvgPoolGradHelper) {
RunTest(x, x_shape, y, y_shape);
}
// TODO(rocm):
// Re-enable this test once 3D pooling is supported on ROCm platform
#ifndef TENSORFLOW_USE_ROCM
TEST_F(NNGradTest, AvgPool3DGradHelper) {
TensorShape x_shape({1, 3, 3, 3, 1});
TensorShape y_shape({1, 1, 1, 1, 1});
@ -300,7 +293,6 @@ TEST_F(NNGradTest, AvgPool3DGradHelper) {
auto y = AvgPool3D(scope_, x, ksize, strides, "SAME");
RunTest(x, x_shape, y, y_shape);
}
#endif
TEST_F(NNGradTest, LRN) {
TensorShape x_shape({1, 1, 2, 1});

View File

@ -69,6 +69,7 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"@llvm-project//llvm:ARMCodeGen", # fixdeps: keep
"@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep
"@llvm-project//llvm:Support",
"@llvm-project//llvm:Target",
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
"//tensorflow/core:regexp_internal",

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "absl/base/call_once.h"
#include "llvm-c/Target.h"
#include "llvm/Support/ManagedStatic.h"
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/compiler/aot/quantize.h"

View File

@ -108,8 +108,7 @@ class XlaExecutableClosure {
explicit XlaExecutableClosure(
xla::LocalClient* client, xla::LocalExecutable* executable,
const XlaCompiler::CompilationResult* compilation_result,
std::map<int, OptionalTensor> resource_var_snapshots,
int num_constant_args)
ResourceVarsSnapshot resource_var_snapshots, int num_constant_args)
: client_(client),
executable_(executable),
compilation_result_(compilation_result),
@ -124,7 +123,7 @@ class XlaExecutableClosure {
const XlaCompiler::CompilationResult* compilation_result() const {
return compilation_result_;
}
const std::map<int, OptionalTensor>& resource_var_snapshots() const {
const ResourceVarsSnapshot& resource_var_snapshots() const {
return resource_var_snapshots_;
}
int num_constant_args() const { return num_constant_args_; }
@ -133,7 +132,7 @@ class XlaExecutableClosure {
xla::LocalClient* client_;
xla::LocalExecutable* executable_;
const XlaCompiler::CompilationResult* compilation_result_;
std::map<int, OptionalTensor> resource_var_snapshots_;
ResourceVarsSnapshot resource_var_snapshots_;
int num_constant_args_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure);
@ -276,10 +275,10 @@ static Status BuildCompilationCache(OpKernelContext* ctx,
static Status CompileToLocalExecutable(
OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
const XlaPlatformInfo& platform_info, absl::Span<const int> resources,
const XlaPlatformInfo& platform_info,
absl::Span<VariableInfo const> variable_infos,
absl::Span<const int> constants, bool lazy, xla::LocalClient** client,
std::map<int, OptionalTensor>* variables,
const XlaCompiler::CompilationResult** kernel,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable) {
// We store information about the JIT-compiled XLA computation
// in the ResourceMgr.
@ -299,7 +298,6 @@ static Status CompileToLocalExecutable(
// this is more obviously correct.)
core::ScopedUnref cache_ref(cache);
TF_RETURN_IF_ERROR(SnapshotResourceVariables(ctx, resources, variables));
*client = static_cast<xla::LocalClient*>(cache->client());
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
@ -337,11 +335,11 @@ static Status CompileToLocalExecutable(
std::vector<XlaCompiler::Argument> args;
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
constant_args, *variables, ctx, &args));
constant_args, variable_infos, ctx, &args));
return cache->Compile(options, function, args, compile_options,
lazy ? XlaCompilationCache::CompileMode::kLazy
: XlaCompilationCache::CompileMode::kStrict,
kernel, executable);
compilation_result, executable);
}
void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
@ -349,16 +347,22 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
<< Canonicalize(function_.name(), AttrSlice(&function_.attr()));
xla::LocalClient* client;
const XlaCompiler::CompilationResult* kernel;
const XlaCompiler::CompilationResult* compilation_result;
xla::LocalExecutable* executable;
std::map<int, OptionalTensor> variables;
ResourceVarsSnapshot variables;
{
std::vector<VariableInfo> variable_infos;
OP_REQUIRES_OK(
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
Status s = CompileToLocalExecutable(
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_,
resources_, constants_, /*lazy=*/false, &client, &variables, &kernel,
&executable);
variable_infos, constants_, /*lazy=*/false, &client,
&compilation_result, &executable);
OP_REQUIRES_OK(ctx, s);
OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_,
variable_infos, &variables));
}
se::Stream* stream =
@ -373,7 +377,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
client, allocator,
/*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
platform_info_.UseMultipleStreams());
launch_context.PopulateInputs(ctx, kernel, variables,
launch_context.PopulateInputs(ctx, compilation_result, variables,
/*missing_ctx_input_prefix=*/0);
// Execute the computation.
@ -413,7 +417,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
executable->executable()->module().input_output_alias_config();
OP_REQUIRES_OK(
ctx, launch_context.PopulateOutputs(
ctx, kernel, run_result.ConsumeValueOrDie(),
ctx, compilation_result, run_result.ConsumeValueOrDie(),
/*missing_ctx_input_prefix=*/0, input_output_alias, variables));
VLOG(1) << "Done";
}
@ -494,7 +498,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
xla::LocalClient* client;
const XlaCompiler::CompilationResult* kernel;
xla::LocalExecutable* executable;
std::map<int, OptionalTensor> variables;
ResourceVarsSnapshot variables;
bool cannot_compile_cluster;
{
@ -506,9 +510,16 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
cannot_compile_cluster) {
executable = nullptr;
} else {
std::vector<VariableInfo> variable_infos;
OP_REQUIRES_OK(
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
Status status = CompileToLocalExecutable(
ctx, function_, has_ref_vars_, platform_info_, resources_, constants_,
/*lazy=*/!must_compile_, &client, &variables, &kernel, &executable);
ctx, function_, has_ref_vars_, platform_info_, variable_infos,
constants_,
/*lazy=*/!must_compile_, &client, &kernel, &executable);
OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_,
variable_infos, &variables));
if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
OP_REQUIRES_OK(ctx, status);
}

View File

@ -1837,7 +1837,7 @@ absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable() {
"ConcatOffset", "Const", "MirrorPad", "Pack", "Pad", "PadV2", "Reverse",
"ReverseV2", "ReverseSequence", "Slice", "Split", "SplitV",
"StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign",
"Tile", "Transpose", "InvertPermutation", "Unpack"}}};
"Tile", "Transpose", "InvertPermutation", "Unpack", "DeviceIndex"}}};
// clang-format on
return result;
}

View File

@ -28,32 +28,23 @@ limitations under the License.
namespace tensorflow {
namespace {
std::map<int, OptionalTensor> GetVariables(OpKernelContext* ctx) {
std::map<int, OptionalTensor> variables;
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
// Returns argument indices corresponding to the resource variable inputs of
// kernel context `ctx`.
static std::vector<int> GetResourceVariableIndices(OpKernelContext* ctx) {
std::vector<int> out;
for (int64 i = 0; i < ctx->num_inputs(); i++) {
if (ctx->input(i).dtype() == DT_RESOURCE) {
core::RefCountPtr<Var> variable;
ResourceHandle handle = HandleFromInput(ctx, i);
OptionalTensor& optional = variables[i];
optional.name = handle.name();
if (LookupResource(ctx, handle, &variable).ok()) {
tf_shared_lock lock(*variable->mu());
optional.present = true;
optional.value = *variable->tensor();
}
out.push_back(i);
}
}
return variables;
return out;
}
} // namespace
Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
const XlaDevice::Metadata& metadata,
const XlaCompiler::CompilationResult* result,
xla::LocalExecutable* executable) {
std::map<int, OptionalTensor> variables = GetVariables(ctx);
xla::LocalExecutable* executable,
const ResourceVarsSnapshot& variable_args) {
xla::LocalClient* client = metadata.client();
// Builds an XLA allocator for the device.
@ -62,7 +53,7 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
/*allocate_xla_tensors=*/true,
/*use_multiple_streams=*/metadata.UseMultipleStreams());
launch_context.PopulateInputs(ctx, result, variables,
launch_context.PopulateInputs(ctx, result, variable_args,
/*missing_ctx_input_prefix=*/0);
se::Stream* stream =
@ -87,7 +78,7 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
executable->executable()->module().input_output_alias_config();
TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
ctx, result, run_result.ConsumeValueOrDie(),
/*missing_ctx_input_prefix=*/0, input_output_alias, variables));
/*missing_ctx_input_prefix=*/0, input_output_alias, variable_args));
return Status::OK();
}
@ -115,7 +106,7 @@ Status XlaCompileOnDemandOp::ShouldArgumentBeConstant(
Status XlaCompileOnDemandOp::Compile(
OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
const XlaCompiler::CompilationResult** result,
xla::LocalExecutable** executable) {
ResourceVarsSnapshot* variable_args, xla::LocalExecutable** executable) {
std::map<int, Tensor> constant_arguments;
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
const Tensor& device_tensor = ctx->input(i);
@ -190,12 +181,18 @@ Status XlaCompileOnDemandOp::Compile(
// rather than a one-element tuple.
compile_options.always_return_tuple = false;
std::map<int, OptionalTensor> variable_args = GetVariables(ctx);
std::vector<int> variables_indices = GetResourceVariableIndices(ctx);
std::vector<XlaCompiler::Argument> args;
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
constant_arguments, variable_args, ctx, &args));
{
std::vector<VariableInfo> variable_infos;
TF_RETURN_IF_ERROR(
GetVariableInfosFromCtxInputs(ctx, variables_indices, &variable_infos));
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
TF_RETURN_IF_ERROR(SnapshotResourceVariables(
ctx, variables_indices, variable_infos, variable_args));
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
constant_arguments, variable_infos, ctx, &args));
}
return cache->CompileSingleOp(options, args, ctx, compile_options, result,
executable);
@ -206,8 +203,10 @@ void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) {
xla::LocalExecutable* executable;
const XlaDevice::Metadata* metadata;
OP_REQUIRES_OK(ctx, XlaDevice::GetMetadata(ctx, &metadata));
OP_REQUIRES_OK(ctx, Compile(ctx, *metadata, &result, &executable));
OP_REQUIRES_OK(ctx, Run(ctx, *metadata, result, executable));
ResourceVarsSnapshot variable_args;
OP_REQUIRES_OK(ctx,
Compile(ctx, *metadata, &result, &variable_args, &executable));
OP_REQUIRES_OK(ctx, Run(ctx, *metadata, result, executable, variable_args));
}
} // namespace tensorflow

View File

@ -20,6 +20,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/framework/function.h"
@ -47,10 +48,12 @@ class XlaCompileOnDemandOp : public OpKernel {
bool* result);
Status Compile(OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
const XlaCompiler::CompilationResult** result,
ResourceVarsSnapshot* variable_args,
xla::LocalExecutable** executable);
Status Run(OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
const XlaCompiler::CompilationResult* result,
xla::LocalExecutable* executable);
xla::LocalExecutable* executable,
const ResourceVarsSnapshot& variable_args);
};
} // namespace tensorflow

View File

@ -52,7 +52,8 @@ const char kPossibleNonVariableResourceHintMessage[] =
"resource inputs to XLA.";
} // anonymous namespace
VariableInfo::VariableInfo(int index, Var* var) : index_(index), var_(var) {}
VariableInfo::VariableInfo(int index, absl::string_view name, Var* var)
: index_(index), name_(name), var_(var) {}
VariableInfo::VariableInfo(VariableInfo&& other)
: index_(other.index_), var_(other.var_), lock_held_(other.lock_held_) {
other.index_ = -1;
@ -87,16 +88,15 @@ VariableInfo::~VariableInfo() {
// Returns a vector of VariableInfo instances for the resource variable inputs
// to the kernel with context `ctx`. The input indices for the resource
// variable inputs are in `variable_indices`.
static Status GetVariableInfosFromCtxInputs(
OpKernelContext* ctx, absl::Span<const int> variable_indices,
std::vector<VariableInfo>* result) {
Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx,
absl::Span<const int> variable_indices,
std::vector<VariableInfo>* result) {
std::vector<const ResourceHandle*> resource_handles;
absl::c_transform(
variable_indices, std::back_inserter(resource_handles),
[&](int variable_idx) { return &HandleFromInput(ctx, variable_idx); });
std::vector<core::RefCountPtr<Var>> variables;
Status s = LookupResources(ctx, resource_handles, &variables);
if (!s.ok()) {
errors::AppendToMessage(&s, kPossibleNonVariableResourceHintMessage);
@ -109,7 +109,9 @@ static Status GetVariableInfosFromCtxInputs(
// *Release* the variable because we're going to unref it later in
// ~VariableInfo.
Var* variable = variables[i].release();
result->emplace_back(variable_indices[i], variable);
int input_idx = variable_indices[i];
std::string var_name = HandleFromInput(ctx, input_idx).name();
result->emplace_back(input_idx, var_name, variable);
}
return Status::OK();
@ -162,21 +164,12 @@ Status LockVariables(absl::Span<VariableInfo> variables) {
Status SnapshotResourceVariables(OpKernelContext* ctx,
absl::Span<const int> variable_indices,
std::map<int, OptionalTensor>* result) {
std::vector<VariableInfo> variable_infos;
TF_RETURN_IF_ERROR(
GetVariableInfosFromCtxInputs(ctx, variable_indices, &variable_infos));
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
absl::Span<VariableInfo const> variable_infos,
ResourceVarsSnapshot* result) {
for (int i = 0; i < variable_indices.size(); i++) {
if (variable_infos[i].var()) {
OptionalTensor& tensor = (*result)[variable_indices[i]];
tensor.name = HandleFromInput(ctx, variable_indices[i]).name();
tensor.present = true;
tensor.value = *variable_infos[i].var()->tensor();
} else {
(*result)[variable_indices[i]] = OptionalTensor();
}
Var* var = variable_infos[i].var();
(*result)[variable_indices[i]] =
var ? absl::make_optional(*var->tensor()) : absl::nullopt;
}
return Status::OK();
}
@ -197,8 +190,7 @@ XlaComputationLaunchContext::XlaComputationLaunchContext(
void XlaComputationLaunchContext::PopulateInputs(
OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
const std::map<int, OptionalTensor>& variables,
int missing_ctx_input_prefix) {
const ResourceVarsSnapshot& variables, int missing_ctx_input_prefix) {
// Build ShapedBuffers that point directly to the Tensor buffers.
arg_ptrs_ =
std::vector<ShapedBuffer*>(compilation_result->xla_input_shapes.size());
@ -210,7 +202,7 @@ void XlaComputationLaunchContext::PopulateInputs(
CHECK_GE(arg_num, missing_ctx_input_prefix);
const xla::Shape& shape = compilation_result->xla_input_shapes[i];
const Tensor* t = variables.count(arg_num)
? &(variables.at(arg_num).value)
? &(variables.at(arg_num).value())
: &(ctx->input(arg_num - missing_ctx_input_prefix));
CHECK(t);
@ -262,7 +254,7 @@ static const Tensor* FindAliasedTensorForOutput(
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias,
absl::Span<const int> input_mapping,
const std::map<int, OptionalTensor>& resource_var_snapshots) {
const ResourceVarsSnapshot& resource_var_snapshots) {
if (MustAliasOutput(input_output_alias, output_num)) {
int xla_param = input_output_alias.GetAliasedParameter({output_num})
.value()
@ -274,8 +266,8 @@ static const Tensor* FindAliasedTensorForOutput(
// entry time.
if (input_tensor->dtype() == DT_RESOURCE) {
auto& v = resource_var_snapshots.at(missing_ctx_input_prefix + tf_param);
CHECK(v.present);
return &v.value;
CHECK(v.has_value());
return &v.value();
}
return input_tensor;
}
@ -298,9 +290,9 @@ static Tensor GetOrCreateTensorForOutput(
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias,
absl::Span<const int> input_mapping,
const std::map<int, OptionalTensor>& resource_var_snapshots,
DataType output_dtype, const TensorShape& output_shape,
se::DeviceMemoryBase output_buffer, Allocator* output_allocator) {
const ResourceVarsSnapshot& resource_var_snapshots, DataType output_dtype,
const TensorShape& output_shape, se::DeviceMemoryBase output_buffer,
Allocator* output_allocator) {
if (const Tensor* aliased_tensor = FindAliasedTensorForOutput(
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
input_mapping, resource_var_snapshots)) {
@ -431,13 +423,13 @@ static xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
// not a Tensor.
Var* variable = nullptr;
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
ctx, HandleFromInput(ctx, actual_input_index), &variable,
[&write](Var** ptr) {
*ptr = new Var(write.type);
return Status::OK();
}));
variable_infos.emplace_back(actual_input_index, variable);
const ResourceHandle handle = HandleFromInput(ctx, actual_input_index);
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(ctx, handle, &variable,
[&write](Var** ptr) {
*ptr = new Var(write.type);
return Status::OK();
}));
variable_infos.emplace_back(actual_input_index, handle.name(), variable);
}
return variable_infos;
}
@ -447,7 +439,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
const XlaCompiler::CompilationResult* compilation_result,
ScopedShapedBuffer output, int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias,
const std::map<int, OptionalTensor>& resource_var_snapshots) {
const ResourceVarsSnapshot& resource_var_snapshots) {
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
Allocator* allocator = ctx->device()->GetAllocator({});
@ -484,10 +476,36 @@ Status XlaComputationLaunchContext::PopulateOutputs(
stream->ThenRecordEvent(definition_event.get());
}
std::vector<TensorShape> output_tensor_shapes;
output_tensor_shapes.reserve(ctx->num_outputs());
if (output.on_host_shape().is_dynamic()) {
TF_ASSIGN_OR_RETURN(
auto transfer_manager,
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
xla::Shape output_host_shape = output.on_host_shape();
xla::Shape output_device_shape = output.on_device_shape();
TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
stream, &output, &output_host_shape, &output_device_shape));
output.set_shapes(output_host_shape, output_device_shape);
for (int i = 0; i < ctx->num_outputs(); ++i) {
const xla::Shape& subshape =
xla::ShapeUtil::GetSubshape(output_host_shape, {i});
TensorShape shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
output_tensor_shapes.push_back(shape);
}
} else {
for (int i = 0; i < ctx->num_outputs(); ++i) {
output_tensor_shapes.push_back(compilation_result->outputs[i].shape);
}
}
// Copy XLA results to the OpOutputList.
int output_num = 0;
for (int i = 0; i < ctx->num_outputs(); ++i) {
const TensorShape& shape = compilation_result->outputs[i].shape;
const TensorShape& shape = output_tensor_shapes[i];
const DataType& type = compilation_result->outputs[i].type;
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
<< DataTypeString(type);
@ -564,12 +582,21 @@ Status XlaComputationLaunchContext::PopulateOutputs(
Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
const std::map<int, Tensor>& constant_args,
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
absl::Span<VariableInfo const> variable_args, OpKernelContext* ctx,
std::vector<XlaCompiler::Argument>* args) {
args->resize(ctx->num_inputs());
absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup;
for (const VariableInfo& info : variable_args) {
CHECK(!info.var() || info.lock_held())
<< "Need to hold the lock on resource variables "
"before calling BuildXlaCompilerArguments";
variable_info_lookup.emplace(info.index(), &info);
}
for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) {
XlaCompiler::Argument& arg = (*args)[input_num];
if (constant_args.count(input_num) > 0) {
// Handles compile-time constants.
const Tensor& input = constant_args.at(input_num);
@ -578,7 +605,7 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
arg.type = input.dtype();
arg.shape = input.shape();
arg.constant_value = input;
} else if (variable_args.count(input_num) == 0) {
} else if (variable_info_lookup.count(input_num) == 0) {
// Handles the non-constant arguments.
const Tensor& input = ctx->input(input_num);
TF_RET_CHECK(input.dtype() != DT_RESOURCE);
@ -594,14 +621,14 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
// Handles resource variables.
const Tensor& input = ctx->input(input_num);
TF_RET_CHECK(input.dtype() == DT_RESOURCE);
const OptionalTensor& variable = variable_args.at(input_num);
arg.name = variable.name;
const VariableInfo& variable = *variable_info_lookup[input_num];
arg.name = std::string(variable.name());
arg.kind = XlaCompiler::Argument::kResource;
arg.resource_kind = XlaResource::kVariable;
if (variable.present) {
const Tensor& value = variable.value;
arg.type = value.dtype();
arg.shape = value.shape();
if (variable.var()) {
const Tensor* value = variable.var()->tensor();
arg.type = value->dtype();
arg.shape = value->shape();
arg.initialized = true;
} else {
// The values of uninitialized variables are not passed as inputs, since

View File

@ -34,36 +34,17 @@ limitations under the License.
namespace tensorflow {
// Struct that represents a possibly-absent Tensor.
struct OptionalTensor {
string name; // A descriptive name
bool present = false; // Is the tensor present?
Tensor value; // If present, what is the Tensor's value?
};
// Takes a snapshot of the values of resource variable arguments, whose indices
// are specified in `variable_indices` argument. We snapshot tensors that back
// resource variables since concurrent updates may modify the shape, and it is
// important that the shapes used for compilation match the true shapes of the
// buffers.
//
// We snapshot the entire set of resource variables as one atomic operation.
// This models Read->* dependencies between resource variable operations. See
// jit/resource_operation_safety_analysis for details.
//
// Returns a map of TensorFlow argument index to resource variable. If a
// resource variable is not initialized, the corresponding OptionalTensor
// will have its `present` field set to false.
Status SnapshotResourceVariables(OpKernelContext* ctx,
absl::Span<const int> variable_indices,
std::map<int, OptionalTensor>* result);
// Snapshot of resource variables for a TF kernel invocation, mapping from
// parameter number to values at execution time. If the resource variable is not
// initialized, the value will not be present.
using ResourceVarsSnapshot = absl::flat_hash_map<int, absl::optional<Tensor>>;
// Information about the state of a variable passed as input to the _XlaCompile
// and _XlaRun operators. Unlocks the resource variable and decrements its
// refcount on destruction.
class VariableInfo {
public:
explicit VariableInfo(int index, Var* var);
explicit VariableInfo(int index, absl::string_view name, Var* var);
VariableInfo(VariableInfo&& other);
VariableInfo& operator=(VariableInfo&& other);
@ -79,6 +60,9 @@ class VariableInfo {
// "empty", i.e. it does not track a resource variable.
Var* var() const { return var_; }
// Returns the variable name.
absl::string_view name() const { return name_; }
// Returns true if the resource variable lock was successfully acquired by
// this thread.
bool lock_held() const { return lock_held_; }
@ -88,6 +72,7 @@ class VariableInfo {
private:
int index_;
std::string name_;
Var* var_;
// We can't use a optional<mutex_lock> here because it confuses the compiler's
@ -96,6 +81,20 @@ class VariableInfo {
bool lock_held_ = false;
};
// Takes a snapshot of the values of resource variable arguments, whose indices
// are specified in `variable_indices` argument. We snapshot tensors that back
// resource variables since concurrent updates may modify the shape, and it is
// important that the shapes used for compilation match the true shapes of the
// buffers.
//
// We snapshot the entire set of resource variables as one atomic operation.
// This models Read->* dependencies between resource variable operations. See
// jit/resource_operation_safety_analysis for details.
Status SnapshotResourceVariables(OpKernelContext* ctx,
absl::Span<const int> variable_indices,
absl::Span<VariableInfo const> variable_infos,
ResourceVarsSnapshot* result);
// Acquires the mutexes for all the variables in `variables` using a
// deadlock-safe protocol (acquire the mutexes in increasing-address order).
//
@ -104,6 +103,13 @@ class VariableInfo {
Status LockVariables(absl::Span<VariableInfo> variables)
TF_EXCLUSIVE_LOCK_FUNCTION();
// Returns a vector of VariableInfo instances for the resource variable inputs
// to the kernel with context `ctx`. The input indices for the resource
// variable inputs are in `variable_indices`.
Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx,
absl::Span<const int> variable_indices,
std::vector<VariableInfo>* result);
// Helper class to perform the marshalling of TensorFlow inputs and outputs to
// ShapedBuffers suitable for passing to an XLA computation.
class XlaComputationLaunchContext {
@ -123,9 +129,10 @@ class XlaComputationLaunchContext {
// Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch
// op.
// Precondition: variables in `variable_args` are locked.
static Status BuildXlaCompilerArguments(
const std::map<int, Tensor>& constant_args,
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
absl::Span<VariableInfo const> variable_args, OpKernelContext* ctx,
std::vector<XlaCompiler::Argument>* args);
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
@ -137,7 +144,7 @@ class XlaComputationLaunchContext {
// (in other words, no inputs actually required by the kernel can be missing).
void PopulateInputs(OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
const std::map<int, OptionalTensor>& variables,
const ResourceVarsSnapshot& variables,
int missing_ctx_input_prefix);
// Given the XLA output in `output`, populate all outputs of `ctx`. Also
@ -155,7 +162,7 @@ class XlaComputationLaunchContext {
const XlaCompiler::CompilationResult* compilation_result,
xla::ScopedShapedBuffer output, int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias,
const std::map<int, OptionalTensor>& resource_var_snapshots);
const ResourceVarsSnapshot& resource_var_snapshots);
// Return the argument list. Only valid after PopulateInputs() has been
// called.

View File

@ -0,0 +1,265 @@
# MLIR CodeGen for XLA
<!--*
# Document freshness: For more information, see go/fresh-source.
freshness: { owner: 'timshen' reviewed: '2020-06-16' }
*-->
XLA operates on `HloInstruction` and performs many optimizations on this
representation, sharing a lot of these between targeted devices. As some point a
linear schedule is computed and the memory buffer is assigned to each value
statically. The device specific codegen operates by traversing this sequence and
calling "emitters" to generate a representation suitable for the device (for
example a single LLVM function per XLA computation on CPU, or a sequence of
"thunks" encapsulating GPU operations and possibly generated PTX when targeting
GPU).
As a staging step, we're currently in the process of intercepting the process
right after XLA completes the buffer-assignment phase and emit instead an MLIR
module in the `lhlo` dialect. From there we perform the codegen using MLIR
components (Linalg, affine, and GPU dialect mainly) depending on the device.
Below is the plan of record to incrementally migrate XLA/GPU by using `lhlo` as
the codegen input.
## Tasks
| | Host | Device
| ------------- | ------------------------ | ------------------------
| Input format | HloInstruction* (Task 1) | HloInstruction* (Task 1)
| Output format | xla::Thunk (Task 2) | LLVM IR (Task 3)
* **Task 1** changes both host and device input format from HloInstruction* to
LHLO.
* **Task 2** changes output format of host from thunks to "some landing pad
for host" (see below).
* **Task 3** migrates device output from LLVM IR to some form of MLIR. It's
optional to this project, and see the section "Migrating Device LLVM IR" for
details.
This project prioritizes having end-to-end runnable models with LHLO-emitters
enabled as much as possible. This implies that the following order list of
objectives by priority:
* Make XLA/GPU runnable with LHLO emitters, with existing Thunks and emitters
unmodified.
* Eliminate the references to HloInstruction\* in LHLO, case by case:
* Switch a legacy emitter to an MLIR-based emitter (e.g. Linalg), or
* Mechanically translate the existing emitter to take MLIR representation
(migrate to Standard with GPU Dialect).
## Migrating Thunks (Task 2)
xla::gpu::Thunk is a data structure that:
* Can be called into from the host (xla::gpu::Thunk::ExecuteOnStream()).
* Carries various data in its subclasses.
* Interacts with BufferAllocation::Slice and StreamExecutor.
* Launches kernels
* Calls into all runtime libraries.
The cost of that includes:
* Representing op-specific configuration data (e.g. convolution configs).
* Migrating op shape and operand shapes.
* Representing a tree of thunks (while, condition, etc).
The migration work is independent from LHLO / emitter migration. Under limited
resources, it's prioritized behind LHLO / emitter migration.
We have several choices on how to lower the host-side part from LHLO:
* TFRT
* (Pro) great CUDA and HIP wrappers for use.
* (Pro) easy to implement library calls (cuDNN, cuBLAS, cuFFT, etc), as
TFRT ops are interpreted by C++ code.
* (Con) host side is under development and not tested.
* (Con) the JAX integration isnt clear from a runtime point of view
* Jitted CPU code
* (Pro) great lower-ability. Create a few loops and conditions and it's
done.
* (Con) GPUDialect doesn't yet model chains/streams/asynchronicity/device
allocation.
* (Con) CUDA / HIP runtime support is minimal (toolkit path, version,
dynamic loading, etc).
* Existing (interpreting) XLA runtime
Tentative conclusion: Use jitted CPU code during the transition, and optionally
adopt TFRT in the end.
## Migrating Device LLVM IR (Task 3)
An elemental emitter generates target op by filling it element by element. Each
output element depends on a set of elements from the operands. All elements are
described by combining the buffer with dynamic indices. It's sufficient to
describe almost all "math" ops, but for performance reasons only a large subset
of "math" ops are implemented directly in (Cpu|Gpu)ElementalIrEmitter.
ElementalIrEmitter is unique in that:
* A large portion of the code is shared between XLA/GPU and CPU.
* It represents a large portion of ops seen in models, including all
element-wise ops.
* Most fusions solely depend on ElementalIrEmitter.
* It's structurally simple, as it describes a data dependency DAG between op
elements and operand elements.
* It's mostly portable and high-level (e.g. unlike GPU kReduce and GPU kCopy).
* Dynamic shape support is easy for at least element-wise ops.
Now, for all ops, elementally-emitted or not, there are several flavors of the
end state of each XLA op:
1. Device code stays as LLVM IR.
1. Refactor the old emitter to be like LHLO -> MLIR LLVM Dialect:
* (Cost) Will be throw-away work if we want to ultimately migrate to
Standard.
* (Benefit) It is easy and mechanical. Can be done in a short period.
* (Benefit) It doesn't benefit more compared to a).
1. Refactor old emitters to be like LHLO -> MLIR GPU + Standard + Loops:
* (Cost) Lifting existing emitters to Standard introduces some challenges.
Pointers and GEPs need to be converted to MemRefs and SubViews. Ensuring
amdgpu completeness is another one.
* (Cost) XLA/GPU heavily relies on LLVM metadata:
* `range` for block/thread indices.
* `align`, `dereferenceable`, `invariant.load`, `alias.scope`,
`noalias` for load/stores.
* `llvm.loop.unroll.disable`, `llvm.loop.unroll.full`,
`llvm.loop.vectorize.enable` for sequential loops.
* (Benefit) Can be long-term. More portable.
1. Refactor old emitters to be LHLO -> Linalg, and write new Linalg emitters
* (Cost) This is case by case. Compared to previous options, a new
implementation that matches XLA's performance needs to go through the
benchmark <-> optimize workflow, which can be a significant cost for
some ops.
* (Benefit) unified stack; community support; portability; more
optimization potentials.
## Prioritization
While all three tasks mentioned above are parallelizable, under limited
resources they have to be serialized. The prioritization focuses on visible
results for completion of each task.
The prioritization is: Task1 (LHLO for legacy emitters) > Task 2 (Thunks) > Task
3 (MLIR emitters).
By the end of Task 1, users of XLA can generate an LHLO (e.g. kernel generator)
and execute them. The compilation format will not be serializable MLIR.
By the end of Task 2, LHLO lowers to proper, serializable MLIR. This enables
offline compilation.
By the end of Task 3, all XLA emitters are MLIR-based in its implementation.
## Detailed Design
### Step 1: (Task 1) Complete LHLO and Make Legacy Emitters Take LHLO
This step makes all existing XLA/GPU emitters interact with MLIR ops. This step
is pure refactoring and NFC.
This step is mostly mechanical, but it's worth noticing the following
discrepancies between an unnested HloComputation and LHLO:
* Each HloInstruction has direct access to its operands (a data-flow DAG). On
contrary, each LHLO op only has access to its operand buffers (a bipartite
between ops and buffers). LHLO ops have to go through use-def chains to
access their operand ops.
* Unnested legacy emitters empirically almost never access their operands. The
only exception is kReduce.
* Unnested legacy emitters access BufferAssignment only for getting slices,
not for accessing aux data structures like dataflow\_analysis() or
alias\_analysis(). llvm\_ir builds its own alias\_analysis() based on slice
information.
The conclusion is that LHLO should fit right-in without major hassle.
### Step 2: (Optional) Profiling Support
**This step is only needed if we start to discard some of the XLA Thunk logic
(see the next step).**
Before actually turning on any MLIR-based emitters, we need profiling for
MLIR-based emitters.
Currently XLA performs its own profiling by calling into StreamExecutor's timer.
The timer under the hood inserts two events before and after a kernel launch,
and measures the sync time between these two events.
There are roughly three approaches to support profiling in MLIR:
* Run a profiler end-to-end
* Add a profile op for each op in LHLO, using an injected profiler.
The "end-to-end" approach is transparent to MLIR, but suffers the same problem
that makes XLA not use it in the first place: library calls collected by a
profiler (nvprof/...) can't easily relate to HLO ops. For example, cuDNN
launches multiple kernels for each HLO, and it's hard to tell which kernels
correspond to which HLO.
The "injected profiler" approach requires:
* LHLO to take a profiler as a parameter.
* inserting profile.start / profile.end before and after each op.
* a pass from that lowers profile.{start,end} to a C++ implementation.
The exact profiling can't be easily done for MLIR-generated ops, since:
* MLIR doesn't have a timer, nor it depends on TFRT / StreamExecutor.
* MLIR doesn't easily call into C functions with complicated parameters.
### Step 3: (Task 2) Migrating Thunks
This step migrates all host ops and library calls. This step will eliminate most
of the thunks and produce serializable MLIR instead.
There are roughly three kinds of thunks:
* KernelThunk, which launches a kernel.
* Control flow thunks, which has host control flow logic (conditional, while,
for, sequence) and launch body kernels.
* Library thunks: cuDNN, cuBLAS, cuFFT, NCCL, etc.
The **bottom line** is to:
* Create a Thunk dialect that provides (de)serialize logic for all existing
C++-based Thunks.
* Change emitters to emit a graph of Thunk dialect.
**Optionally**, we can relieve some thunks from C++ implementation. KernelThunk
can lower to the GPU LaunchKernelOp. Control flow thunks can leverage the CFG
Dialect for loops and conditions, combined with LaunchKernelOp. This optional
step requires profiling and stream support.
### Step 4: (Task 3) Migrated ElementalIrEmitter
Once profiling is ready, we can complete and tune all ElementalIrEmitter-based
emitters in MLIR. Then we turn them on by default, assuming that all of these
MLIR-based emitters use a single stream.
Notice that it's beneficial to migrate XLA/CPU's ElementalIrEmitter as well,
since they share a large portion of the code.
With all benchmarking and performance hunting done (TODO: define performance
parity), we turn on the new MLIR-based elemental emitter, and delete the legacy
ElementalIrEmitter.
This step also provides easy fusion transitions (nested ops) for the later
migration.
### Step 5: Multi-Stream Support or Drop
We can't delete
[some of the emitters](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/service/gpu/stream_assignment.cc#L140)
until we support it in MLIR, or we drop the feature. It's a relatively large
amount of work in MLIR and a small amount of gain for XLA. We should investigate
current users of multi-stream XLA/GPU users, and try to delete this feature if
reasonable.
### Step 6: (Task 3) Migrated Device Ops
This step migrates all unnested ops, then we can delete all unnested emitters.
This calls on a rewrite/refactor for kCopy and kReduce. kReduce is already
worked on for plenty, so the actual amount of work that needs to be done remains
to be seen.

View File

@ -314,7 +314,6 @@ tf_cc_test(
cc_library(
name = "tensorflow_lite_legalize_tf",
srcs = [
"transforms/device_index_selector.cc",
"transforms/dilated_conv.cc",
"transforms/generated_legalize_tf.inc",
"transforms/generated_lower_static_tensor_list.inc",

View File

@ -446,7 +446,7 @@ static void GenOperandResultVerifier(raw_ostream &os,
auto desc =
definit->getDef()->getValueAsString("tflRuntimeTypeDescription");
// Emit a loop to check all the dynamic values in the pack.
// Emit a loop to check all operands.
os << formatv(" for (Value v : top.getODS{0}{1}s({2})) {{\n",
// Capitalize the first letter to match the function name
valueKind.substr(0, 1).upper(), valueKind.substr(1),
@ -455,14 +455,10 @@ static void GenOperandResultVerifier(raw_ostream &os,
os << " (void)v;\n"
<< " if (!("
<< tgfmt(pred.getCondition(), &fctx.withSelf("v.getType()")) << ")) {\n"
<< " if (failure_on_operand_type_mismatch) {\n"
<< formatv(
" return op->emitOpError(\"{0} #\") << index "
" return op->emitOpError(\"{0} #\") << index "
"<< \" must be {1}, but got \" << v.getType();\n",
valueKind, desc)
<< " } else {\n"
<< " return ::mlir::LogicalResult::Failure;\n"
<< " }\n"
<< " }\n" // if
<< " ++index;\n"
<< " }\n"; // for
@ -487,8 +483,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
mlir::tblgen::FmtContext verify_ctx;
os << "::mlir::LogicalResult " << op.getCppClassName()
<< "::VerifyTflRuntimeConstraints(::mlir::Operation *op, bool "
"failure_on_operand_type_mismatch) {\n";
<< "::VerifyTflRuntimeConstraints(::mlir::Operation *op) {\n";
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
verify_ctx.withOp("top");
@ -529,11 +524,8 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
os << tgfmt(
" if (!($0)) {\n "
" if (failure_on_operand_type_mismatch) {\n"
" return top.emitOpError(\"failed to verify that $1\");\n"
" } else {\n"
" return ::mlir::LogicalResult::Failure;\n }\n }\n",
" if (!($0))\n"
" return top.emitOpError(\"failed to verify that $1\");\n",
&verify_ctx, tgfmt(pred.getCondition(), &verify_ctx), desc);
}
os << " return top.verify();\n}\n";

View File

@ -240,10 +240,10 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
}
for (auto fn : module.getOps<FuncOp>()) {
if (fn.getBlocks().size() != 1) {
if (!llvm::hasSingleElement(fn)) {
return fn.emitError("should have exactly one basic block"), false;
}
auto& bb = fn.getBlocks().front();
auto& bb = fn.front();
for (auto arg : bb.getArguments()) {
if (!HasValidTFLiteType(arg, fn))
@ -1089,7 +1089,7 @@ void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) {
dict_attr.get("outputs").dyn_cast_or_null<mlir::StringAttr>()) {
str.getValue().split(output_names, ',', /*MaxSplit=*/-1,
/*KeepEmpty=*/false);
auto term = fn.getBlocks().back().getTerminator();
auto term = fn.back().getTerminator();
if (output_names.size() != term->getNumOperands()) {
fn.emitWarning() << "output names (" << output_names.size()
<< ") != terminator operands (" << term->getNumOperands()

View File

@ -94,8 +94,7 @@ def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> {
let methods = [
StaticInterfaceMethod<
[{Returns whether the op's operands/results are supported by runtime.}],
"LogicalResult", "VerifyTflRuntimeConstraints",
(ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch)
"LogicalResult", "VerifyTflRuntimeConstraints", (ins "Operation*":$op)
>,
];
}

View File

@ -138,6 +138,11 @@ bool IsI32Type(Type element_type) {
return element_type.isInteger(32) && !element_type.isUnsignedInteger();
}
// Return true when the given element_type is I64.
bool IsI64Type(Type element_type) {
return element_type.isInteger(64) && !element_type.isUnsignedInteger();
}
// Return true if the given Add operation has the CPU kernel supported shapes.
bool VerifyAddOpShapeConstraints(AddOp op) {
auto element_type = getElementTypeOrSelf(op.output().getType());
@ -174,7 +179,8 @@ bool VerifySubOpShapeConstraints(SubOp op) {
// Allows F32, QUI8, and QI16 outputs when the operands have valid shapes,
// which are broadcastable shapes up to five dimension or have same shapes.
if (element_type.isF32() || IsI32Type(element_type) ||
IsQUI8Type(element_type) || IsQI16Type(element_type)) {
IsI64Type(element_type) || IsQUI8Type(element_type) ||
IsQI16Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5);
@ -758,6 +764,22 @@ OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
return new_concat.getResult();
}
//===----------------------------------------------------------------------===//
// CustomOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(CustomOp op) {
OpaqueElementsAttr opaque_attr =
op.custom_option().cast<OpaqueElementsAttr>();
if (!opaque_attr.getType().hasStaticShape())
return op.emitOpError("custom_option should have a static shape.");
if (opaque_attr.getValue().size() !=
opaque_attr.getType().cast<ShapedType>().getDimSize(0))
return op.emitOpError(
"custom_option should have the same length of content with shape.");
return success();
}
//===----------------------------------------------------------------------===//
// FullyConnectedOp
//===----------------------------------------------------------------------===//
@ -2169,6 +2191,10 @@ static LogicalResult Verify(TransposeOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// WhileOp
//===----------------------------------------------------------------------===//
LogicalResult Verify(WhileOp op) {
if (op.getNumOperands() != op.getNumResults())
return op.emitOpError(llvm::formatv(
@ -2178,18 +2204,6 @@ LogicalResult Verify(WhileOp op) {
return success();
}
static LogicalResult Verify(CustomOp op) {
OpaqueElementsAttr opaque_attr =
op.custom_option().cast<OpaqueElementsAttr>();
if (!opaque_attr.getType().hasStaticShape())
return op.emitOpError("custom_option should have a static shape.");
if (opaque_attr.getValue().size() !=
opaque_attr.getType().cast<ShapedType>().getDimSize(0))
return op.emitOpError(
"custom_option should have the same length of content with shape.");
return success();
}
namespace {
// Canonicalize While op so that results and operands match and external values
// are via implicit capture rather than via block args.

View File

@ -571,6 +571,8 @@ def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [
TFL_OperandHasRank<2, 4>,
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 2>>,
AccumulatorUniformScale<3, 1, 2>,
TFL_ChannelDimIndexInterface, AffineOpCoefficient<0, 2>,
TFL_GpuTargetOp,
TFL_SparseOp]> {
let summary = "Transpose convolution operator";
@ -596,6 +598,8 @@ def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [
let verifier = [{ return Verify(*this); }];
let extraClassDeclaration = [{
// ChannelDimIndexInterface:
int GetChannelDimIndex() { return 0; }
// SparseOpInterface:
std::vector<int> GetSparseOperands() { return {1}; }
std::vector<std::vector<int>> GetFloatBlockSize() { return {}; }
@ -953,14 +957,14 @@ in the batch dimensions and broadcasting.
}];
let arguments = (ins
TFL_TensorOf<[F32]>:$x,
TFL_TensorOf<[F32]>:$y,
TFL_TensorOf<[F32, QI8]>:$x,
TFL_TensorOf<[F32, QI8]>:$y,
DefaultValuedAttr<BoolAttr, "false">:$adj_x,
DefaultValuedAttr<BoolAttr, "false">:$adj_y
);
let results = (outs
TFL_TensorOf<[F32]>:$output
TFL_TensorOf<[F32, QI8]>:$output
);
let hasOptions = 1;
@ -2860,11 +2864,11 @@ def TFL_SubOp : TFL_Op<"sub", [
}];
let arguments = (
ins TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$lhs,
TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$rhs,
ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$lhs,
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$rhs,
TFL_AFAttr:$fused_activation_function);
let results = (outs TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$output);
let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$output);
let hasFolder = 1;

View File

@ -26,6 +26,7 @@ filegroup(
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
"//tensorflow/compiler/mlir/lite:tf_tfl_translate",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:not",
],
)

View File

@ -0,0 +1,257 @@
# RUN: not tf_tfl_translate -tf-upgrade-legacy=false -tf-input-arrays=Placeholder,Placeholder_1 -tf-input-shapes=1,2:1 -tf-output-arrays=cond/Merge -tf-enable-shape-inference-on-import=false -mlir-print-debuginfo -output-mlir %s -o - 2>&1 | FileCheck %s
# CHECK: error: The graph has Control Flow V1 ops. TFLite converter doesn't support Control Flow V1 ops. Consider using Control Flow V2 ops instead.
node {
name: "Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 2
}
dim {
size: 2
}
}
tensor_content: "\315\314\314=\315\314L>\232\231\231>\315\314\314>"
}
}
}
}
node {
name: "Placeholder"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: -1
}
dim {
size: 2
}
}
}
}
}
node {
name: "Placeholder_1"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_BOOL
}
}
attr {
key: "shape"
value {
shape {
}
}
}
}
node {
name: "cond/Switch"
op: "Switch"
input: "Placeholder_1"
input: "Placeholder_1"
attr {
key: "T"
value {
type: DT_BOOL
}
}
}
node {
name: "cond/switch_t"
op: "Identity"
input: "cond/Switch:1"
attr {
key: "T"
value {
type: DT_BOOL
}
}
}
node {
name: "cond/switch_f"
op: "Identity"
input: "cond/Switch"
attr {
key: "T"
value {
type: DT_BOOL
}
}
}
node {
name: "cond/pred_id"
op: "Identity"
input: "Placeholder_1"
attr {
key: "T"
value {
type: DT_BOOL
}
}
}
node {
name: "cond/MatMul"
op: "MatMul"
input: "cond/MatMul/Switch:1"
input: "cond/MatMul/Switch_1:1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "transpose_a"
value {
b: false
}
}
attr {
key: "transpose_b"
value {
b: false
}
}
}
node {
name: "cond/MatMul/Switch"
op: "Switch"
input: "Placeholder"
input: "cond/pred_id"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@Placeholder"
}
}
}
}
node {
name: "cond/MatMul/Switch_1"
op: "Switch"
input: "Const"
input: "cond/pred_id"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@Const"
}
}
}
}
node {
name: "cond/Add"
op: "Add"
input: "cond/Add/Switch"
input: "cond/Add/Switch_1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "cond/Add/Switch"
op: "Switch"
input: "Placeholder"
input: "cond/pred_id"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@Placeholder"
}
}
}
}
node {
name: "cond/Add/Switch_1"
op: "Switch"
input: "Const"
input: "cond/pred_id"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@Const"
}
}
}
}
node {
name: "cond/Merge"
op: "Merge"
input: "cond/Add"
input: "cond/MatMul"
attr {
key: "N"
value {
i: 2
}
}
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "init"
op: "NoOp"
}
versions {
producer: 134
}

View File

@ -9,6 +9,15 @@ func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
// CHECK: return
}
func @sub(%arg0: tensor<1xi64>, %arg1: tensor<1xi64>) -> tensor<1xi64> {
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64>
return %0: tensor<1xi64>
// CHECK-LABEL: sub
// CHECK: tfl.sub %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xi64>
// CHECK: return
}
// CHECK-LABEL: testAddHighDimsHaveSameShape
func @testAddHighDimsHaveSameShape(%arg0: tensor<1x2x3x4x5x6x7x8xi32>, %arg1: tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32> {
// CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"}
@ -990,6 +999,13 @@ func @batch_to_space_nd(%arg0: tensor<4x2x2x3xf32>, %arg1: tensor<2xi32>, %arg2:
// CHECK: "tfl.batch_to_space_nd"(%arg0, %arg1, %arg2) : (tensor<4x2x2x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
}
func @batch_to_space_nd_unsupported(%arg0: tensor<?x1x1x1x4xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3x2xi32>) -> tensor<?x3x3x3x4xf32> {
%0 = "tf.BatchToSpaceND"(%arg0, %arg1, %arg2) : (tensor<?x1x1x1x4xf32>, tensor<3xi32>, tensor<3x2xi32>) -> tensor<?x3x3x3x4xf32>
return %0 : tensor<?x3x3x3x4xf32>
// CHECK-LABEL: batch_to_space_nd_unsupported
// CHECK: "tf.BatchToSpaceND"
}
func @space_to_batch_nd(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2x2xi32>) -> tensor<?xf32> {
%0 = "tf.SpaceToBatchND"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>

View File

@ -269,6 +269,14 @@ func @testSub(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
return %0#0 : tensor<? x i32>
}
// CHECK-LABEL: testSubInt64
func @testSubInt64(tensor<? x i64>, tensor<? x i64>) -> tensor<? x i64> {
^bb0(%arg0: tensor<? x i64>, %arg1: tensor<? x i64>):
// CHECK: tfl.sub %arg0, %arg1 {fused_activation_function = "RELU6"}
%0 = tfl.sub %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor<? x i64>
return %0#0 : tensor<? x i64>
}
// CHECK-LABEL: testMul
func @testMul(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):

View File

@ -70,6 +70,7 @@ func @prepareAdd(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
}
// CHECK-LABEL: prepareConv2DSplat
// PerTensor-LABEL: prepareConv2DSplat
func @prepareConv2DSplat(%arg0: tensor<1x5x5x3xf32>) -> tensor<1x5x5x3xf32> {
%w = constant dense<127.0> : tensor<3x3x3x3xf32>
%b = constant dense<0.0> : tensor<3xf32>
@ -89,6 +90,7 @@ func @prepareConv2DSplat(%arg0: tensor<1x5x5x3xf32>) -> tensor<1x5x5x3xf32> {
}
// CHECK-LABEL: prepareConv2D
// PerTensor-LABEL: prepareConv2D
func @prepareConv2D(%arg0: tensor<1x5x5x1xf32>) -> tensor<1x5x5x3xf32> {
%w = constant dense<[[[[0.0]]], [[[127.0]]], [[[-127.0]]]]> : tensor<3x1x1x1xf32>
%b = constant dense<0.0> : tensor<3xf32>
@ -108,6 +110,7 @@ func @prepareConv2D(%arg0: tensor<1x5x5x1xf32>) -> tensor<1x5x5x3xf32> {
}
// CHECK-LABEL: prepareDepthwiseConv2D
// PerTensor-LABEL: prepareDepthwiseConv2D
func @prepareDepthwiseConv2D(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
%w = constant dense<127.0> : tensor<32x3x3x3xf32>
%b = constant dense<0.0> : tensor<32xf32>
@ -127,6 +130,7 @@ func @prepareDepthwiseConv2D(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112
}
// CHECK-LABEL: QuantizeFullyConnected
// PerTensor-LABEL: QuantizeFullyConnected
func @QuantizeFullyConnected(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
%w = constant dense<127.0> : tensor<32x12xf32>
%b = constant dense<0.0> : tensor<32xf32>
@ -143,3 +147,22 @@ func @QuantizeFullyConnected(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112
// PerTensor: %[[dq:.*]] = "tfl.dequantize"(%0) : (tensor<32x12x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>) -> tensor<32x12xf32>
// PerTensor: "tfl.fully_connected"(%arg0, %[[dq]]
}
// CHECK-LABEL: QuantizeTransposeConv
// PerTensor-LABEL: QuantizeTransposeConv
func @QuantizeTransposeConv(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<4xi32>) -> tensor<1x32x42x128xf32> {
%w = constant dense<127.0> : tensor<1x32x42x128xf32>
%b = constant dense<0.0> : tensor<1x32x42x128xf32>
%tc = "tfl.transpose_conv"(%arg1, %arg0, %w, %b) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<1x32x42x128xf32>) -> tensor<1x32x42x128xf32>
return %tc : tensor<1x32x42x128xf32>
// CHECK: %[[CST:.*]] = constant dense<1.270000e+02> : tensor<1x32x42x128xf32>
// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CST]]) {qtype = tensor<1x32x42x128x!quant.uniform<i8<-127:127>:f32:0, {1.000000e+00}>>, volatile}
// CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) : (tensor<1x32x42x128x!quant.uniform<i8<-127:127>:f32:0, {1.000000e+00}>>) -> tensor<1x32x42x128xf32>
// CHECK: "tfl.transpose_conv"(%arg1, %arg0, %[[DEQUANTIZE]]
// PerTensor: %[[CST:.*]] = constant dense<1.270000e+02> : tensor<1x32x42x128xf32>
// PerTensor: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CST]]) {qtype = tensor<1x32x42x128x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>, volatile}
// PerTensor: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) : (tensor<1x32x42x128x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>) -> tensor<1x32x42x128xf32>
// PerTensor: "tfl.transpose_conv"(%arg1, %arg0, %[[DEQUANTIZE]]
}

View File

@ -528,6 +528,26 @@ func @PadStridedSliceNewAxisMask2(%arg0: tensor<4x64x64x1xf32>) -> tensor<1x4x64
return %1 : tensor<1x4x64x64xf32>
}
// CHECK-LABEL: @StridedSliceRewriteMasks
func @StridedSliceRewriteMasks(%arg0: tensor<8x4x16x2xf32>) -> tensor<8x4x16x1xf32> {
%cst = "tf.Const"() {device = "", value = dense<[1, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
%cst_0 = "tf.Const"() {device = "", value = dense<[1, 0, 0]> : tensor<3xi32>} : () -> tensor<3xi32>
%cst_1 = "tf.Const"() {device = "", value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK: %[[CST:.*]] = constant dense<[1, 0, 0, 1]> : tensor<4xi32>
// CHECK: %[[CST0:.*]] = constant dense<[1, 0, 0, 0]> : tensor<4xi32>
// CHECK: %[[CST1:.*]] = constant dense<1> : tensor<4xi32>
// CHECK: %[[RESULT:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST0]], %[[CST1]])
// CHECK-SAME: begin_mask = 7 : i64
// CHECK-SAME: ellipsis_mask = 0 : i64
// CHECK-SAME: end_mask = 14 : i64
// CHECK-SAME: new_axis_mask = 0 : i64
// CHECK-SAME: shrink_axis_mask = 0 : i64
%0 = "tf.StridedSlice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 1 : i64, device = "", ellipsis_mask = 2 : i64, end_mask = 4 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<8x4x16x2xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x4x16x1xf32>
return %0 : tensor<8x4x16x1xf32>
}
// CHECK-LABEL: @MatrixSetDiagV2Conversion
func @MatrixSetDiagV2Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
%cst = constant dense<0> : tensor<i32>

View File

@ -39,22 +39,18 @@ namespace tensorflow {
void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
mlir::OpPassManager* pass_manager) {
pass_manager->addPass(mlir::TFL::CreatePrepareQuantizePass(quant_specs));
pass_manager->addPass(mlir::TFL::CreateQuantizePass());
bool emit_quant_adaptor_ops =
quant_specs.inference_type != quant_specs.inference_input_type;
pass_manager->addPass(
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
if (quant_specs.default_ranges.first.hasValue() ||
quant_specs.default_ranges.second.hasValue()) {
pass_manager->addPass(mlir::TFL::CreateDefaultQuantParamsPass(
quant_specs.default_ranges.first.getValueOr(0.0),
quant_specs.default_ranges.second.getValueOr(0.0),
quant_specs.IsSignedInferenceType()));
pass_manager->addPass(mlir::TFL::CreateQuantizePass());
pass_manager->addPass(
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
}
pass_manager->addPass(mlir::TFL::CreateQuantizePass());
bool emit_quant_adaptor_ops =
quant_specs.inference_type != quant_specs.inference_input_type;
pass_manager->addPass(
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
}
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
@ -63,7 +59,7 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
standard_pipeline_options.enable_inliner = false;
standard_pipeline_options.form_clusters = pass_config.form_clusters;
mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options);
pass_manager->addPass(mlir::TFL::CreateDeviceIndexSelectorPass());
pass_manager->addPass(mlir::TF::CreateDeviceIndexSelectorPass());
if (pass_config.shape_inference) {
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
@ -212,9 +208,6 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
// Saved model pass to mark global tensors immutable.
pm.addPass(mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
// Used to mark non-exported functions in saved model private.
pm.addPass(mlir::tf_saved_model::
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass());
// Op fusion pass.
pm.addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());

View File

@ -172,7 +172,7 @@ int main(int argc, char **argv) {
input_file_name, input_mlir, use_splatted_constant, custom_opdefs,
debug_info_file, input_arrays, input_dtypes, input_shapes,
output_arrays,
/*prune_unused_nodes=*/true, &source_mgr, &context);
/*prune_unused_nodes=*/true, upgrade_legacy, &source_mgr, &context);
}
// If errors occur, the library call in the above already logged the error

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Parser.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
@ -28,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
@ -39,19 +41,47 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
namespace {
using mlir::MLIRContext;
using mlir::ModuleOp;
using mlir::Operation;
using mlir::OwningModuleRef;
using stream_executor::port::StatusOr;
bool IsControlFlowV1Op(Operation* op) {
return mlir::isa<mlir::tf_executor::SwitchOp>(op) ||
mlir::isa<mlir::tf_executor::MergeOp>(op) ||
mlir::isa<mlir::tf_executor::EnterOp>(op) ||
mlir::isa<mlir::tf_executor::ExitOp>(op) ||
mlir::isa<mlir::tf_executor::NextIterationSinkOp>(op) ||
mlir::isa<mlir::tf_executor::NextIterationSourceOp>(op);
}
mlir::LogicalResult IsValidGraph(mlir::ModuleOp module) {
auto result = module.walk([&](Operation* op) {
return IsControlFlowV1Op(op) ? mlir::WalkResult::interrupt()
: mlir::WalkResult::advance();
});
if (result.wasInterrupted()) {
module.emitError(
"The graph has Control Flow V1 ops. TFLite converter doesn't support "
"Control Flow V1 ops. Consider using Control Flow V2 ops instead. See "
"https://www.tensorflow.org/api_docs/python/tf/compat/v1/"
"enable_control_flow_v2.");
return mlir::failure();
}
return mlir::success();
}
} // namespace
StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
const std::string& input_filename, bool input_mlir,
bool use_splatted_constant, const std::vector<std::string>& extra_tf_opdefs,
absl::string_view debug_info_file, absl::string_view input_arrays,
absl::string_view input_dtypes, absl::string_view input_shapes,
absl::string_view output_arrays, bool prune_unused_nodes,
llvm::SourceMgr* source_mgr, MLIRContext* context) {
bool enable_upgrade_legacy, llvm::SourceMgr* source_mgr,
MLIRContext* context) {
// Set up the input file.
std::string error_message;
auto file = mlir::openInputFile(input_filename, &error_message);
@ -86,14 +116,14 @@ StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
input_shapes, output_arrays, /*control_output_arrays=*/"",
prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
/*graph_as_function=*/false, /*upgrade_legacy=*/true,
/*graph_as_function=*/false, enable_upgrade_legacy,
/*enable_shape_inference=*/false, context);
}
return tensorflow::GraphdefToMlirTranslateFunction(
file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
input_shapes, output_arrays, /*control_output_arrays=*/"",
prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
/*graph_as_function=*/false, /*upgrade_legacy=*/true,
/*graph_as_function=*/false, enable_upgrade_legacy,
/*enable_shape_inference=*/false, context);
}
@ -104,7 +134,8 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
mlir::PassManager* pass_manager) {
mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(),
/*propagate=*/true);
if (failed(pass_manager->run(module))) {
if (failed(IsValidGraph(module)) || failed(pass_manager->run(module))) {
return statusHandler.ConsumeStatus();
}

View File

@ -41,7 +41,8 @@ LoadFromGraphdefOrMlirSource(
absl::string_view debug_info_file, absl::string_view input_arrays,
absl::string_view input_dtypes, absl::string_view input_shapes,
absl::string_view output_arrays, bool prune_unused_nodes,
llvm::SourceMgr* source_mgr, mlir::MLIRContext* context);
bool enable_upgrade_legacy, llvm::SourceMgr* source_mgr,
mlir::MLIRContext* context);
// Load Saved model (either v1 or v2) into MLIR.
stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel(

View File

@ -28,9 +28,11 @@ limitations under the License.
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Threading.h"
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
@ -767,13 +769,26 @@ void LegalizeTF::runOnFunction() {
[](Operation* op) {
auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op);
if (!tfl_op) return false;
return succeeded(tfl_op.VerifyTflRuntimeConstraints(
tfl_op.getOperation(),
/*failure_on_operand_type_mismatch=*/false));
return succeeded(tfl_op.VerifyTflRuntimeConstraints(op));
}));
} else {
target.addLegalDialect<TensorFlowLiteDialect>();
}
// Ignore transient errors by registering an no-op handler.
// Applying legalization patterns will emit unwanted, transient errors when
// the replaced TFLite ops do not meet the sanity checks. In order to ignore
// the transient errors, the following lines override a diagnostic handler
// with an no-op handler only while this pass runs.
uint64_t current_thread_id = llvm::get_threadid();
ScopedDiagnosticHandler scoped_diag_handler(
context, [&current_thread_id](Diagnostic&) -> LogicalResult {
// Consume only errors that are coming from the same thread in order not
// to ignore errors from other passes that are running. Things running
// in the pass manager can be multi-threaded.
return success(current_thread_id == llvm::get_threadid());
});
// Keep trying to convert.
// TODO(karimnosseir): This is similar to what apply greedy patterns does.
// Look if there is a function that tries until it converge.

View File

@ -91,9 +91,6 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass();
// Verifies runtime constraints.
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass();
// Creates function pass to select device index/fold tf.DeviceIndex.
std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass();
} // namespace TFL
} // namespace mlir

View File

@ -52,7 +52,7 @@ class PostQuantizePass : public PassWrapper<PostQuantizePass, FunctionPass> {
void RemoveQuantizationAdaptorOps(FuncOp func) {
mlir::OpBuilder builder(func.getBody());
auto& bb = func.getBlocks().front();
auto& bb = func.front();
auto* terminator = bb.getTerminator();
int num_args = bb.getNumArguments();

View File

@ -584,46 +584,50 @@ struct ConvertTFStridedSlice : public RewritePattern {
const int ellipsis_filled_dim_size = input_size - begin_shape[0] + 1;
llvm::APInt new_begin_mask = strided_slice_op.begin_mask();
llvm::APInt new_end_mask = strided_slice_op.end_mask();
int64_t begin_mask = strided_slice_op.begin_mask().getSExtValue();
int64_t end_mask = strided_slice_op.end_mask().getSExtValue();
int64_t new_begin_mask = 0;
int64_t new_end_mask = 0;
SmallVector<int32_t, 4> padded_begin;
SmallVector<int32_t, 4> padded_end;
SmallVector<int32_t, 4> padded_stride;
// Before the ellipsis.
uint64_t index = 1;
int count = 0;
while (index < ellipsis_mask) {
padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(count));
padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(count));
padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(count));
index <<= 1;
count++;
int index = 0;
int new_index = 0;
while (((ellipsis_mask >> index) & 1) == 0) {
padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(index));
padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(index));
padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(index));
if ((begin_mask >> index) & 1) new_begin_mask |= (1 << new_index);
if ((end_mask >> index) & 1) new_end_mask |= (1 << new_index);
++index;
++new_index;
}
// Ellipsis.
for (int i = 0; i < ellipsis_filled_dim_size; ++i) {
new_begin_mask |= ellipsis_mask;
new_end_mask |= ellipsis_mask;
for (; new_index < index + ellipsis_filled_dim_size; ++new_index) {
new_begin_mask |= (1 << new_index);
new_end_mask |= (1 << new_index);
// Mimic the begin/end/strides mask behavior.
padded_begin.push_back(0);
padded_end.push_back(0);
padded_stride.push_back(1);
ellipsis_mask <<= 1;
}
// Account for ellipsis mask.
count++;
++index;
// After the ellipsis.
for (; count < begin_shape[0]; ++count) {
padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(count));
padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(count));
padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(count));
for (; index < begin_shape[0]; ++index) {
padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(index));
padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(index));
padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(index));
if ((begin_mask >> index) & 1) new_begin_mask |= (1 << new_index);
if ((end_mask >> index) & 1) new_end_mask |= (1 << new_index);
}
auto attribute_type = rewriter.getIntegerType(64);
@ -645,7 +649,7 @@ struct ConvertTFStridedSlice : public RewritePattern {
end_op.getResult(), stride_op.getResult(),
rewriter.getIntegerAttr(attribute_type, new_begin_mask),
rewriter.getIntegerAttr(attribute_type, new_end_mask),
rewriter.getI64IntegerAttr(0),
/*ellipsis_maks=*/rewriter.getI64IntegerAttr(0),
rewriter.getIntegerAttr(attribute_type,
strided_slice_op.new_axis_mask()),
rewriter.getIntegerAttr(attribute_type,
@ -655,10 +659,12 @@ struct ConvertTFStridedSlice : public RewritePattern {
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
// TODO(renjieliu): Consider expand the transformation for shrink
// mask as well.
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
// TODO(renjieliu): Consider expand the transformation for shrink mask as
// well.
if (strided_slice_op.shrink_axis_mask().getZExtValue()) return failure();
// Handle new axis mask.
uint64_t new_axis_mask = strided_slice_op.new_axis_mask().getZExtValue();
if (new_axis_mask != 0) {

View File

@ -34,8 +34,7 @@ class RuntimeVerifyPass
void RuntimeVerifyPass::runOnFunction() {
getFunction().walk([&](TflRuntimeVerifyOpInterface op) {
if (failed(op.VerifyTflRuntimeConstraints(
op.getOperation(), /*failure_on_operand_type_mismatch=*/true)))
if (failed(op.VerifyTflRuntimeConstraints(op.getOperation())))
signalPassFailure();
});
}

View File

@ -57,6 +57,7 @@ gentbl(
td_srcs = [
":tensorflow_ops_td_files",
],
test = True,
)
gentbl(
@ -88,6 +89,7 @@ gentbl(
td_srcs = [
":tensorflow_ops_td_files",
],
test = True,
)
gentbl(
@ -112,6 +114,7 @@ gentbl(
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td",
],
test = True,
)
gentbl(
@ -137,6 +140,7 @@ gentbl(
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td",
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
],
test = True,
)
gentbl(
@ -161,6 +165,7 @@ gentbl(
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td",
],
test = True,
)
gentbl(
@ -475,6 +480,7 @@ cc_library(
"transforms/cluster_outlining.cc",
"transforms/collection_ops_util.cc",
"transforms/decompose_resource_ops_pass.cc",
"transforms/device_index_selector.cc",
"transforms/einsum.cc",
"transforms/executor_island_coarsening.cc",
"transforms/executor_tpuv1_inline_tpu_island.cc",
@ -491,7 +497,6 @@ cc_library(
"transforms/graph_pruning.cc",
"transforms/launch_to_device_attribute.cc",
"transforms/layout_optimization.cc",
"transforms/mark_function_visibility.cc",
"transforms/materialize_mlir_passthrough_op.cc",
"transforms/optimize.cc",
"transforms/optimize_global_tensors.cc",
@ -661,7 +666,9 @@ cc_library(
":tensorflow_types",
":translate_utils",
"//tensorflow/cc/saved_model:bundle_v2",
"//tensorflow/cc/saved_model:constants",
"//tensorflow/cc/saved_model:loader_lite",
"//tensorflow/cc/saved_model:loader_util",
"//tensorflow/compiler/jit:shape_inference_helpers",
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
"//tensorflow/compiler/tf2xla:functionalize_control_flow",
@ -673,6 +680,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler/utils:transitive_fanin",
"//tensorflow/core/platform:protobuf_internal",
"//tensorflow/core/platform:types",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/algorithm:container",
@ -682,7 +690,6 @@ cc_library(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
@ -1349,6 +1356,7 @@ cc_library(
srcs = ["utils/tpu_rewrite_device_util.cc"],
hdrs = ["utils/tpu_rewrite_device_util.h"],
deps = [
":tensorflow",
"//tensorflow/compiler/xla:array4d",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:computation_placer",
@ -1359,6 +1367,7 @@ cc_library(
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)
@ -1367,6 +1376,7 @@ tf_cc_test(
size = "small",
srcs = ["utils/tpu_rewrite_device_util_test.cc"],
deps = [
":device_util",
":tpu_rewrite_device_util",
"//tensorflow/core:framework",
"//tensorflow/core:test",

View File

@ -299,13 +299,13 @@ ParseResult ParseReplicateOp(OpAsmParser* parser, OperationState* state) {
parser->parseRegion(body, region_args, region_arg_types))
return failure();
if (body.getBlocks().size() > 1)
return parser->emitError(loc) << "expects a single block region";
// Ensure that the region is well formed: it contains at least a block with
// a ReturnOp terminator.
ReplicateOp::ensureTerminator(body, parser->getBuilder(), state->location);
if (!llvm::hasSingleElement(body))
return parser->emitError(loc) << "expects a single block region";
Operation& terminator = body.front().back();
if (!isa<ReturnOp>(terminator))
return parser->emitError(loc) << "expects a tf_device.return terminator";

View File

@ -220,13 +220,13 @@ ParseResult ParseGraphOp(OpAsmParser &parser, OperationState &result) {
Region &body = *result.addRegion();
if (parser.parseRegion(body, llvm::None, llvm::None)) return failure();
if (body.getBlocks().size() > 1)
return parser.emitError(loc) << "expects a single block region";
// Ensure that the region is well formed: it contains at least a block with
// a FetchOp terminator.
GraphOp::ensureTerminator(body, parser.getBuilder(), result.location);
if (!llvm::hasSingleElement(body))
return parser.emitError(loc) << "expects a single block region";
// Get the results type from the terminator type inside the graph.
Operation &fetch = body.back().back();
if (!isa<FetchOp>(fetch))

File diff suppressed because it is too large Load Diff

View File

@ -232,6 +232,7 @@ else_branch: A function that takes 'inputs' and returns a list of
def TF_YieldOp : TF_Op<"Yield", [Terminator]> {
let summary = "Yield operation";
let description = [{
The "yield" operation represents a return operation within the conditional
and body of structured control flow (e.g., if and while). The operation
@ -497,6 +498,7 @@ Inserts a placeholder for a tensor that will be always fed.
def TF_PlaceholderWithDefaultOp : TF_Op<"PlaceholderWithDefault", [NoSideEffect]> {
let summary = "Placeholder op";
let description = [{
A placeholder op that passes through input when its output is not fed.
}];
@ -839,9 +841,6 @@ def TF_XlaShardingOp : TF_Op<"XlaSharding", [NoSideEffect]> {
An op which shards the input based on the given sharding attribute.
}];
let description = [{
}];
let arguments = (ins
TF_Tensor:$input,
@ -858,9 +857,6 @@ An op which shards the input based on the given sharding attribute.
def TF_InfeedDequeueTupleOp : TF_Op<"InfeedDequeueTuple", []> {
let summary = "Fetches multiple values from infeed as an XLA tuple.";
let description = [{
}];
let arguments = (ins
OptionalAttr<StrAttr>:$_XlaSharding
);
@ -904,9 +900,6 @@ def TF_BatchDatasetV2Op : TF_Op<"BatchDatasetV2", [NoSideEffect]> {
Creates a dataset that batches `batch_size` elements from `input_dataset`.
}];
let description = [{
}];
let arguments = (ins
TF_VariantTensor:$input_dataset,
I64Tensor:$batch_size,
@ -1048,4 +1041,46 @@ operation create / operate on a copy of `x`.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_BesselI0eOp : TF_Op<"BesselI0e", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes the Bessel i0e function of `x` element-wise.";
let description = [{
Exponentially scaled modified Bessel function of order 0 defined as
`bessel_i0e(x) = exp(-abs(x)) bessel_i0(x)`.
This function is faster and numerically stabler than `bessel_i0(x)`.
}];
let arguments = (ins
TF_FpTensor:$x
);
let results = (outs
TF_FpTensor:$y
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_BesselI1eOp : TF_Op<"BesselI1e", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes the Bessel i1e function of `x` element-wise.";
let description = [{
Exponentially scaled modified Bessel function of order 0 defined as
`bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`.
This function is faster and numerically stabler than `bessel_i1(x)`.
}];
let arguments = (ins
TF_FpTensor:$x
);
let results = (outs
TF_FpTensor:$y
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
#endif // TF_OPS

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "mlir/IR/Identifier.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
@ -76,6 +77,23 @@ static LogicalResult Verify(GlobalTensorOp global_tensor) {
return success();
}
static LogicalResult Verify(SessionInitializerOp session_initializer) {
mlir::SymbolTable symbol_table(
session_initializer.getParentOfType<ModuleOp>());
auto init_func_op =
symbol_table.lookup<mlir::FuncOp>(session_initializer.initializer());
if (!init_func_op)
return session_initializer.emitOpError()
<< "the initializer function does not exist";
if (!init_func_op.getType().getResults().empty())
return session_initializer.emitOpError()
<< "the initializer function should have no output";
return success();
}
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc"
@ -212,14 +230,36 @@ static LogicalResult VerifySavedModelModule(
}
}
for (auto func : module.getOps<FuncOp>()) {
if (HasAnyTfSavedModelArgAttr(func)) {
if (!IsExported(func)) {
return func.emitError()
<< "can only apply 'tf_saved_model' argument attributes "
"to exported functions";
}
const bool is_exported = IsExported(func);
if (is_exported && !func.isPublic()) {
return func.emitError()
<< "exported function @" << func.getName() << " should be public";
}
if (!is_exported && func.isPublic()) {
return func.emitError() << "non-exported function @" << func.getName()
<< " should be private";
}
if (!is_exported && HasAnyTfSavedModelArgAttr(func)) {
return func.emitError() << "can only apply 'tf_saved_model' argument "
"attributes to exported functions";
}
}
auto session_initializers = module.getOps<SessionInitializerOp>();
if (!session_initializers.empty() &&
!llvm::hasSingleElement(session_initializers)) {
return (*++session_initializers.begin()).emitError()
<< "there must be no more than one session_initializer op";
}
auto is_init = [&session_initializers](mlir::FuncOp func) {
if (session_initializers.empty()) return false;
return (*session_initializers.begin()).initializer() == func.getName();
};
SymbolTable symbol_table(module);
auto symbol_uses = SymbolTable::getSymbolUses(&module.getBodyRegion());
if (!symbol_uses.hasValue()) {
@ -230,6 +270,12 @@ static LogicalResult VerifySavedModelModule(
auto func = symbol_table.lookup<FuncOp>(
symbol_use.getSymbolRef().cast<FlatSymbolRefAttr>().getValue());
if (func && IsExported(func)) {
// If it is an init function, then it can be used by the unique
// session_initializer op.
if (is_init(func) &&
llvm::isa<SessionInitializerOp>(symbol_use.getUser()))
continue;
return symbol_use.getUser()
->emitError("exported function cannot be internally referenced")
.attachNote(func.getLoc())
@ -349,5 +395,39 @@ GlobalTensorOp LookupBoundInput(FuncOp func, int arg_index,
return symbol_table.lookup<GlobalTensorOp>(attr.getValue());
}
SessionInitializerOp GetSessionInitializerOp(mlir::ModuleOp op) {
auto initializers = op.getOps<SessionInitializerOp>();
if (initializers.empty()) return {};
return *initializers.begin();
}
class OptimizeSessionInitializerPattern
: public OpRewritePattern<SessionInitializerOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(SessionInitializerOp op,
PatternRewriter &rewriter) const override {
SymbolTable symbol_table(op.getParentOfType<ModuleOp>());
auto init_func_op = symbol_table.lookup<mlir::FuncOp>(op.initializer());
// The init function can only be referenced from the SessionInitializerOp.
// And there is at most one SessionInitializerOp in the module. So both ops
// have no other uses and can be simply erased.
if (init_func_op.front().begin()->isKnownTerminator()) {
rewriter.eraseOp(init_func_op);
rewriter.eraseOp(op);
return success();
}
return failure();
}
};
void SessionInitializerOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<OptimizeSessionInitializerPattern>(context);
}
} // namespace tf_saved_model
} // namespace mlir

View File

@ -61,6 +61,10 @@ GlobalTensorOp LookupBoundInput(FuncOp func, int arg_index,
// should have.
Type GetBoundInputArgTypeFor(GlobalTensorOp global_tensor);
// Returns the session initializer of this module if it exists. Returns null
// otherwise.
SessionInitializerOp GetSessionInitializerOp(mlir::ModuleOp op);
} // namespace tf_saved_model
} // namespace mlir

View File

@ -128,4 +128,30 @@ def TfSavedModel_GlobalTensorOp : TfSavedModel_Op<"global_tensor"> {
let verifier = [{ return Verify(*this); }];
}
def TfSavedModel_SessionInitializerOp: TfSavedModel_Op<"session_initializer"> {
let summary = "Initializes TensorFlow session state.";
let description = [{
The session initializer op marks a function that must be called by an
external agent exactly once to initialize TensorFlow session state, and this
must happen before any other exported functions are called. There must be no
more than one session initializer in a saved model.
The `initializer` represents the initialization function. The function have
no output and this function should be only called once.
This is used, for example, to initialize hash tables stored in resources and
accessed by resource name (rather than as resource handles or bound inputs
which is how `global_tensor`s are referenced)
}];
let arguments = (ins
FlatSymbolRefAttr:$initializer
);
let verifier = [{ return Verify(*this); }];
let hasCanonicalizer = 1;
}
#endif // SAVED_MODEL_DIALECT

View File

@ -1,47 +0,0 @@
// RUN: tf-opt -tf-saved-model-mark-func-visibility -split-input-file %s | FileCheck --check-prefix=SAVEDMODEL %s
// RUN: tf-opt -tf-mark-func-visibility -split-input-file -verify-diagnostics %s | FileCheck %s
module attributes {tf_saved_model.semantics} {
// SAVEDMODEL: func @func_exported_1() attributes {tf_saved_model.exported_names = ["func_exported_1"]}
func @func_exported_1() attributes {tf_saved_model.exported_names = ["func_exported_1"]} {
"tf.some_call"() {callee = {callee = {callee = @child}}} : () -> ()
return
}
// SAVEDMODEL: func @func_exported_2() attributes {tf_saved_model.exported_names = ["func_exported_2"]}
func @func_exported_2() attributes {tf_saved_model.exported_names = ["func_exported_2"]} {
"tf.some_call"() {callee = {callee = {callee = @child}}} : () -> ()
return
}
// SAVEDMODEL: func @func_not_exported() attributes {sym_visibility = "private"}
func @func_not_exported() {
return
}
}
// -----
module {
// CHECK: func @func_with_entry_spec(%arg0: tensor<1xi32>) -> tensor<1xi32> attributes {tf.entry_function = {inputs = "x", outputs = "y"}}
func @func_with_entry_spec(%arg0: tensor<1xi32>) -> tensor<1xi32> attributes {tf.entry_function = {inputs = "x", outputs = "y"}} {
return %arg0 : tensor<1xi32>
}
// CHECK: func @func_without_entry_spec(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> attributes {sym_visibility = "private"}
func @func_without_entry_spec(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
%0 = "tf.AddV2"(%arg0, %arg1) {T = i32, device = ""} : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
return %0 : tensor<*xi32>
}
}
// -----
module {
// expected-error @+1 {{can't overwrite the visibility of function private_func_with_entry_spec with private visibility}}
func @private_func_with_entry_spec(%arg0: tensor<1xi32>) -> tensor<1xi32> attributes {tf.entry_function = {inputs = "x", outputs = "y"}, sym_visibility = "private"} {
return %arg0 : tensor<1xi32>
}
}

View File

@ -433,4 +433,28 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
// CHECK: return %[[CAST_RESULT_0]], %[[CAST_RESULT_1]], %[[ADDI]]
return %27, %28, %2 : tensor<*xui8>, tensor<*xi8>, tensor<*xi8>
}
// CHECK-LABEL: infer_device_launch
func @infer_device_launch(%arg0: tensor<1x8x2xi32>) -> (tensor<*xf32>, tensor<*xf32>) {
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%1 = "tf_device.launch"() ({
%2 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1x8x2xi32>) -> tensor<1x8x2xf32>
tf_device.return %2 : tensor<1x8x2xf32>
// CHECK: () -> tensor<1x8x2xf32>
}) {device = "/device:CPU:0"} : () -> tensor<*xf32>
// CHECK: "tf.Cast"(%{{.*}}) {Truncate = false} : (tensor<1x8x2xf32>) -> tensor<*xf32>
// CHECK: (tensor<i32>, tensor<1x8x2xf32>) -> (tensor<1x8x1xf32>, tensor<1x8x1xf32>)
%3:2 = "tf.Split"(%0, %1) {device = ""} : (tensor<i32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>)
%4 = addf %1, %1 : tensor<*xf32>
return %3#0, %3#1 : tensor<*xf32>, tensor<*xf32>
}
// CHECK-LABEL: func @tensor_cast(%arg0: tensor<1xi32>) -> tensor<1xi32>
func @tensor_cast(%arg0: tensor<1xi32>) -> tensor<*xi32> {
// CHECK: %[[RESULT:.*]] = tensor_cast
// CHECK-SAME: tensor<1xi32> to tensor<1xi32>
// CHECK: return %[[RESULT]] : tensor<1xi32>
%1 = tensor_cast %arg0 : tensor<1xi32> to tensor<*xi32>
return %1 : tensor<*xi32>
}
}

View File

@ -4,8 +4,6 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "lit_test")
def tf_saved_model_test(name, data, tags = None):
"""Create a SavedModel test."""
if tags == None:
tags = ["no_rocm"]
native.py_binary(
name = name,
testonly = 1,
@ -26,5 +24,5 @@ def tf_saved_model_test(name, data, tags = None):
name = name + ".py",
data = [name] + data,
driver = "@llvm-project//mlir:run_lit.sh",
tags = tags,
tags = tags + ["no_rocm"],
)

View File

@ -46,7 +46,10 @@ def set_tf_options():
# This function needs to take a "create_module_fn", as opposed to just the
# module itself, because the creation of the module has to be delayed until
# after absl and tensorflow have run various initialization steps.
def do_test(signature_def_map, show_debug_info=False):
def do_test(signature_def_map,
init_op=None,
canonicalize=False,
show_debug_info=False):
"""Runs test.
1. Performs absl and tf "main"-like initialization that must run before almost
@ -61,6 +64,9 @@ def do_test(signature_def_map, show_debug_info=False):
Args:
signature_def_map: A map from string key to signature_def. The key will be
used as function name in the resulting MLIR.
init_op: The initializer op for the saved model. If set, it will generate a
initializer graph in the resulting MLIR.
canonicalize: If true, canonicalizer will be run on the resulting MLIR.
show_debug_info: If true, shows debug locations in the resulting MLIR.
"""
@ -84,6 +90,7 @@ def do_test(signature_def_map, show_debug_info=False):
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map,
main_op=init_op,
strip_default_attrs=True)
builder.save()
@ -97,6 +104,9 @@ def do_test(signature_def_map, show_debug_info=False):
mlir = pywrap_mlir.experimental_run_pass_pipeline(mlir,
'tf-standard-pipeline',
show_debug_info)
if canonicalize:
mlir = pywrap_mlir.experimental_run_pass_pipeline(mlir, 'canonicalize',
show_debug_info)
print(mlir)
app.run(app_main)

View File

@ -0,0 +1,92 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# RUN: %p/hash_table_v1 | FileCheck %s
# pylint: disable=missing-docstring,line-too-long
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v1 as tf
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
# Verify that the tf.versions attribute exists. It is difficult to enforce
# contents, since the version numbers change over time. The conversion logic
# itself is verified in the common graphdef converter, so here just assert
# it is being invoked.
# CHECK: module
# CHECK-SAME: tf.versions
# CHECK-SAME: bad_consumers
# CHECK-SAME: min_consumer
# CHECK-SAME: producer
# CHECK: "tf_saved_model.session_initializer"() {initializer = [[init:@.*]]} : () -> ()
# CHECK: "tf_saved_model.global_tensor"()
# CHECK: func [[init]]
# CHECK-NEXT: [[R5:%.*]] = "tf.Const"()
# CHECK-NEXT: [[R6:%.*]] = "tf.Const"()
# CHECK-NEXT: [[R7:%.*]] = "tf.HashTableV2"()
# CHECK-SAME: shared_name = "[[hash_table:.*]]"
# CHECK-NEXT: "tf.LookupTableImportV2"([[R7]], [[R5]], [[R6]])
# CHECK: func {{@[a-zA-Z_0-9]+}}(
# CHECK-SAME: [[ARG0:%.*]]: tensor<i32>
# CHECK-SAME: [[ARG1:%.*]]: tensor<!tf.resource
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["key"]
# CHECK-NEXT: [[R0:%.*]] = "tf.Const"()
# CHECK-NEXT: [[R1:%.*]] = "tf.HashTableV2"()
# CHECK-SAME: shared_name = "[[hash_table]]"
# CHECK-NEXT: [[R2:%.*]] = "tf.LookupTableFindV2"([[R1]], [[ARG0]], [[R0]])
# CHECK-NEXT: [[R3:%.*]] = "tf.ReadVariableOp"([[ARG1]])
# CHECK-NEXT: [[R4:%.*]] = "tf.AddV2"([[R2]], [[R3]])
# CHECK-NEXT: return [[R4]]
def Test():
z = tf.compat.v1.get_variable(
name='y',
shape=(),
initializer=tf.random_normal_initializer(),
trainable=True)
table_initializer = tf.lookup.KeyValueTensorInitializer(
keys=[1, 2, 3, 4],
values=[5, 6, 7, 8],
key_dtype=tf.int32,
value_dtype=tf.float32)
table = tf.lookup.StaticHashTable(
table_initializer, default_value=tf.constant(0.0))
x = tf.placeholder(tf.int32, shape=(), name='input')
y = table.lookup(x)
r = tf.add(y, z)
tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x)
tensor_info_r = tf.compat.v1.saved_model.utils.build_tensor_info(r)
return {
'key': (tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
inputs={'x': tensor_info_x},
outputs={'r': tensor_info_r},
method_name='some_function'))
}
if __name__ == '__main__':
common_v1.set_tf_options()
common_v1.do_test(Test(), tf.tables_initializer())

View File

@ -0,0 +1,74 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# RUN: %p/remove_init_variable_v1 | FileCheck %s
# pylint: disable=missing-docstring,line-too-long
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v1 as tf
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
# Verify that the tf.versions attribute exists. It is difficult to enforce
# contents, since the version numbers change over time. The conversion logic
# itself is verified in the common graphdef converter, so here just assert
# it is being invoked.
# CHECK: module
# CHECK-SAME: tf.versions
# CHECK-SAME: bad_consumers
# CHECK-SAME: min_consumer
# CHECK-SAME: producer
# CHECK: "tf_saved_model.global_tensor"() {is_mutable, sym_name = "[[VAR:[a-zA-Z_0-9]+]]", type = tensor<1x3xf32>, value = {{.*}} : tensor<1x3xf32>} : () -> ()
# CHECK-NOT: session_initializer
# CHECK: func {{@[a-zA-Z_0-9]+}}(
# CHECK-SAME: [[ARG0:%.*]]: tensor<3x1xf32> {tf_saved_model.index_path = ["x"]},
# CHECK-SAME: [[ARG1:%.*]]: tensor<!tf.resource<tensor<1x3xf32>>> {tf_saved_model.bound_input = @[[VAR]]})
# CHECK-SAME: -> (tensor<3x3xf32> {tf_saved_model.index_path = ["r"]})
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["key"]
# CHECK-NEXT: [[R0:%.*]] = "tf.ReadVariableOp"([[ARG1]]) {{{.*}}} : (tensor<!tf.resource<tensor<1x3xf32>>>) -> tensor<1x3xf32>
# CHECK-NEXT: [[R1:%.*]] = "tf.MatMul"([[ARG0]], [[R0]]) {{{.*}}} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
# CHECK-NEXT: return [[R1]] : tensor<3x3xf32>
def Test():
x = tf.constant([[1.0], [1.0], [1.0]])
y = tf.compat.v1.get_variable(
name='y',
shape=(1, 3),
initializer=tf.random_normal_initializer(),
trainable=True)
r = tf.matmul(x, y)
tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x)
tensor_info_r = tf.compat.v1.saved_model.utils.build_tensor_info(r)
return {
'key': (tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
inputs={'x': tensor_info_x},
outputs={'r': tensor_info_r},
method_name='some_function'))
}
if __name__ == '__main__':
common_v1.set_tf_options()
common_v1.do_test(
Test(), tf.initializers.global_variables(), canonicalize=True)

View File

@ -1,96 +0,0 @@
// RUN: tf-opt -tf-saved-model-mark-func-visibility -symbol-dce -split-input-file %s | FileCheck %s
module attributes {tf_saved_model.semantics} {
// Test case: Unused function should be deleted.
// CHECK-NOT: func @unused
func @unused() {
return
}
}
// -----
module attributes {tf_saved_model.semantics} {
// Test case: Root calls child. Child should not be deleted.
// CHECK: func @root
func @root() attributes {tf_saved_model.exported_names = ["root"]} {
"tf.some_call"() { callee = @child } : () -> ()
return
}
// CHECK: func @child
func @child() {
return
}
}
// -----
module attributes {tf_saved_model.semantics} {
// Test case: Don't crash if attribute that doesn't reference a func.
"tf.some_opaque_global_variable"() { sym_name = "some_global" } : () -> ()
func @root2() attributes {tf_saved_model.exported_names = ["root2"]} {
"tf.do_something_with_a_global"() { global = @some_global } : () -> ()
return
}
}
// -----
module attributes {tf_saved_model.semantics} {
// Test case: Delete recursively dead cycle.
// CHECK-NOT: func @recursively_dead0
func @recursively_dead0() {
"tf.some_call"() { callee = @recursively_dead1 } : () -> ()
return
}
// CHECK-NOT: func @recursively_dead1
func @recursively_dead1() {
"tf.some_call"() { callee = @recursively_dead0 } : () -> ()
return
}
}
// -----
module attributes {tf_saved_model.semantics} {
// Test case: Root calls child with a deeply nested symbol reference.
// Child should not be deleted.
// CHECK: func @root
func @root() attributes {tf_saved_model.exported_names = ["root"]} {
"tf.some_call"() {callee = {callee = {callee = @child}}} : () -> ()
return
}
// CHECK: func @child
func @child() {
return
}
}
// -----
// Test case: If the module doesn't have tf_saved_model semantics, then this
// pass shouldn't do anything.
module {
// CHECK: func @not_dead()
func @not_dead() {
return
}
}

View File

@ -64,7 +64,7 @@ module attributes {tf_saved_model.semantics} {
return
}
func @f_callee(%arg0: tensor<!tf.resource<tensor<f32>>>) {
func @f_callee(%arg0: tensor<!tf.resource<tensor<f32>>>) attributes {sym_visibility = "private"} {
return
}
}

View File

@ -2,6 +2,11 @@
module attributes {tf_saved_model.semantics} {
// CHECK: tf_saved_model.session_initializer
"tf_saved_model.session_initializer"() {
initializer = @init
} : () -> ()
// Representation for constants: (immutable) global tensor.
// CHECK: tf_saved_model.global_tensor
"tf_saved_model.global_tensor"() {
@ -35,7 +40,18 @@ module attributes {tf_saved_model.semantics} {
return %arg0 : tensor<f32>
}
func @f() {
func @f() attributes {sym_visibility = "private"} {
return
}
// Representation for init functions
// CHECK: func @init
// CHECK-SAME: exported_names = ["__tf_saved_model_session_initializer"]
func @init(
%arg1: tensor<!tf.resource<tensor<1x64xf32>>> {tf_saved_model.bound_input = @some_constant}
) attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"]}
{
"tf.some_call"(%arg1) : (tensor<!tf.resource<tensor<1x64xf32>>>) -> ()
return
}

View File

@ -3,7 +3,7 @@
module attributes {tf_saved_model.semantics} {
// expected-error@+1 {{unknown tf_saved_model dialect arg attribute 'tf_saved_model.not_a_real_arg_attr'}}
func @f(%arg0: tensor<f32> {tf_saved_model.not_a_real_arg_attr = 1 : i32}) {
func @f(%arg0: tensor<f32> {tf_saved_model.not_a_real_arg_attr = 1 : i32}) attributes {sym_visibility = "private"} {
return
}
@ -233,7 +233,7 @@ module attributes {tf_saved_model.semantics} {
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<?xf32>, value = dense<1.> : tensor<1xf32> } : () -> ()
// expected-error@+1 {{can only apply 'tf_saved_model' argument attributes to exported functions}}
func @f(%arg0: tensor<!tf.resource<tensor<?xf32>>> {tf_saved_model.bound_input = @v})
-> (tensor<?xf32> {tf_saved_model.index_path = []}) {
-> (tensor<?xf32> {tf_saved_model.index_path = []}) attributes {sym_visibility = "private"} {
%0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<?xf32>>>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -258,3 +258,97 @@ module attributes {tf_saved_model.semantics} {
// expected-error@+1 {{'type' attribute for immutable 'tf_saved_model.global_tensor' should have a static shape}}
"tf_saved_model.global_tensor"() { sym_name = "v", type = tensor<?xf32>, value = dense<1.> : tensor<1xf32> } : () -> ()
}
// -----
module attributes {tf_saved_model.semantics} {
// expected-error@+1 {{the initializer function does not exist}}
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
}
// -----
module attributes {tf_saved_model.semantics} {
// expected-error@+1 {{the initializer function should have no output}}
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
func @init() -> tensor<1xf32> attributes {sym_visibility = "private"} {
%0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32>
return %0 : tensor<1xf32>
}
}
// -----
module attributes {tf_saved_model.semantics} {
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
// expected-error@+1 {{there must be no more than one session_initializer op}}
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
func @init() -> tensor<1xf32> attributes {sym_visibility = "private"} {
%0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32>
return %0 : tensor<1xf32>
}
}
// -----
module attributes {tf_saved_model.semantics} {
// expected-error@+1 {{exported function @f should be public}}
func @f(
%arg0: tensor<f32> {tf.resource_name = "resource"}
) attributes { sym_visibility = "private", tf_saved_model.exported_names = ["foo.some_func"] } {
return
}
}
// -----
module attributes {tf_saved_model.semantics} {
// expected-error@+1 {{non-exported function @f should be private}}
func @f(
%arg0: tensor<f32> {tf.resource_name = "resource"}
) {
return
}
}
// -----
module attributes {tf_saved_model.semantics} {
// expected-error@+1 {{the initializer function does not exist}}
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
}
// -----
module attributes {tf_saved_model.semantics} {
// expected-error@+1 {{the initializer function should have no output}}
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
func @init() -> (tensor<1xf32> {tf_saved_model.index_path = ["output"]})
attributes { tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"] } {
%0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32>
return %0 : tensor<1xf32>
}
}
// -----
module attributes {tf_saved_model.semantics} {
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
// expected-error@+1 {{there must be no more than one session_initializer op}}
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
func @init() -> (tensor<1xf32> {tf_saved_model.index_path = ["output"]})
attributes { tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"] } {
%0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32>
return %0 : tensor<1xf32>
}
}

View File

@ -1,7 +1,7 @@
// RUN: tf-opt -tf-saved-model-optimize-global-tensors -split-input-file %s | FileCheck %s
//===----------------------------------------------------------------------===//
// Freezing.
// Immutability.
//===----------------------------------------------------------------------===//
module attributes {tf_saved_model.semantics} {
@ -142,3 +142,89 @@ module attributes {tf_saved_model.semantics} {
// Test running the pass on a module that does not have
// tf_saved_model.semantics.
module {}
// -----
// Test use as an input in unhandled op
module attributes {tf_saved_model.semantics} {
// CHECK: "tf_saved_model.global_tensor"() {
// CHECK-SAME: is_mutable
// CHECK-SAME: } : () -> ()
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
attributes {tf_saved_model.exported_names = ["f"]} {
"tf.unhandled_op"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> ()
return
}
}
// -----
// Test use as a region capture in an unhandled op
module attributes {tf_saved_model.semantics} {
// CHECK: "tf_saved_model.global_tensor"() {
// CHECK-SAME: is_mutable
// CHECK-SAME: } : () -> ()
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
attributes {tf_saved_model.exported_names = ["f"]} {
"tf.unhandled"() ({
%val = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
"tf.unhandled_terminator"() : () -> ()
}) : () -> ()
return
}
}
// -----
// Test use as region capture as well as input in an unhandled op
// to the unhandled op.
module attributes {tf_saved_model.semantics} {
// CHECK: "tf_saved_model.global_tensor"() {
// CHECK-SAME: is_mutable
// CHECK-SAME: } : () -> ()
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
// CHECK: "tf_saved_model.global_tensor"() {
// CHECK-SAME: is_mutable
// CHECK-SAME: } : () -> ()
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "u", type = tensor<f32>, value = dense<22.> : tensor<f32> } : () -> ()
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}, %arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @u})
attributes {tf_saved_model.exported_names = ["f"]} {
%0 = "tf.unhandled"(%arg0) ({
%val = "tf.ReadVariableOp"(%arg1) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
"tf.unhandled_terminator"() : () -> ()
}) : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<!tf.resource<tensor<f32>>>)
return
}
}
// -----
// Test multiple global tensors uses as operands for an unhandled op.
module attributes {tf_saved_model.semantics} {
// CHECK: "tf_saved_model.global_tensor"() {
// CHECK-SAME: is_mutable
// CHECK-SAME: } : () -> ()
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
// CHECK: "tf_saved_model.global_tensor"() {
// CHECK-SAME: is_mutable
// CHECK-SAME: } : () -> ()
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "u", type = tensor<f32>, value = dense<22.> : tensor<f32> } : () -> ()
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}, %arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @u})
attributes {tf_saved_model.exported_names = ["f"]} {
"tf.unhandled"(%arg0, %arg1) : (tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<f32>>>) -> ()
return
}
}

Some files were not shown because too many files have changed in this diff Show More