Merge branch 'upstream/master' into interface_16x8

This commit is contained in:
Elena Zhelezina 2020-05-08 11:35:26 +01:00
commit 0391c064f5
2679 changed files with 114105 additions and 52428 deletions

View File

@ -356,9 +356,10 @@ build:rbe_linux --linkopt=-lm
build:rbe_cpu_linux --config=rbe_linux
build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain"
build:rbe_cpu_linux --extra_toolchains="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:cc-toolchain-k8"
build:rbe_cpu_linux --extra_execution_platforms"=@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
build:rbe_cpu_linux --host_platform="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
build:rbe_cpu_linux --platforms="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
build:rbe_cpu_linux --extra_execution_platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
build:rbe_cpu_linux --extra_execution_platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
build:rbe_cpu_linux --host_platform="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
build:rbe_cpu_linux --platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
build:rbe_linux_cuda_base --config=rbe_linux
build:rbe_linux_cuda_base --repo_env=TF_NEED_TENSORRT=1
@ -380,17 +381,37 @@ build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-gcc7_
build:rbe_linux_cuda_nvcc --define=using_cuda_nvcc=true
test:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
build:rbe_linux_cuda_clang --config=rbe_linux_cuda_base
build:rbe_linux_cuda_clang --crosstool_top="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_clang --extra_toolchains="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_clang --extra_execution_platforms="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_clang --host_platform="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_clang --platforms="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_clang --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_clang --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_clang --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_clang --define=using_cuda_clang=true
test:rbe_linux_cuda_clang --config=rbe_linux_cuda_base
build:rbe_linux_cuda_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda_nvcc_base --crosstool_top="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_nvcc_base --extra_toolchains="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc_base --extra_execution_platforms="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --host_platform="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --platforms="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_nvcc_base --define=using_cuda_nvcc=true
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
build:rbe_linux_cuda_clang_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda_clang_base --crosstool_top="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_clang_base --extra_toolchains="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_clang_base --extra_execution_platforms="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_clang_base --host_platform="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_clang_base --platforms="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_clang_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_clang_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_clang_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_clang_base --define=using_cuda_clang=true
build:rbe_linux_cuda_clang_py27 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
build:rbe_linux_cuda_clang_py35 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
build:rbe_linux_cuda_clang_py36 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
build:rbe_linux_cuda_clang_py37 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
build:rbe_linux_cuda_clang_py38 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
common:rbe_gpu_linux --config=rbe_linux_cuda_nvcc

View File

@ -38,6 +38,9 @@ state what is wrong:
- Producing correct results, but the model is slower than expected (model generated from old converter)
**RNN conversion support**
If converting TF RNN to TFLite fused RNN ops, please prefix [RNN] in the title.
**Any other info / logs**
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

39
.github/stale.yml vendored Normal file
View File

@ -0,0 +1,39 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
#
# THIS IS A GENERATED DOCKERFILE.
#
# This file was assembled from multiple pieces, whose use is documented
# throughout. Please refer to the TensorFlow dockerfiles documentation
# for more information.
# Number of days of inactivity before an Issue or Pull Request becomes stale
daysUntilStale: 7
# Number of days of inactivity before a stale Issue or Pull Request is closed
daysUntilClose: 7
# Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable
onlyLabels:
- stat:awaiting response
# Comment to post when marking as stale. Set to `false` to disable
markComment: >
This issue has been automatically marked as stale because it has not had
recent activity. It will be closed if no further activity occurs. Thank you.
# Comment to post when removing the stale label. Set to `false` to disable
unmarkComment: false
closeComment: >
Closing as stale. Please reopen if you'd like to work on this further.
limitPerRun: 30
# Limit to only `issues` or `pulls`
only: issues

View File

@ -1171,14 +1171,16 @@ def system_specific_test_config(environ_cp):
test_only_filters = ['-oss_serial']
if is_windows():
test_and_build_filters.append('-no_windows')
if environ_cp.get('TF_NEED_CUDA', None) == '1':
if ((environ_cp.get('TF_NEED_CUDA', None) == '1') or
(environ_cp.get('TF_NEED_ROCM', None) == '1')):
test_and_build_filters += ['-no_windows_gpu', '-no_gpu']
else:
test_and_build_filters.append('-gpu')
elif is_macos():
test_and_build_filters += ['-gpu', '-nomac', '-no_mac']
elif is_linux():
if environ_cp.get('TF_NEED_CUDA', None) == '1':
if ((environ_cp.get('TF_NEED_CUDA', None) == '1') or
(environ_cp.get('TF_NEED_ROCM', None) == '1')):
test_and_build_filters.append('-no_gpu')
write_to_bazelrc('test --test_env=LD_LIBRARY_PATH')
else:
@ -1416,6 +1418,10 @@ def main():
write_action_env_to_bazelrc('LD_LIBRARY_PATH',
environ_cp.get('LD_LIBRARY_PATH'))
if (environ_cp.get('TF_NEED_ROCM') == '1' and environ_cp.get('ROCM_PATH')):
write_action_env_to_bazelrc('ROCM_PATH', environ_cp.get('ROCM_PATH'))
write_action_env_to_bazelrc('ROCM_ROOT', environ_cp.get('ROCM_PATH'))
environ_cp['TF_NEED_CUDA'] = str(
int(get_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False)))
if (environ_cp.get('TF_NEED_CUDA') == '1' and

View File

@ -523,6 +523,12 @@ package_group(
],
)
package_group(name = "ndarray_tensor_allow_list")
# Packages that use composite tensors or dispatch.
# TODO(b/154762408) Remove this package group once it's no longer needed.
package_group(name = "composite_tensor_whitelist")
filegroup(
name = "intel_binary_blob",
data = if_mkl_ml(

View File

@ -116,7 +116,7 @@ from tensorflow.python.lib.io import file_io as _fi
# Get sitepackages directories for the python installation.
_site_packages_dirs = []
_site_packages_dirs += [_site.USER_SITE]
_site_packages_dirs += [] if _site.USER_SITE is None else [_site.USER_SITE]
_site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p]
if 'getsitepackages' in dir(_site):
_site_packages_dirs += _site.getsitepackages()

View File

@ -126,7 +126,7 @@ from tensorflow.python.lib.io import file_io as _fi
# Get sitepackages directories for the python installation.
_site_packages_dirs = []
_site_packages_dirs += [_site.USER_SITE]
_site_packages_dirs += [] if _site.USER_SITE is None else [_site.USER_SITE]
_site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p]
if 'getsitepackages' in dir(_site):
_site_packages_dirs += _site.getsitepackages()

View File

@ -58,6 +58,7 @@ filegroup(
name = "pywrap_required_hdrs",
srcs = [
"c_api_internal.h",
"conversion_macros.h",
"python_api.h",
"tensor_interface.h",
"tf_status_helper.h",
@ -86,6 +87,13 @@ tf_cuda_library(
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//tensorflow:chromiumos": [
":tf_attrtype",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/platform:platform",
],
"//conditions:default": [
":tf_attrtype",
"//tensorflow/core:core_cpu",
@ -118,6 +126,13 @@ cc_library(
visibility = ["//visibility:public"],
)
cc_library(
name = "c_api_macros",
hdrs = ["c_api_macros.h"],
copts = tf_copts(),
visibility = ["//visibility:public"],
)
tf_cuda_library(
name = "c_api",
hdrs = [
@ -327,6 +342,9 @@ tf_cuda_library(
":checkpoint_reader",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_internal",
"//tensorflow/c/eager:tfe_context_internal",
"//tensorflow/c/eager:tfe_op_internal",
"//tensorflow/c/eager:tfe_tensorhandle_internal",
"//tensorflow/compiler/jit:flags",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
@ -517,12 +535,12 @@ tf_cuda_cc_test(
":test_op1.so",
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
],
kernels = [":test_op_kernel"],
linkopts = select({
"//tensorflow:macos": ["-headerpad_max_install_names"],
"//conditions:default": [],
}),
tags = [
"no_windows", # TODO(b/155444728)
"noasan",
],
# We must ensure that the dependencies can be dynamically linked since
@ -531,6 +549,7 @@ tf_cuda_cc_test(
deps = [
":c_api",
":c_test_util",
":test_op_kernel",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:grad_ops",
"//tensorflow/cc/saved_model:signature_constants",
@ -597,6 +616,7 @@ tf_cc_test(
":c_api",
":c_api_internal",
":c_test_util",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
@ -721,3 +741,11 @@ tf_cuda_library(
],
alwayslink = 1,
)
cc_library(
name = "conversion_macros",
hdrs = [
"conversion_macros.h",
],
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -39,6 +39,7 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eval_const_tensor.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
@ -53,7 +54,6 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/validate.h"
#include "tensorflow/core/lib/gtl/array_slice.h"

View File

@ -21,6 +21,9 @@ limitations under the License.
#include "tensorflow/c/checkpoint_reader.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
@ -686,8 +689,7 @@ TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
status->status = tensorflow::Status::OK();
return new TFE_TensorHandle{
tensorflow::TensorHandle::CreateLocalHandle(tensor)};
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
}
namespace {
@ -708,7 +710,7 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
// New server created for new server_def. Unused if updating server_def.
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
tensorflow::GrpcServer* grpc_server =
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
if (grpc_server == nullptr) {
@ -822,14 +824,13 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
const int num_inputs = input_shapes->num_items;
NodeDef node_def;
node_def.set_name(tfe_op->operation->Name());
node_def.set_op(tfe_op->operation->Name());
tensorflow::AbstractOperationInterface* op = tensorflow::unwrap(tfe_op);
node_def.set_name(op->Name());
node_def.set_op(op->Name());
for (int i = 0; i < num_inputs; ++i) {
node_def.add_input("dummy_input");
}
OperationFromInterface(tfe_op->operation)
->Attrs()
.FillAttrValueMap(node_def.mutable_attr());
OperationFromInterface(op)->Attrs().FillAttrValueMap(node_def.mutable_attr());
const tensorflow::OpRegistrationData* op_reg_data;
status->status =

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/c_test_util.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/lib/hash/hash.h"

View File

@ -38,7 +38,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/status.h"

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_C_API_MACROS_H_
#define TENSORFLOW_C_C_API_MACROS_H_
#ifdef SWIG
#define TF_CAPI_EXPORT
#else
#if defined(_WIN32)
#ifdef TF_COMPILE_LIBRARY
#define TF_CAPI_EXPORT __declspec(dllexport)
#else
#define TF_CAPI_EXPORT __declspec(dllimport)
#endif // TF_COMPILE_LIBRARY
#else
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
#endif // _WIN32
#endif // SWIG
#endif // TENSORFLOW_C_C_API_MACROS_H_

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_CONVERSION_MACROS_H_
#define TENSORFLOW_C_CONVERSION_MACROS_H_
#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \
inline cpp_impl *unwrap(wrapper *w) { \
return reinterpret_cast<cpp_impl *>(w); \
} \
\
inline const cpp_impl *unwrap(const wrapper *w) { \
return reinterpret_cast<const cpp_impl *>(w); \
} \
\
inline wrapper *wrap(cpp_impl *i) { return reinterpret_cast<wrapper *>(i); } \
inline const wrapper *wrap(const cpp_impl *i) { \
return reinterpret_cast<const wrapper *>(i); \
}
#endif // TENSORFLOW_C_CONVERSION_MACROS_H_

View File

@ -16,6 +16,7 @@ load(
"//tensorflow/core/platform:build_config_root.bzl",
"tf_cuda_tests_tags",
)
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
package(
licenses = ["notice"], # Apache 2.0
@ -41,12 +42,20 @@ tf_cuda_library(
":context_interface",
":operation_interface",
":tensor_handle_interface",
":tfe_context_internal",
":tfe_cancellation_manager_internal",
":tfe_executor_internal",
":tfe_monitoring_internal",
":tfe_op_attrs_internal",
":tfe_op_internal",
":tfe_tensor_debug_info_internal",
":tfe_tensorhandle_internal",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:fixed_array",
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/types:variant",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
"//tensorflow/c:tf_status_internal",
"//tensorflow/c:tf_tensor_internal",
"//tensorflow/core:core_cpu",
"//tensorflow/core/common_runtime/eager:attr_builder",
@ -100,6 +109,11 @@ filegroup(
"dlpack.h",
"operation_interface.h",
"tensor_handle_interface.h",
"tfe_cancellation_manager_internal.h",
"tfe_executor_internal.h",
"tfe_monitoring_internal.h",
"tfe_op_attrs_internal.h",
"tfe_tensor_debug_info_internal.h",
],
visibility = [
"//tensorflow/core:__pkg__",
@ -107,33 +121,27 @@ filegroup(
],
)
tf_cuda_library(
cc_library(
name = "c_api_internal",
srcs = [
hdrs = [
"c_api_experimental.h",
"c_api_unified_experimental.h",
"c_api_internal.h",
],
hdrs = ["c_api_internal.h"],
visibility = [
"//learning/deepmind/courier:__subpackages__",
"//tensorflow:internal",
],
deps = [
":c_api",
":context_interface",
":operation_interface",
":tensor_handle_interface",
"//tensorflow/c:c_api",
":tfe_cancellation_manager_internal",
":tfe_context_internal",
":tfe_executor_internal",
":tfe_monitoring_internal",
":tfe_op_attrs_internal",
":tfe_op_internal",
":tfe_tensor_debug_info_internal",
":tfe_tensorhandle_internal",
"//tensorflow/c:c_api_internal",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:framework_lite",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/common_runtime/eager:eager_executor",
],
)
@ -177,13 +185,110 @@ cc_library(
":operation_interface",
":tensor_handle_interface",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "tfe_context_internal",
hdrs = ["tfe_context_internal.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
":context_interface",
"//tensorflow/c:conversion_macros",
],
)
cc_library(
name = "tfe_cancellation_manager_internal",
hdrs = ["tfe_cancellation_manager_internal.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/core:framework",
],
)
cc_library(
name = "tfe_executor_internal",
hdrs = ["tfe_executor_internal.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/core/common_runtime/eager:eager_executor",
],
)
cc_library(
name = "tfe_monitoring_internal",
hdrs = ["tfe_monitoring_internal.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
],
)
cc_library(
name = "tfe_op_attrs_internal",
hdrs = ["tfe_op_attrs_internal.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/c:tf_status",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:attr_builder",
],
)
cc_library(
name = "tfe_op_internal",
hdrs = ["tfe_op_internal.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
":operation_interface",
"//tensorflow/c:conversion_macros",
],
)
cc_library(
name = "tfe_tensor_debug_info_internal",
hdrs = ["tfe_tensor_debug_info_internal.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/core:lib",
],
)
cc_library(
name = "tfe_tensorhandle_internal",
hdrs = ["tfe_tensorhandle_internal.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
":tensor_handle_interface",
"//tensorflow/c:conversion_macros",
],
)
tf_cuda_library(
name = "c_api_test_util",
testonly = 1,
@ -213,7 +318,8 @@ tf_cuda_cc_test(
],
extra_copts = tfe_xla_copts(),
tags = [
"guitar",
"noguitar", # TODO(b/155445984): flaky
#"guitar",
"multi_gpu",
],
deps = [
@ -221,6 +327,8 @@ tf_cuda_cc_test(
":c_api_experimental",
":c_api_internal",
":c_api_test_util",
":tfe_op_internal",
":tfe_tensorhandle_internal",
"//tensorflow/c:c_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@ -239,12 +347,16 @@ tf_cuda_cc_test(
srcs = [
"c_api_remote_test.cc",
],
# TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = ["noasan"], # leaks gRPC server instances
deps = [
":c_api",
":c_api_experimental",
":c_api_internal",
":c_api_test_util",
":tfe_tensorhandle_internal",
"//tensorflow/c:c_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@ -256,11 +368,42 @@ tf_cuda_cc_test(
],
)
tf_cuda_cc_test(
name = "c_api_cluster_test",
size = "small",
srcs = [
"c_api_cluster_test.cc",
],
# TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = ["noasan"], # leaks gRPC server instances
deps = [
":c_api",
":c_api_experimental",
":c_api_internal",
":c_api_test_util",
":tfe_tensorhandle_internal",
"//tensorflow/c:c_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/platform:env",
"@com_google_absl//absl/strings",
],
)
tf_cuda_library(
name = "c_api_experimental",
srcs = [
"c_api_experimental.cc",
"c_api_unified_experimental.cc",
"c_api_unified_experimental_eager.cc",
"c_api_unified_experimental_graph.cc",
"c_api_unified_experimental_internal.h",
],
hdrs = [
"c_api_experimental.h",
@ -275,6 +418,9 @@ tf_cuda_library(
"//conditions:default": [
":c_api",
":c_api_internal",
":tfe_context_internal",
":tfe_op_internal",
":tfe_tensorhandle_internal",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
"//tensorflow/core:core_cpu",
@ -357,6 +503,7 @@ tf_cuda_cc_test(
":c_api",
":c_api_experimental",
":c_api_test_util",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/cc/profiler",
"//tensorflow/core:lib",
@ -438,8 +585,9 @@ cc_library(
deps = [
":c_api",
":c_api_experimental",
":c_api_internal",
":tfe_tensorhandle_internal",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c:tf_status_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",

View File

@ -26,7 +26,6 @@ limitations under the License.
// clang-format on
#include "absl/algorithm/container.h"
#include "absl/container/fixed_array.h"
#include "absl/memory/memory.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
@ -34,6 +33,9 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_tensor_internal.h"
#ifdef PLATFORM_GOOGLE
#include "tensorflow/c/eager/c_api_tfrt.h"
@ -298,7 +300,7 @@ tensorflow::Status CreateRemoteContexts(
std::vector<bool> filtered_device_mask;
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->FilterDevicesForRemoteWorkers(
remote_worker, base_request.cluster_device_attributes(),
&filtered_device_mask);
@ -383,7 +385,7 @@ tensorflow::Status UpdateRemoteContexts(
std::vector<bool> filtered_device_mask;
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->FilterDevicesForRemoteWorkers(
remote_worker, base_request.cluster_device_attributes(),
&filtered_device_mask);
@ -464,7 +466,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
// New server created for new server_def. Unused if updating server_def.
std::unique_ptr<tensorflow::ServerInterface> new_server;
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
tensorflow::GrpcServer* grpc_server;
if (reset_context) {
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
@ -498,6 +500,17 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
grpc_server->master_env()->worker_cache->GetEagerClientCache(
&remote_eager_workers));
// For cluster update, use a status group to aggregate statuses from
// * adding and removing remote devices
// * creating remote contexts on newly added workers
// * updating remote contexts on existing workers
// * updating the master context
// Note that we should not return immediately on errors in the middle of these
// updates to prevent cluster from having inconsistent context views.
//
// Unused if `reset_context` is True.
tensorflow::StatusGroup sg;
// When updating an existing context, populate the following lists with:
// * added_workers: set(remote_workers) - set(curr_remote_workers)
// * removed_workers: set(curr_remote_workers) - set(remote_workers)
@ -533,7 +546,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
DifferentiateWorkerLists(&curr_remote_workers, &remote_workers,
&added_workers, &removed_workers,
&existing_workers);
LOG_AND_RETURN_IF_ERROR(GetReplacedFromExistingWorkers(
sg.Update(GetReplacedFromExistingWorkers(
&existing_workers, context_id, context->GetContextViewId(), server_def,
remote_eager_workers.get(), &replaced_workers));
if (VLOG_IS_ON(1)) {
@ -557,11 +570,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
existing_workers.end());
}
}
LOG_AND_RETURN_IF_ERROR(
RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr));
LOG_AND_RETURN_IF_ERROR(AddRemoteDevicesToMgr(
added_workers, grpc_server->master_env()->worker_cache,
remote_device_mgr));
sg.Update(RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr));
sg.Update(AddRemoteDevicesToMgr(added_workers,
grpc_server->master_env()->worker_cache,
remote_device_mgr));
}
std::vector<tensorflow::DeviceAttributes> cluster_device_attributes;
@ -582,7 +594,6 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
}
// Initialize remote eager workers.
// TODO(b/138847548) Create remote eager contexts in async mode by default.
if (reset_context) {
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
ctx, remote_workers, context_id, context_view_id, keep_alive_secs,
@ -594,7 +605,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
// existing workers to also have the updated context_view_id, so
// we must set their context_view_id to the existing master's
// context_view_id + 1.
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
sg.Update(CreateRemoteContexts(
ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
server_def, remote_eager_workers.get(), context->Executor().Async(),
context->LazyCopyFunctionRemoteInputs(), base_request));
@ -604,20 +615,19 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
VLOG(1) << "Updating cluster with existing worker " << w;
}
}
LOG_AND_RETURN_IF_ERROR(UpdateRemoteContexts(
ctx, existing_workers, added_workers, removed_workers, context_id,
context_view_id + 1, server_def, remote_eager_workers.get(),
base_request));
sg.Update(UpdateRemoteContexts(ctx, existing_workers, added_workers,
removed_workers, context_id,
context_view_id + 1, server_def,
remote_eager_workers.get(), base_request));
}
}
tensorflow::RemoteRendezvous* r =
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
auto session_name = tensorflow::strings::StrCat("eager_", context_id);
auto* device_mgr = grpc_server->worker_env()->device_mgr;
std::shared_ptr<tensorflow::WorkerSession> worker_session;
if (reset_context) {
tensorflow::RemoteRendezvous* r =
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
auto* device_mgr = grpc_server->worker_env()->device_mgr;
std::shared_ptr<tensorflow::WorkerSession> worker_session;
TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
session_name, server_def, base_request.cluster_device_attributes(),
true));
@ -644,13 +654,13 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
// GrpcServer cannot be destroyed after it is started.
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
} else {
LOG_AND_RETURN_IF_ERROR(
grpc_server->worker_env()->session_mgr->UpdateSession(
session_name, server_def, base_request.cluster_device_attributes(),
true));
LOG_AND_RETURN_IF_ERROR(context->UpdateRemoteMaster(
grpc_server->worker_env(), std::move(remote_eager_workers),
added_workers, removed_workers, context_id, r));
sg.Update(grpc_server->worker_env()->session_mgr->UpdateSession(
session_name, server_def, base_request.cluster_device_attributes(),
/*isolate_session_state=*/true));
sg.Update(context->UpdateRemoteMaster(context_id,
std::move(remote_eager_workers),
added_workers, removed_workers));
LOG_AND_RETURN_IF_ERROR(sg.as_summary_status());
}
#undef LOG_AND_RETURN_IF_ERROR
@ -684,8 +694,13 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
if (opts->use_tfrt) {
#ifdef PLATFORM_GOOGLE
status->status = tensorflow::Status::OK();
return new TFE_Context{new tfrt::ContextInterface()};
tfrt::SmallVector<std::string, 4> op_handler_chains;
tfrt::SmallVector<tensorflow::DeviceAttributes, 4> device_attributes;
status->status = tfrt::ListOpHandlerChains(
opts->session_options.options, &op_handler_chains, &device_attributes);
if (!status->status.ok()) return nullptr;
return tensorflow::wrap(
new tfrt::ContextInterface(op_handler_chains, device_attributes));
#else
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
return nullptr;
@ -702,14 +717,14 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get());
return new TFE_Context{new tensorflow::EagerContext(
return tensorflow::wrap(new tensorflow::EagerContext(
opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
/*device_mgr_owned*/ true, r,
tensorflow::GetDefaultCustomKernelCreator())};
tensorflow::GetDefaultCustomKernelCreator()));
}
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
@ -720,14 +735,14 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr);
return new TFE_Context{new tensorflow::EagerContext(
return tensorflow::wrap(new tensorflow::EagerContext(
opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
/*device_mgr_owned*/ false, r,
tensorflow::GetDefaultCustomKernelCreator())};
tensorflow::GetDefaultCustomKernelCreator()));
}
void TFE_DeleteContext(TFE_Context* ctx) {
@ -735,23 +750,18 @@ void TFE_DeleteContext(TFE_Context* ctx) {
return;
}
// context->RefCountIsOne() should be true here.
// TODO(iga): Remove EagerContext refcounting.
ctx->context->Release();
delete ctx;
// ctx->RefCountIsOne() should be true here.
tensorflow::unwrap(ctx)->Release();
}
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
TF_DeviceList* l = new TF_DeviceList;
ctx->context->ListDevices(&l->response);
tensorflow::unwrap(ctx)->ListDevices(&l->response);
return l;
}
void TFE_ContextClearCaches(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->ClearCachesAndThreadExecutors();
tensorflow::unwrap(ctx)->ClearCachesAndThreadExecutors();
}
// Set server_def on the context, possibly updating it.
@ -773,7 +783,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
if (server_def.has_cluster_device_filters()) {
const auto& cdf = server_def.cluster_device_filters();
for (const auto& jdf : cdf.jobs()) {
const string& remote_prefix = "/job:" + jdf.name() + "/task:";
const string remote_prefix = "/job:" + jdf.name() + "/task:";
for (const auto& tdf : jdf.tasks()) {
const int32_t task_index = tdf.first;
std::vector<string> device_filters(tdf.second.device_filters_size());
@ -782,7 +792,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
}
const string remote_worker = remote_prefix + std::to_string(task_index);
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status =
context->SetRemoteDeviceFilters(remote_worker, device_filters);
}
@ -804,7 +814,7 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
#else // !defined(IS_MOBILE_PLATFORM)
tensorflow::ServerDef server_def;
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
if (!server_def.ParseFromArray(proto, proto_len)) {
status->status = tensorflow::errors::InvalidArgument(
"Invalid tensorflow.ServerDef protocol buffer");
@ -834,7 +844,7 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
return false;
#else // !defined(IS_MOBILE_PLATFORM)
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
tensorflow::GrpcServer* grpc_server =
static_cast<tensorflow::GrpcServer*>(context->GetServer());
@ -890,7 +900,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
status->status = tensorflow::Status::OK();
#else // !defined(IS_MOBILE_PLATFORM)
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->SyncExecutors();
#endif // !IS_MOBILE_PLATFORM
}
@ -898,7 +908,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetThreadLocalDevicePlacementPolicy(
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
}
@ -909,7 +919,7 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return static_cast<TFE_ContextDevicePlacementPolicy>(
context->GetDevicePlacementPolicy());
}
@ -919,8 +929,7 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
if (!status->status.ok()) return nullptr;
return new TFE_TensorHandle{
tensorflow::TensorHandle::CreateLocalHandle(tensor)};
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
}
void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
@ -928,84 +937,84 @@ void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
tensorflow::profiler::TraceMe activity(
"TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
if (h->handle) {
h->handle->Release();
if (h) {
tensorflow::unwrap(h)->Release();
}
delete h;
}
TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
return static_cast<TF_DataType>(h->handle->DataType());
return static_cast<TF_DataType>(tensorflow::unwrap(h)->DataType());
}
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return -1;
}
int num_dims = -1;
status->status = h->handle->NumDims(&num_dims);
status->status = tensorflow::unwrap(h)->NumDims(&num_dims);
return num_dims;
}
int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return -1;
}
int64 num_elements = -1;
status->status = h->handle->NumElements(&num_elements);
status->status = tensorflow::unwrap(h)->NumElements(&num_elements);
return num_elements;
}
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return -1;
}
int64 dim = -1;
status->status = h->handle->Dim(dim_index, &dim);
status->status = tensorflow::unwrap(h)->Dim(dim_index, &dim);
return dim;
}
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
return h->handle->DeviceName(&status->status);
return tensorflow::unwrap(h)->DeviceName(&status->status);
}
const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
return h->handle->BackingDeviceName(&status->status);
return tensorflow::unwrap(h)->BackingDeviceName(&status->status);
}
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
return new TFE_TensorHandle{h->handle->Copy()};
return tensorflow::wrap(tensorflow::unwrap(h)->Copy());
}
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
tensorflow::AbstractTensorInterface* t = h->handle->Resolve(&status->status);
tensorflow::AbstractTensorInterface* t =
tensorflow::unwrap(h)->Resolve(&status->status);
if (t == nullptr) {
return nullptr;
}
@ -1014,22 +1023,22 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
}
void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
tensorflow::TensorHandle* handle =
tensorflow::TensorHandleFromInterface(h->handle);
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
if (VariantDeviceIsCustom(handle->device())) {
const tensorflow::Tensor* t;
status->status = handle->Tensor(&t);
return t->data();
}
if (handle->IsRemote()) {
if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
status->status = tensorflow::errors::InvalidArgument(
"TFE_TensorHandleDevicePointer may not be called on a remote tensor "
"handle.");
"TFE_TensorHandleDevicePointer may not be called on a ",
handle->TypeString(), " tensor handle.");
return nullptr;
}
tensorflow::Device* device(absl::get<tensorflow::Device*>(handle->device()));
@ -1055,7 +1064,7 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
void* deallocator_arg, TF_Status* status) {
tensorflow::Device* device = nullptr;
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->FindDeviceFromName(device_name, &device);
tensorflow::CustomDevice* custom_device = nullptr;
if (!status->status.ok()) {
@ -1081,11 +1090,11 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
tensorflow::TensorShape(dimvec), buf);
buf->Unref();
if (custom_device == nullptr) {
return new TFE_TensorHandle{tensorflow::TensorHandle::CreateLocalHandle(
std::move(t), device, device, context)};
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
std::move(t), device, device, context));
} else {
return new TFE_TensorHandle{tensorflow::TensorHandle::CreateLocalHandle(
std::move(t), custom_device, context)};
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
std::move(t), custom_device, context));
}
}
@ -1094,16 +1103,16 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
// bytes of the memory pointed to by the device pointer returned above.
size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return 0;
}
tensorflow::TensorHandle* handle =
tensorflow::TensorHandleFromInterface(h->handle);
if (handle->IsRemote()) {
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
status->status = tensorflow::errors::InvalidArgument(
"TFE_TensorHandleDeviceMemorySize may not be called on a remote tensor "
"handle.");
"TFE_TensorHandleDeviceMemorySize may not be called on a ",
handle->TypeString(), " tensor handle.");
return 0;
}
const tensorflow::Tensor* tensor;
@ -1116,12 +1125,14 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) {
std::unique_ptr<TFE_Op> new_op(new TFE_Op{ctx->context->CreateOperation()});
status->status = new_op->operation->Reset(op_or_function_name, nullptr);
tensorflow::AbstractOperationInterface* new_op =
tensorflow::unwrap(ctx)->CreateOperation();
status->status = new_op->Reset(op_or_function_name, nullptr);
if (!status->status.ok()) {
new_op.reset();
new_op->Release();
new_op = nullptr;
}
return new_op.release();
return tensorflow::wrap(new_op);
}
void TFE_DeleteOp(TFE_Op* op) {
@ -1129,24 +1140,20 @@ void TFE_DeleteOp(TFE_Op* op) {
return;
}
if (op->operation) {
op->operation->Release();
}
delete op;
tensorflow::unwrap(op)->Release();
}
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
status->status = op->operation->SetDeviceName(device_name);
status->status = tensorflow::unwrap(op)->SetDeviceName(device_name);
}
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
return op->operation->DeviceName().c_str();
return tensorflow::unwrap(op)->DeviceName().c_str();
}
void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
#ifdef TENSORFLOW_EAGER_USE_XLA
tensorflow::Status s = op->operation->SetUseXla(enable);
tensorflow::Status s = tensorflow::unwrap(op)->SetUseXla(enable);
if (!s.ok()) {
LOG(ERROR) << "Could not enable XLA compilation for op: " << s;
}
@ -1157,18 +1164,13 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
}
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
status->status = op->operation->AddInput(input->handle);
status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input));
}
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
TF_Status* status) {
absl::FixedArray<tensorflow::AbstractTensorHandleInterface*> handles(
num_inputs);
for (int i = 0; i < num_inputs; ++i) {
handles[i] = inputs[i]->handle;
}
status->status =
op->operation->AddInputList({handles.data(), handles.size()});
status->status = tensorflow::unwrap(op)->AddInputList(
{tensorflow::unwrap(inputs), static_cast<size_t>(num_inputs)});
}
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
@ -1176,8 +1178,8 @@ TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
TF_AttrType ret = TF_ATTR_INT;
const tensorflow::AttrTypeMap* attr_types_;
bool is_function;
status->status = tensorflow::AttrTypeMapForOp(op->operation->Name().c_str(),
&attr_types_, &is_function);
status->status = tensorflow::AttrTypeMapForOp(
tensorflow::unwrap(op)->Name().c_str(), &attr_types_, &is_function);
if (!status->status.ok()) {
return ret;
}
@ -1203,7 +1205,7 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
size_t length) {
auto s = op->operation->SetAttrString(
auto s = tensorflow::unwrap(op)->SetAttrString(
attr_name, static_cast<const char*>(value), length);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
@ -1211,29 +1213,30 @@ void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
}
void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
auto s = op->operation->SetAttrInt(attr_name, value);
auto s = tensorflow::unwrap(op)->SetAttrInt(attr_name, value);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
auto s = op->operation->SetAttrFloat(attr_name, value);
auto s = tensorflow::unwrap(op)->SetAttrFloat(attr_name, value);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
auto s = op->operation->SetAttrBool(attr_name, (value == 0) ? false : true);
auto s = tensorflow::unwrap(op)->SetAttrBool(attr_name,
(value == 0) ? false : true);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
auto s = op->operation->SetAttrType(attr_name,
static_cast<tensorflow::DataType>(value));
auto s = tensorflow::unwrap(op)->SetAttrType(
attr_name, static_cast<tensorflow::DataType>(value));
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
@ -1241,12 +1244,14 @@ void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
const int num_dims, TF_Status* out_status) {
out_status->status = op->operation->SetAttrShape(attr_name, dims, num_dims);
out_status->status =
tensorflow::unwrap(op)->SetAttrShape(attr_name, dims, num_dims);
}
void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
const TFE_Op* value) {
auto s = op->operation->SetAttrFunction(attr_name, value->operation);
auto s = tensorflow::unwrap(op)->SetAttrFunction(
attr_name, tensorflow::unwrap(const_cast<TFE_Op*>(value)));
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
@ -1254,7 +1259,7 @@ void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
const char* data, size_t length) {
auto s = op->operation->SetAttrFunctionName(attr_name, data, length);
auto s = tensorflow::unwrap(op)->SetAttrFunctionName(attr_name, data, length);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
@ -1265,14 +1270,14 @@ void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
tensorflow::Tensor t;
status->status = TF_TensorToTensor(tensor, &t);
tensorflow::TensorInterface interface(t);
status->status = op->operation->SetAttrTensor(attr_name, &interface);
status->status = tensorflow::unwrap(op)->SetAttrTensor(attr_name, &interface);
}
void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
const void* const* values, const size_t* lengths,
int num_values) {
auto s =
op->operation->SetAttrStringList(attr_name, values, lengths, num_values);
auto s = tensorflow::unwrap(op)->SetAttrStringList(attr_name, values, lengths,
num_values);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
@ -1280,7 +1285,8 @@ void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
const float* values, int num_values) {
auto s = op->operation->SetAttrFloatList(attr_name, values, num_values);
auto s =
tensorflow::unwrap(op)->SetAttrFloatList(attr_name, values, num_values);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
@ -1288,7 +1294,8 @@ void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
const int64_t* values, int num_values) {
auto s = op->operation->SetAttrIntList(attr_name, values, num_values);
auto s =
tensorflow::unwrap(op)->SetAttrIntList(attr_name, values, num_values);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
@ -1296,7 +1303,7 @@ void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
const TF_DataType* values, int num_values) {
auto s = op->operation->SetAttrTypeList(
auto s = tensorflow::unwrap(op)->SetAttrTypeList(
attr_name, reinterpret_cast<const tensorflow::DataType*>(values),
num_values);
if (!s.ok()) {
@ -1306,7 +1313,8 @@ void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
const unsigned char* values, int num_values) {
auto s = op->operation->SetAttrBoolList(attr_name, values, num_values);
auto s =
tensorflow::unwrap(op)->SetAttrBoolList(attr_name, values, num_values);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
@ -1315,19 +1323,14 @@ void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
const int64_t** dims, const int* num_dims,
int num_values, TF_Status* out_status) {
out_status->status =
op->operation->SetAttrShapeList(attr_name, dims, num_dims, num_values);
out_status->status = tensorflow::unwrap(op)->SetAttrShapeList(
attr_name, dims, num_dims, num_values);
}
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
const TFE_Op** value, int num_values) {
absl::FixedArray<const tensorflow::AbstractOperationInterface*> values(
num_values);
for (int i = 0; i < num_values; ++i) {
values[i] = value[i]->operation;
}
auto s = op->operation->SetAttrFunctionList(attr_name,
{values.data(), values.size()});
auto s = tensorflow::unwrap(op)->SetAttrFunctionList(
attr_name, {tensorflow::unwrap(value), static_cast<size_t>(num_values)});
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
@ -1342,12 +1345,13 @@ void TFE_OpSetAttrValueProto(const TFE_Op* op, const char* attr_name,
tensorflow::errors::InvalidArgument("Unparseable AttrValue proto");
return;
}
if (op == nullptr || op->operation == nullptr) {
if (op == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"Got a null or uninitialized `op` argument");
return;
}
tensorflow::EagerOperation* operation = OperationFromInterface(op->operation);
tensorflow::EagerOperation* operation =
OperationFromInterface(tensorflow::unwrap(const_cast<TFE_Op*>(op)));
operation->MutableAttrs()->Set(attr_name, attr_value);
}
@ -1355,7 +1359,7 @@ TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
const char* input_name,
TF_Status* status) {
int ret = -1;
status->status = op->operation->InputLength(input_name, &ret);
status->status = tensorflow::unwrap(op)->InputLength(input_name, &ret);
return ret;
}
@ -1363,71 +1367,29 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
const char* output_name,
TF_Status* status) {
int ret = -1;
status->status = op->operation->OutputLength(output_name, &ret);
status->status = tensorflow::unwrap(op)->OutputLength(output_name, &ret);
return ret;
}
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) {
absl::FixedArray<tensorflow::AbstractTensorHandleInterface*> handles(
*num_retvals);
status->status = op->operation->Execute(absl::MakeSpan(handles), num_retvals);
if (!status->status.ok()) {
return;
}
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = new TFE_TensorHandle{handles[i]};
}
status->status = tensorflow::unwrap(op)->Execute(
absl::MakeSpan(tensorflow::unwrap(retvals), *num_retvals), num_retvals);
}
TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
TFE_Context* ctx,
const char* device_name,
TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
tensorflow::TensorHandle* handle = nullptr;
tensorflow::Device* device;
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
status->status = context->FindDeviceFromName(device_name, &device);
if (!status->status.ok()) {
tensorflow::CustomDevice* dev;
status->status = context->FindCustomDeviceFromName(device_name, &dev);
if (status->status.ok()) {
status->status = dev->CopyTensorToDevice(
tensorflow::TensorHandleFromInterface(h->handle), &handle);
if (status->status.ok()) {
return new TFE_TensorHandle{handle};
}
}
return nullptr;
}
// Handle tensor handles currently in custom devices
const char* handle_device_name = h->handle->DeviceName(&status->status);
if (!status->status.ok()) {
return nullptr;
}
tensorflow::CustomDevice* dev;
status->status = context->FindCustomDeviceFromName(handle_device_name, &dev);
auto* result = tensorflow::unwrap(ctx)->CopyTensorHandleToDevice(
tensorflow::unwrap(h), device_name, &status->status);
if (status->status.ok()) {
status->status = dev->CopyTensorFromDevice(
tensorflow::TensorHandleFromInterface(h->handle), device_name, &handle);
if (status->status.ok()) {
return new TFE_TensorHandle{handle};
}
return nullptr;
}
// Handle regular case.
status->status = tensorflow::EagerCopyToDevice(
tensorflow::TensorHandleFromInterface(h->handle), context,
&context->Executor(), device, false, &handle);
if (status->status.ok()) {
return new TFE_TensorHandle{handle};
return tensorflow::wrap(result);
}
return nullptr;
}
@ -1442,39 +1404,39 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
return;
}
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->AddFunctionDef(function_def);
}
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->AddFunctionDef(function->fdef);
}
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->RemoveFunction(name);
}
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return context->FindFunctionDef(name) != nullptr;
}
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(true);
}
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(false);
}
@ -1482,13 +1444,13 @@ void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
TF_Status* status) {
return new TFE_TensorHandle{tensorflow::TensorHandle::CreateLocalHandle(t)};
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(t));
}
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->Executor().WaitForAllPendingNodes();
if (!status->status.ok()) return;
tensorflow::mutex_lock ml(*context->MetadataMu());
@ -1510,26 +1472,23 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
} // namespace
void TFE_ContextStartStep(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->StartStep();
tensorflow::unwrap(ctx)->StartStep();
}
void TFE_ContextEndStep(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->EndStep();
tensorflow::unwrap(ctx)->EndStep();
}
void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) {
tensorflow::EagerOperation* operation = OperationFromInterface(op->operation);
*attrs = TFE_OpAttrs(&operation->Attrs(), operation->Name().c_str());
const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) {
return tensorflow::wrap(
&OperationFromInterface(tensorflow::unwrap(op))->Attrs());
}
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
tensorflow::AttrValueMap m;
attrs->attributes->FillAttrValueMap(&m);
tensorflow::EagerOperation* operation = OperationFromInterface(op->operation);
tensorflow::unwrap(attrs)->FillAttrValueMap(&m);
tensorflow::EagerOperation* operation =
OperationFromInterface(tensorflow::unwrap(op));
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
for (const auto& attribute : m) {
destination->Set(attribute.first, attribute.second);
@ -1539,8 +1498,8 @@ void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
TF_Status* status) {
tensorflow::NameAttrList name_and_attrs;
attrs->attributes->FillAttrValueMap(name_and_attrs.mutable_attr());
name_and_attrs.set_name(attrs->name);
tensorflow::unwrap(attrs)->FillAttrValueMap(name_and_attrs.mutable_attr());
name_and_attrs.set_name(tensorflow::unwrap(attrs)->op_name());
status->status = MessageToBuffer(name_and_attrs, buf);
}
@ -1617,33 +1576,34 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
const string& name() override { return name_; }
tensorflow::Status CopyTensorToDevice(
tensorflow::TensorHandle* tensor,
tensorflow::TensorHandle* handle,
tensorflow::TensorHandle** result) override {
tensor->Ref();
TFE_TensorHandle tensor_handle{tensor};
handle->Ref();
TF_Status status;
TFE_TensorHandle* result_handle =
device_.copy_tensor_to_device(context_, &tensor_handle, &status, info_);
tensor_handle.handle->Release();
TFE_TensorHandle* result_handle = device_.copy_tensor_to_device(
context_, tensorflow::wrap(handle), &status, info_);
handle->Release();
if (!status.status.ok()) return status.status;
*result = tensorflow::TensorHandleFromInterface(result_handle->handle);
*result = tensorflow::TensorHandleFromInterface(
tensorflow::unwrap(result_handle));
(*result)->Ref();
TFE_DeleteTensorHandle(result_handle);
return status.status;
}
tensorflow::Status CopyTensorFromDevice(
tensorflow::TensorHandle* tensor,
tensorflow::TensorHandle* handle,
const tensorflow::string& target_device_name,
tensorflow::TensorHandle** result) override {
TF_Status status;
tensor->Ref();
TFE_TensorHandle tensor_handle{tensor};
handle->Ref();
TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
context_, &tensor_handle, target_device_name.c_str(), &status, info_);
tensor_handle.handle->Release();
context_, tensorflow::wrap(handle), target_device_name.c_str(), &status,
info_);
handle->Release();
if (!status.status.ok()) return status.status;
*result = tensorflow::TensorHandleFromInterface(result_handle->handle);
*result = tensorflow::TensorHandleFromInterface(
tensorflow::unwrap(result_handle));
(*result)->Ref();
TFE_DeleteTensorHandle(result_handle);
return status.status;
@ -1656,16 +1616,17 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
inputs.reserve(op->Inputs().size());
for (int i = 0; i < op->Inputs().size(); ++i) {
op->Inputs()[i]->Ref();
inputs.push_back(new TFE_TensorHandle{op->Inputs()[i]});
inputs.push_back(tensorflow::wrap(op->Inputs()[i]));
}
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
TF_Status status;
TFE_OpAttrs attributes(&op->Attrs(), op->Name().c_str());
device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(),
&attributes, num_retvals, outputs.data(), &status, info_);
wrap(&op->Attrs()), num_retvals, outputs.data(), &status,
info_);
if (status.status.ok()) {
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = tensorflow::TensorHandleFromInterface(outputs[i]->handle);
retvals[i] = tensorflow::TensorHandleFromInterface(
tensorflow::unwrap(outputs[i]));
retvals[i]->Ref();
TFE_DeleteTensorHandle(outputs[i]);
}
@ -1693,7 +1654,7 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
auto custom_device =
std::make_unique<CustomDeviceAPI>(ctx, device, device_info, device_name);
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status =
context->RegisterCustomDevice(device_name, std::move(custom_device));
}

View File

@ -0,0 +1,313 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace {
using ::tensorflow::string;
tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
tensorflow::ServerDef server_def;
server_def.set_protocol("grpc");
server_def.set_job_name(job_name);
server_def.set_task_index(0);
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->add_job();
job_def->set_name(job_name);
for (int i = 0; i < num_tasks; i++) {
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
{i, tensorflow::strings::StrCat("localhost:", port)});
}
return server_def;
}
tensorflow::ServerDef GetServerDef(int num_tasks) {
return GetServerDef("localhost", num_tasks);
}
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
const std::vector<float>& expected_values) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
std::unique_ptr<float[]> actual_values(new float[expected_values.size()]);
EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t));
memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
for (int i = 0; i < expected_values.size(); i++) {
EXPECT_EQ(expected_values[i], actual_values[i])
<< "Mismatch in expected values at (zero-based) index " << i;
}
}
void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
const char* remote_device_name,
const char* local_device_name) {
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0);
TFE_OpSetDevice(matmul, remote_device_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* retval_task0 =
TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22});
TFE_DeleteTensorHandle(retval_task0);
TFE_DeleteTensorHandle(h0_task0);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteOp(matmul);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteExecutor(executor);
TF_DeleteStatus(status);
}
void TestRemoteExecuteChangeServerDef(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
const char local_device_name[] =
"/job:localhost/replica:0/task:0/device:CPU:0";
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
// Update the server def with a new set of names (worker instead of
// localhost).
tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2);
serialized = updated_server_def.SerializeAsString();
updated_server_def.set_task_index(1);
tensorflow::Status s = tensorflow::GrpcServer::Create(
updated_server_def, tensorflow::Env::Default(), &worker_server);
ASSERT_TRUE(s.ok()) << s.error_message();
ASSERT_TRUE(worker_server->Start().ok());
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Create a new tensor_handle.
TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle(ctx);
// Check that copying it to the old remote device (named localhost) fails.
TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status);
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Copying and executing on the new remote device works.
const char new_remote_device_name[] =
"/job:worker/replica:0/task:1/device:CPU:0";
const char new_local_device_name[] =
"/job:worker/replica:0/task:0/device:CPU:0";
auto* h0_task1_new = TFE_TensorHandleCopyToDevice(
h0_task0_new, ctx, new_remote_device_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(h0_task0_new);
TFE_DeleteTensorHandle(h0_task1_new);
CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name,
new_local_device_name);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteExecutor(executor);
TF_DeleteStatus(status);
TFE_DeleteContext(ctx);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
}
TEST(CAPI, RemoteExecuteChangeServerDef) {
TestRemoteExecuteChangeServerDef(false);
}
TEST(CAPI, RemoteExecuteChangeServerDefAsync) {
TestRemoteExecuteChangeServerDef(true);
}
void TestRemoteExecuteUpdateServerDef(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const char local_device_name[] =
"/job:localhost/replica:0/task:0/device:CPU:0";
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(),
status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
}
TEST(CAPI, RemoteExecuteUpdateServerDef) {
TestRemoteExecuteUpdateServerDef(false);
}
TEST(CAPI, RemoteExecuteUpdateServerDefAsync) {
TestRemoteExecuteUpdateServerDef(true);
}
void TestRemoteExecuteUpdateServerDefWithFailures(bool async) {
// Fail fast on GetStatus requests so we can get errors instead of timeout
// when updating cluster with non-exsitent worker
tensorflow::setenv("GRPC_FAIL_FAST", "TRUE", /*overwrite=*/1);
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const char local_device_name[] =
"/job:localhost/replica:0/task:0/device:CPU:0";
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
// Adding a non-existent remote worker to cluster def. This should cause the
// UpdateServerDef call to fail.
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->mutable_job(0);
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
{2, tensorflow::strings::StrCat("localhost:", port)});
string serialized_update = server_def.SerializeAsString();
TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
serialized_update.size(), status);
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Even after the prevoiusly failed cluster update, another update and op
// execution should work fine as long as the provided server_def is valid.
TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(),
status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
tensorflow::unsetenv("GRPC_FAIL_FAST");
}
TEST(CAPI, RemoteExecuteUpdateServerDefWithFailures) {
TestRemoteExecuteUpdateServerDefWithFailures(false);
}
TEST(CAPI, RemoteExecuteUpdateServerDefWithFailuresAsync) {
TestRemoteExecuteUpdateServerDefWithFailures(true);
}
} // namespace

View File

@ -17,8 +17,11 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/tfe_tensor_debug_info_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/platform/status.h"
#ifdef TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/compiler/jit/xla_device.h"
#endif // TENSORFLOW_EAGER_USE_XLA
@ -54,7 +57,8 @@ extern "C" {
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
TFE_TensorHandle* h, TF_Status* status) {
tensorflow::TensorHandle* handle = TensorHandleFromInterface(h->handle);
tensorflow::TensorHandle* handle =
TensorHandleFromInterface(tensorflow::unwrap(h));
const tensorflow::Tensor* tensor;
status->status = handle->Tensor(&tensor);
if (!status->status.ok()) {

View File

@ -19,6 +19,9 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
@ -34,9 +37,10 @@ using tensorflow::string;
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status) {
if (op_to_reset) {
op_to_reset->operation->Clear();
status->status =
op_to_reset->operation->Reset(op_or_function_name, raw_device_name);
tensorflow::AbstractOperationInterface* op =
tensorflow::unwrap(op_to_reset);
op->Clear();
status->status = op->Reset(op_or_function_name, raw_device_name);
} else {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"op_to_reset should not be nullptr");
@ -45,13 +49,13 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(true);
}
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(false);
}
@ -483,7 +487,7 @@ void TFE_ContextOptionsSetMirroringPolicy(TFE_ContextOptions* options,
void TFE_ContextSetThreadLocalMirroringPolicy(
TFE_Context* ctx, TFE_ContextMirroringPolicy policy) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetThreadLocalMirroringPolicy(
static_cast<tensorflow::ContextMirroringPolicy>(policy));
}
@ -494,7 +498,7 @@ void TFE_ContextSetThreadLocalMirroringPolicy(
extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return static_cast<TFE_ContextMirroringPolicy>(context->GetMirroringPolicy());
}
@ -530,7 +534,7 @@ void TFE_OpSetCancellationManager(TFE_Op* op,
TFE_CancellationManager* cancellation_manager,
TF_Status* status) {
tensorflow::EagerOperation* operation =
tensorflow::OperationFromInterface(op->operation);
tensorflow::OperationFromInterface(tensorflow::unwrap(op));
operation->SetCancellationManager(
&cancellation_manager->cancellation_manager);
status->status = tensorflow::Status::OK();
@ -557,19 +561,19 @@ void TFE_ExecutorClearError(TFE_Executor* executor) {
void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetExecutorForThread(executor->executor());
}
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return new TFE_Executor(&context->Executor());
}
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
context->HostCPU()->parsed_name());
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
@ -585,7 +589,7 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
TF_Buffer* buf, TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
auto* function_def = context->FindFunctionDef(function_name);
if (function_def == nullptr) {
status->status = tensorflow::errors::NotFound(
@ -611,13 +615,14 @@ TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype,
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
}
if (ctx == nullptr || ctx->context == nullptr) {
if (ctx == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid Context");
return nullptr;
}
tensorflow::AbstractTensorInterface* t = ctx->context->CreateTensor(
static_cast<tensorflow::DataType>(dtype), dimvec);
tensorflow::AbstractTensorInterface* t =
tensorflow::unwrap(ctx)->CreateTensor(
static_cast<tensorflow::DataType>(dtype), dimvec);
if (t == nullptr) {
status->status =
@ -630,5 +635,6 @@ TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype,
TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t,
TF_Status* status) {
return new TFE_TensorHandle{ctx->context->CreateLocalHandle(t->tensor)};
return tensorflow::wrap(
tensorflow::unwrap(ctx)->CreateLocalHandle(t->tensor));
}

View File

@ -431,11 +431,9 @@ TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
// A reference to an op's name -> attribute mapping
typedef struct TFE_OpAttrs TFE_OpAttrs;
// Fetch a struct with a reference to information about attributes of `op`.
//
// The `attrs` struct does not own any memory, and `op` must outlive it.
TF_CAPI_EXPORT extern void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs);
// Fetch a reference to `op`'s attributes. The returned reference is only valid
// while `op` is alive.
const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op);
// Add attributes in `attrs` to `op`.
//
// Does not overwrite or update existing attributes, but adds new ones.

View File

@ -15,39 +15,17 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
#define TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
#include <algorithm>
#include <cstddef>
#include <map>
#include <memory>
#include <queue>
#include <string>
#include <vector>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/context_interface.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/c/eager/tfe_cancellation_manager_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_executor_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_monitoring_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_op_attrs_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_tensor_debug_info_internal.h" // IWYU pragma: export
// TODO(b/154564140): Move this to its own header. This requires splitting
// c_api_experimental.h
struct TFE_ContextOptions {
TF_SessionOptions session_options;
// true if async execution is enabled.
@ -61,199 +39,4 @@ struct TFE_ContextOptions {
bool use_tfrt = false;
};
// Wraps a pointer to a context implementation.
//
// WARNING: Since the underlying object could be ref-counted a user of this
// interface cannot destruct the underlying context object. Instead, call
// TFE_DeleteContext who calls Release() on the context pointer and deletes
// the TFE_Context structure.
struct TFE_Context {
tensorflow::AbstractContextInterface* context;
};
// Wraps a pointer to a tensor handle implementation.
//
// WARNING: Since the underlying object could be ref-counted a user of this
// interface cannot destruct the underlying handle object. Instead, call
// TFE_DeleteTensorHandle who calls Release() on the handle pointer and deletes
// the TFE_TensorHandle structure.
struct TFE_TensorHandle {
tensorflow::AbstractTensorHandleInterface* handle;
};
struct TFE_TensorDebugInfo {
explicit TFE_TensorDebugInfo(const std::vector<tensorflow::int64>& dims)
: dev_dims(dims) {}
// Fully-padded, minor-to-major.
std::vector<tensorflow::int64> dev_dims;
};
// Wraps a pointer to an operation implementation.
//
// WARNING: Since the underlying object could be ref-counted a user of this
// interface cannot destruct the underlying operation object. Instead, call
// TFE_DeleteOp who calls Release() on the operation pointer and deletes
// the TFE_Op structure.
struct TFE_Op {
tensorflow::AbstractOperationInterface* operation;
};
struct TFE_MonitoringCounterCell {
tensorflow::monitoring::CounterCell cell;
};
template <int NumLabels>
struct TFE_MonitoringCounter {
template <typename... LabelDesc>
TFE_MonitoringCounter(const char* name, const char* description,
LabelDesc&&... label) {
counter = absl::WrapUnique(tensorflow::monitoring::Counter<NumLabels>::New(
name, description, label...));
}
std::unique_ptr<tensorflow::monitoring::Counter<NumLabels>> counter;
};
struct TFE_MonitoringCounter0 : TFE_MonitoringCounter<0> {
using TFE_MonitoringCounter::TFE_MonitoringCounter;
};
struct TFE_MonitoringCounter1 : TFE_MonitoringCounter<1> {
using TFE_MonitoringCounter::TFE_MonitoringCounter;
};
struct TFE_MonitoringCounter2 : TFE_MonitoringCounter<2> {
using TFE_MonitoringCounter::TFE_MonitoringCounter;
};
struct TFE_MonitoringIntGaugeCell {
tensorflow::monitoring::GaugeCell<tensorflow::int64> cell;
};
struct TFE_MonitoringStringGaugeCell {
tensorflow::monitoring::GaugeCell<tensorflow::string> cell;
};
struct TFE_MonitoringBoolGaugeCell {
tensorflow::monitoring::GaugeCell<bool> cell;
};
template <typename ValueType, int NumLabels>
struct TFE_MonitoringGauge {
template <typename... LabelDesc>
TFE_MonitoringGauge(const char* name, const char* description,
LabelDesc&&... label) {
gauge = absl::WrapUnique(
tensorflow::monitoring::Gauge<ValueType, NumLabels>::New(
name, description, label...));
}
std::unique_ptr<tensorflow::monitoring::Gauge<ValueType, NumLabels>> gauge;
};
struct TFE_MonitoringIntGauge0 : TFE_MonitoringGauge<tensorflow::int64, 0> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringIntGauge1 : TFE_MonitoringGauge<tensorflow::int64, 1> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringIntGauge2 : TFE_MonitoringGauge<tensorflow::int64, 2> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringStringGauge0 : TFE_MonitoringGauge<tensorflow::string, 0> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringStringGauge1 : TFE_MonitoringGauge<tensorflow::string, 1> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringStringGauge2 : TFE_MonitoringGauge<tensorflow::string, 2> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBoolGauge0 : TFE_MonitoringGauge<bool, 0> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBoolGauge1 : TFE_MonitoringGauge<bool, 1> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBoolGauge2 : TFE_MonitoringGauge<bool, 2> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBuckets {
explicit TFE_MonitoringBuckets(
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
fn) {
create_buckets = fn;
}
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
create_buckets;
};
struct TFE_MonitoringSamplerCell {
tensorflow::monitoring::SamplerCell cell;
};
template <int NumLabels>
struct TFE_MonitoringSampler {
template <typename... LabelDesc>
TFE_MonitoringSampler(
const char* name,
std::unique_ptr<tensorflow::monitoring::Buckets> buckets,
const char* description, LabelDesc&&... label) {
sampler = absl::WrapUnique(tensorflow::monitoring::Sampler<NumLabels>::New(
{name, description, label...}, std::move(buckets)));
}
std::unique_ptr<tensorflow::monitoring::Sampler<NumLabels>> sampler;
};
struct TFE_MonitoringSampler0 : TFE_MonitoringSampler<0> {
using TFE_MonitoringSampler::TFE_MonitoringSampler;
};
struct TFE_MonitoringSampler1 : TFE_MonitoringSampler<1> {
using TFE_MonitoringSampler::TFE_MonitoringSampler;
};
struct TFE_MonitoringSampler2 : TFE_MonitoringSampler<2> {
using TFE_MonitoringSampler::TFE_MonitoringSampler;
};
namespace tensorflow {
// Set an AttrValue on the op. Doesn't handle the list types.
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
const tensorflow::AttrValue& default_value,
const char* attr_name, TF_Status* status);
} // namespace tensorflow
struct TFE_CancellationManager {
tensorflow::CancellationManager cancellation_manager;
};
struct TFE_Executor {
explicit TFE_Executor(bool async)
: owned_executor(new tensorflow::EagerExecutor(async)) {}
explicit TFE_Executor(tensorflow::EagerExecutor* executor)
: owned_executor(nullptr), unowned_executor(executor) {}
tensorflow::EagerExecutor* executor() {
return owned_executor == nullptr ? unowned_executor : owned_executor.get();
}
std::unique_ptr<tensorflow::EagerExecutor> owned_executor;
tensorflow::EagerExecutor* unowned_executor;
};
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
// that sometimes do not require serialization.
struct TFE_OpAttrs {
explicit TFE_OpAttrs() : name(nullptr), attributes(nullptr) {}
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value,
const char* op_name)
: name(op_name), attributes(value) {}
const char* name;
const tensorflow::AttrBuilder* attributes;
};
#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/platform/casts.h"
@ -167,7 +168,11 @@ string MatMulFunction() {
return def.SerializeAsString();
}
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
// If heavy_load_on_streaming_rpc is true, send some rpc reqeusts before the one
// which creates a remote remote input, to simulate a scenario that the remote
// input is not ready when we start running an op or a function.
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
bool heavy_load_on_streaming_rpc) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
@ -192,48 +197,64 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx);
std::vector<TFE_TensorHandle*> handles_task0;
if (heavy_load_on_streaming_rpc) {
// Send 50 tensor copy requests to simulate that there have been some RPC
// requests been enqueued.
for (int i = 0; i < 50; ++i) {
handles_task0.push_back(TestMatrixTensorHandle(ctx));
}
}
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
std::vector<TFE_TensorHandle*> handles_task2;
for (auto* h_task0 : handles_task0) {
handles_task2.push_back(
TFE_TensorHandleCopyToDevice(h_task0, ctx, task2_name, status));
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
}
auto* h1_task2 =
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* matmul = nullptr;
if (func) {
string function_def = MatMulFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
matmul = TFE_NewOp(ctx, "MatMulFunction", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(matmul, h0_task0, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(matmul, h1_task2, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
} else {
// Handles are on task0 (local), and task2, but op is on task1.
matmul = MatMulOp(ctx, h0_task0, h1_task2);
}
if (remote) {
TFE_OpSetDevice(matmul, task1_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
} else if (!async) {
// Set the local device to CPU to easily validate mirroring
string cpu_device_name;
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
auto remote_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
// The input handles should never change since they have been mirrored.
ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr));
}
@ -241,21 +262,22 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// TODO(gjn): Add support for waiting on async local mirrors
if (!remote && !async) {
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
auto remote_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
// The input handles should never change since they have been mirrored.
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
}
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteTensorHandle(retval_task0);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
@ -270,12 +292,18 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
TFE_DeleteTensorHandle(h1_task0);
TFE_DeleteTensorHandle(h1_task2);
TFE_DeleteTensorHandle(retvals[0]);
for (auto* h : handles_task0) {
TFE_DeleteTensorHandle(h);
}
for (auto* h : handles_task2) {
TFE_DeleteTensorHandle(h);
}
TFE_DeleteOp(matmul);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteExecutor(executor);
if (func) {
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
@ -290,22 +318,37 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
}
TEST(CAPI, RemoteExecuteSilentCopies) {
TestRemoteExecuteSilentCopies(false, true, false);
TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/true,
/*func=*/false,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
TestRemoteExecuteSilentCopies(true, true, false);
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/true, /*func=*/false,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
TestRemoteExecuteSilentCopies(true, true, true);
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/true, /*func=*/true,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
TestRemoteExecuteSilentCopies(false, false, false);
TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/false,
/*func=*/false,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
TestRemoteExecuteSilentCopies(true, false, false);
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false,
/*func=*/false,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
TestRemoteExecuteSilentCopies(true, false, true);
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
// A remote input may be not ready when we start running a function. Test that
// the function execution should wait until the remote input is ready.
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true,
/*heavy_load_on_streaming_rpc=*/true);
}
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
@ -378,150 +421,4 @@ TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) {
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) {
TestRemoteExecuteDeleteContextWithOutstandingRPC(true);
}
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
const std::vector<float>& expected_values) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
std::unique_ptr<float[]> actual_values(new float[expected_values.size()]);
EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t));
memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
for (int i = 0; i < expected_values.size(); i++) {
EXPECT_EQ(expected_values[i], actual_values[i])
<< "Mismatch in expected values at (zero-based) index " << i;
}
}
void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
const char* remote_device_name,
const char* local_device_name) {
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0);
TFE_OpSetDevice(matmul, remote_device_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* retval_task0 =
TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22});
TFE_DeleteTensorHandle(retval_task0);
TFE_DeleteTensorHandle(h0_task0);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteOp(matmul);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteExecutor(executor);
TF_DeleteStatus(status);
}
void TestRemoteExecuteChangeServerDef(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
const char local_device_name[] =
"/job:localhost/replica:0/task:0/device:CPU:0";
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
// Update the server def with a new set of names (worker instead of
// localhost).
tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2);
serialized = updated_server_def.SerializeAsString();
updated_server_def.set_task_index(1);
tensorflow::Status s = tensorflow::GrpcServer::Create(
updated_server_def, tensorflow::Env::Default(), &worker_server);
ASSERT_TRUE(s.ok()) << s.error_message();
ASSERT_TRUE(worker_server->Start().ok());
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Create a new tensor_handle.
TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle(ctx);
// Check that copying it to the old remote device (named localhost) fails.
TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status);
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Copying and executing on the new remote device works.
const char new_remote_device_name[] =
"/job:worker/replica:0/task:1/device:CPU:0";
const char new_local_device_name[] =
"/job:worker/replica:0/task:0/device:CPU:0";
auto* h0_task1_new = TFE_TensorHandleCopyToDevice(
h0_task0_new, ctx, new_remote_device_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(h0_task0_new);
TFE_DeleteTensorHandle(h0_task1_new);
CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name,
new_local_device_name);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteExecutor(executor);
TF_DeleteStatus(status);
TFE_DeleteContext(ctx);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
}
TEST(CAPI, RemoteExecuteChangeServerDef) {
TestRemoteExecuteChangeServerDef(false);
}
TEST(CAPI, RemoteExecuteChangeServerDefAsync) {
TestRemoteExecuteChangeServerDef(true);
}
} // namespace

View File

@ -27,6 +27,8 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/function.pb.h"
@ -416,8 +418,10 @@ void TensorHandleSilentCopy(bool async,
hcpu, ctx, gpu_device_name.c_str(), status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
auto cpu_arg = tensorflow::TensorHandleFromInterface(hcpu->handle);
auto gpu_arg = tensorflow::TensorHandleFromInterface(hgpu->handle);
auto cpu_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hcpu));
auto gpu_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hgpu));
auto gpu_device = absl::get<tensorflow::Device*>(gpu_arg->device());
ASSERT_FALSE(cpu_arg->HasLocalMirror(gpu_device));
@ -1346,7 +1350,7 @@ TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
tensorflow::AttrValueMap ExtractAttrs(TFE_Op* op) {
tensorflow::AttrValueMap attr_values;
tensorflow::EagerOperation* operation =
tensorflow::OperationFromInterface(op->operation);
tensorflow::OperationFromInterface(tensorflow::unwrap(op));
operation->Attrs().FillAttrValueMap(&attr_values);
return attr_values;
}
@ -1482,10 +1486,10 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
TFE_TensorHandle* inputs[] = {input1, input2};
TFE_OpAddInput(concatOp, dim, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CHECK(concatOp->operation->OpDef());
CHECK(tensorflow::unwrap(concatOp)->OpDef());
TFE_OpAddInput(concatOp, inputs[0], status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_FALSE(concatOp->operation->OpDef())
EXPECT_FALSE(tensorflow::unwrap(concatOp)->OpDef())
<< "Inference context is still present";
TFE_OpAddInput(concatOp, inputs[1], status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
@ -1577,7 +1581,7 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) {
TFE_DeleteContext(ctx);
}
TEST(CAPI, TestTFE_OpGetAttrs) {
TEST(CAPI, TestTFE_OpAddAttrs) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
@ -1587,12 +1591,11 @@ TEST(CAPI, TestTFE_OpGetAttrs) {
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
TFE_OpAttrs attributes;
TFE_OpGetAttrs(var_op, &attributes);
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(var_op);
TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status);
TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT);
TFE_OpAddAttrs(copy_op, &attributes);
TFE_OpAddAttrs(copy_op, attributes);
unsigned char is_list = 0;
ASSERT_EQ(TF_ATTR_TYPE,
TFE_OpGetAttrType(copy_op, "dtype", &is_list, status));
@ -1603,7 +1606,7 @@ TEST(CAPI, TestTFE_OpGetAttrs) {
tensorflow::AttrValueMap attr_values;
tensorflow::EagerOperation* op =
tensorflow::OperationFromInterface(copy_op->operation);
tensorflow::OperationFromInterface(tensorflow::unwrap(copy_op));
op->Attrs().FillAttrValueMap(&attr_values);
EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type());
@ -1624,11 +1627,10 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
TFE_OpAttrs attributes;
TFE_OpGetAttrs(var_op, &attributes);
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(var_op);
TF_Buffer* serialized_attr_values = TF_NewBuffer();
TFE_OpAttrsSerialize(&attributes, serialized_attr_values, status);
TFE_OpAttrsSerialize(attributes, serialized_attr_values, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::NameAttrList name_and_attrs;
ASSERT_TRUE(name_and_attrs.ParseFromArray(serialized_attr_values->data,
@ -1651,7 +1653,7 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) {
tensorflow::AttrValueMap attr_values;
tensorflow::EagerOperation* op =
tensorflow::OperationFromInterface(var_op_2->operation);
tensorflow::OperationFromInterface(tensorflow::unwrap(var_op_2));
op->Attrs().FillAttrValueMap(&attr_values);
EXPECT_EQ(tensorflow::DT_INT64, attr_values.find("dtype")->second.type());

View File

@ -15,486 +15,78 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/strcat.h"
#include <vector>
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::string;
using tensorflow::internal::OutputList;
using tensorflow::internal::unwrap;
// =============================================================================
// Unified Execution APIs for Eager and tracing backends.
// Public C API entry points
//
// These are only the generic entry points for the C API. This file does not
// have any visibility into the graph/eager implementation and is only providing
// C bindings to the abstract classes defined in the
// c_api_unified_experimental_internal.h header.
//
// =============================================================================
typedef void (*ExecuteOperation)(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs,
TF_OutputList* o, TF_ExecutionContext* ctx,
TF_Status* s);
struct TF_ExecutionContext {
// Needed to implement our own version of RTTI since dynamic_cast is not
// supported in mobile builds.
enum ExecutionContextKind { GraphContext, EagerContext };
explicit TF_ExecutionContext(ExecutionContextKind kind) : k(kind) {}
ExecutionContextKind getKind() const { return k; }
virtual void ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs,
TF_OutputList* o, TF_Status* s) = 0;
virtual TF_AbstractOp* CreateOperation() = 0;
virtual void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) = 0;
virtual ~TF_ExecutionContext() {}
private:
const ExecutionContextKind k;
};
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete c; }
template <typename T, typename S>
T* dynamic_cast_helper(S source) {
if (source->getKind() != T::kKind) {
return nullptr;
}
return tensorflow::down_cast<T*>(source);
}
class TF_GraphContext;
class TF_EagerContext;
struct TF_GraphTensor {
TF_Output output;
TF_GraphContext* ctx;
};
struct TF_AbstractTensor {
absl::variant<TFE_TensorHandle*, TF_GraphTensor*> t;
~TF_AbstractTensor() {
if (absl::holds_alternative<TFE_TensorHandle*>(t)) {
TFE_DeleteTensorHandle(absl::get<TFE_TensorHandle*>(t));
} else if (absl::holds_alternative<TF_GraphTensor*>(t)) {
delete absl::get<TF_GraphTensor*>(t);
}
}
};
struct TF_AbstractOp {
// Needed to implement our own version of RTTI since dynamic_cast is not
// supported in mobile builds.
enum AbstractOpKind { GraphOp, EagerOp };
explicit TF_AbstractOp(AbstractOpKind kind) : k(kind) {}
AbstractOpKind getKind() const { return k; }
virtual void SetOpType(const char* const op_type, TF_Status* s) = 0;
virtual void SetOpName(const char* const op_name, TF_Status* s) = 0;
virtual void SetAttrType(const char* const attr_name, TF_DataType value,
TF_Status* s) = 0;
virtual ~TF_AbstractOp() {}
private:
const AbstractOpKind k;
};
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); }
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
return c->CreateOperation();
return wrap(unwrap(c)->CreateOperation());
}
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete op; }
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete unwrap(op); }
class TF_GraphOp : public TF_AbstractOp {
public:
explicit TF_GraphOp(TF_Graph* g) : TF_AbstractOp(kKind), g_(g) {}
void SetOpType(const char* const op_type, TF_Status* s) override {
if (op_) {
TF_SetStatus(
s, TF_FAILED_PRECONDITION,
absl::StrCat("SetOpType called on already built op.").c_str());
return;
}
if (op_name_ != nullptr) {
op_.reset(TF_NewOperation(g_, op_type, op_name_));
op_name_ = nullptr;
} else {
op_type_ = op_type;
}
}
void SetOpName(const char* const op_name, TF_Status* s) override {
if (op_) {
TF_SetStatus(
s, TF_FAILED_PRECONDITION,
absl::StrCat("SetOpName called on already built op.").c_str());
return;
}
if (op_type_ != nullptr) {
op_.reset(TF_NewOperation(g_, op_type_, op_name));
op_type_ = nullptr;
} else {
op_name_ = op_name;
}
}
void SetAttrType(const char* const attr_name, TF_DataType value,
TF_Status* s) override {
if (!op_) {
TF_SetStatus(
s, TF_FAILED_PRECONDITION,
"op_type and op_name must be specified before specifying attrs.");
return;
}
TF_SetAttrType(op_.get(), attr_name, value);
}
~TF_GraphOp() override {}
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete unwrap(t); }
static constexpr AbstractOpKind kKind = GraphOp;
private:
friend class TF_GraphContext; // For access to op_.
TF_Graph* g_;
std::unique_ptr<TF_OperationDescription> op_;
// Hold `op_type` and `op_name` till both are available since we need both
// to build a graph operation.
const char* op_type_ = nullptr;
const char* op_name_ = nullptr;
};
class TF_EagerOp : public TF_AbstractOp {
public:
explicit TF_EagerOp(TFE_Context* ctx) : TF_AbstractOp(kKind), ctx_(ctx) {}
void SetOpType(const char* const op_type, TF_Status* s) override {
op_ = TFE_NewOp(ctx_, op_type, s);
}
void SetOpName(const char* const op_name, TF_Status* s) override {
// Name is ignored in eager mode.
}
void SetAttrType(const char* const attr_name, TF_DataType value,
TF_Status* s) override {
if (op_ == nullptr) {
TF_SetStatus(s, TF_FAILED_PRECONDITION,
"op_type must be specified before specifying attrs.");
return;
}
TFE_OpSetAttrType(op_, attr_name, value);
}
~TF_EagerOp() override { TFE_DeleteOp(op_); }
static constexpr AbstractOpKind kKind = EagerOp;
private:
friend class TF_EagerContext; // For access to op_.
TFE_Op* op_ = nullptr;
TFE_Context* ctx_;
};
bool IsEagerTensor(const TF_AbstractTensor* const t) {
return absl::holds_alternative<TFE_TensorHandle*>(t->t);
}
struct TF_OutputList {
std::vector<TF_AbstractTensor*> outputs;
int expected_num_outputs = -1;
};
struct TF_AbstractFunction {
TF_Function* func = nullptr;
~TF_AbstractFunction() { TF_DeleteFunction(func); }
};
class TF_EagerContext : public TF_ExecutionContext {
public:
TF_EagerContext() : TF_ExecutionContext(kKind) {}
void Build(TFE_ContextOptions* options, TF_Status* status) {
eager_ctx_ = TFE_NewContext(options, status);
}
TF_AbstractOp* CreateOperation() override {
// TODO(srbs): Should the lifetime of this op be tied to the context.
return new TF_EagerOp(eager_ctx_);
}
void ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_Status* s) override {
auto* eager_op = dynamic_cast_helper<TF_EagerOp>(op);
if (eager_op == nullptr) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Unable to cast TF_AbstractOp to TF_EagerOp.");
return;
}
auto* tfe_op = eager_op->op_;
if (TF_GetCode(s) != TF_OK) return;
for (int i = 0; i < num_inputs; ++i) {
if (!IsEagerTensor(inputs[i])) {
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
return;
}
TFE_OpAddInput(tfe_op, absl::get<TFE_TensorHandle*>(inputs[i]->t), s);
if (TF_GetCode(s) != TF_OK) return;
}
if (o->expected_num_outputs == -1) {
string msg =
"The number of outputs must be provided in eager mode. Use "
"TF_OutputListSetNumOutputs.";
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return;
}
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals;
int num_retvals = o->expected_num_outputs;
retvals.resize(num_retvals);
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
if (TF_GetCode(s) != TF_OK) {
return;
}
o->outputs.clear();
o->outputs.reserve(num_retvals);
for (int i = 0; i < num_retvals; ++i) {
auto* t = new TF_AbstractTensor();
t->t = retvals[i];
o->outputs.push_back(t);
}
}
void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) override {
TFE_ContextAddFunction(eager_ctx_, func->func, s);
}
~TF_EagerContext() override { TFE_DeleteContext(eager_ctx_); }
static constexpr ExecutionContextKind kKind = EagerContext;
private:
friend TFE_Context* TF_ExecutionContextGetTFEContext(
TF_ExecutionContext* ctx);
TFE_Context* eager_ctx_;
};
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete t; }
TF_GraphContext* GetGraphContext(TF_AbstractTensor const* t) {
return absl::get<TF_GraphTensor*>(t->t)->ctx;
}
class TF_GraphContext : public TF_ExecutionContext {
public:
TF_GraphContext()
: TF_ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {}
TF_AbstractOp* CreateOperation() override {
// TODO(srbs): Should the lifetime of this op be tied to the context.
return new TF_GraphOp(graph_.get());
}
void ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_Status* s) override {
auto* graph_op = dynamic_cast_helper<TF_GraphOp>(op);
if (graph_op == nullptr) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Unable to cast TF_AbstractOp to TF_GraphOp.");
return;
}
auto* tf_opdesc = graph_op->op_.release();
for (int i = 0; i < num_inputs; ++i) {
auto* input = inputs[i];
if (IsEagerTensor(input)) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Capturing eager tensors is not supported yet.");
return;
} else {
if (GetGraphContext(input) != this) {
TF_SetStatus(
s, TF_INVALID_ARGUMENT,
"Capturing tensors from other graphs is not supported yet.");
return;
}
TF_AddInput(tf_opdesc, absl::get<TF_GraphTensor*>(input->t)->output);
}
}
auto* operation = TF_FinishOperation(tf_opdesc, s);
// TF_FinishOperation deletes `tf_opdesc` so clear its reference.
graph_op->op_ = nullptr;
if (TF_GetCode(s) != TF_OK) return;
int num_outputs = TF_OperationNumOutputs(operation);
o->outputs.clear();
o->outputs.reserve(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
auto* t = new TF_AbstractTensor;
TF_GraphTensor* graph_t = new TF_GraphTensor;
graph_t->ctx = this;
graph_t->output = {operation, i};
t->t = graph_t;
o->outputs.push_back(t);
}
}
TF_Function* ToFunction(const char* fn_name, int num_inputs,
const TF_AbstractTensor* inputs, int num_outputs,
const TF_AbstractTensor* outputs,
TF_Status* status) const {
std::vector<TF_Output> graph_inputs;
graph_inputs.resize(num_inputs);
std::vector<TF_Output> graph_outputs;
graph_outputs.resize(num_outputs);
for (int i = 0; i < num_inputs; i++) {
graph_inputs[i] = absl::get<TF_GraphTensor*>(inputs[i].t)->output;
}
for (int i = 0; i < num_outputs; i++) {
graph_outputs[i] = absl::get<TF_GraphTensor*>(outputs[i].t)->output;
}
return TF_GraphToFunction(graph_.get(), fn_name, 0, -1, nullptr,
graph_inputs.size(), graph_inputs.data(),
graph_outputs.size(), graph_outputs.data(),
nullptr, nullptr, fn_name, status);
}
void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) override {
TF_SetStatus(s, TF_UNIMPLEMENTED,
"Registering graph functions has not been implemented yet.");
}
~TF_GraphContext() override {}
static constexpr ExecutionContextKind kKind = GraphContext;
private:
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
};
struct TF_GraphContextOptions {};
struct TF_EagerContextOptions {
explicit TF_EagerContextOptions(TFE_ContextOptions* options)
: options(options) {}
TFE_ContextOptions* options; // Not owned.
};
struct TF_ExecutionContextOptions {
absl::variant<TF_GraphContextOptions*, TF_EagerContextOptions*> options;
~TF_ExecutionContextOptions() {
if (absl::holds_alternative<TF_GraphContextOptions*>(options)) {
delete absl::get<TF_GraphContextOptions*>(options);
} else if (absl::holds_alternative<TF_EagerContextOptions*>(options)) {
delete absl::get<TF_EagerContextOptions*>(options);
}
}
};
TF_ExecutionContextOptions* TF_NewGraphContextOptions() {
auto* options = new TF_ExecutionContextOptions();
options->options = new TF_GraphContextOptions();
return options;
}
void TF_DeleteExecutionContextOptions(TF_ExecutionContextOptions* options) {
delete options;
}
TF_ExecutionContextOptions* TF_NewEagerContextOptions(
TFE_ContextOptions* tfe_options) {
auto* options = new TF_ExecutionContextOptions();
options->options = new TF_EagerContextOptions(tfe_options);
return options;
}
TF_ExecutionContext* TF_NewExecutionContext(TF_ExecutionContextOptions* options,
TF_Status* s) {
if (absl::holds_alternative<TF_EagerContextOptions*>(options->options)) {
auto* ctx = new TF_EagerContext();
ctx->Build(absl::get<TF_EagerContextOptions*>(options->options)->options,
s);
return ctx;
} else {
return new TF_GraphContext();
}
}
TF_OutputList* TF_NewOutputList() { return new TF_OutputList; }
void TF_DeleteOutputList(TF_OutputList* o) { delete o; }
TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); }
void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); }
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
TF_Status* s) {
o->expected_num_outputs = num_outputs;
unwrap(o)->expected_num_outputs = num_outputs;
}
int TF_OutputListNumOutputs(TF_OutputList* o) {
return unwrap(o)->outputs.size();
}
int TF_OutputListNumOutputs(TF_OutputList* o) { return o->outputs.size(); }
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
return o->outputs[i];
return wrap(unwrap(o)->outputs[i]);
}
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
TF_Status* s) {
op->SetOpType(op_type, s);
unwrap(op)->SetOpType(op_type, s);
}
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
TF_Status* s) {
op->SetOpName(op_name, s);
unwrap(op)->SetOpName(op_name, s);
}
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
TF_DataType value, TF_Status* s) {
op->SetAttrType(attr_name, value, s);
unwrap(op)->SetAttrType(attr_name, value, s);
}
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s) {
ctx->ExecuteOperation(op, num_inputs, inputs, o, s);
unwrap(ctx)->ExecuteOperation(unwrap(op), num_inputs, &unwrap(*inputs),
unwrap(o), s);
}
TF_AbstractFunction* TF_ExecutionContextToFunction(
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const TF_AbstractTensor* inputs, int num_outputs,
const TF_AbstractTensor* outputs, TF_Status* status) {
auto* graph_ctx = dynamic_cast_helper<const TF_GraphContext>(fn_body);
if (graph_ctx == nullptr) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"fn_body is not a TF_GraphContext.");
return nullptr;
}
TF_AbstractFunction* func = new TF_AbstractFunction;
func->func = graph_ctx->ToFunction(fn_name, num_inputs, inputs, num_outputs,
outputs, status);
return func;
void TF_DeleteAbstractFunction(TF_AbstractFunction* func) {
delete unwrap(func);
}
void TF_DeleteAbstractFunction(TF_AbstractFunction* func) { delete func; }
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext* ctx,
TF_AbstractFunction* func,
TF_Status* s) {
ctx->RegisterFunction(func, s);
}
// Temporary APIs till we figure out how to create scalar valued Eager
// tensors and how to get value out of eager abstract tensors.
TF_AbstractTensor* TF_NewAbstractTensor() {
TF_AbstractTensor* t = new TF_AbstractTensor;
return t;
}
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
TF_Status* s) {
at->t = t;
}
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s) {
if (!absl::holds_alternative<TFE_TensorHandle*>(at->t)) {
string msg = absl::StrCat("Not an eager tensor handle.",
reinterpret_cast<uintptr_t>(at));
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
}
return absl::get<TFE_TensorHandle*>(at->t);
}
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx) {
return dynamic_cast_helper<TF_EagerContext>(ctx)->eager_ctx_;
unwrap(ctx)->RegisterFunction(unwrap(func), s);
}

View File

@ -16,6 +16,7 @@ limitations under the License.
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#ifdef __cplusplus
@ -34,26 +35,34 @@ extern "C" {
// E.g. it could know whether we're in eager mode or in graph mode, keeps track
// of gradient tapes, etc.
typedef struct TF_ExecutionContext TF_ExecutionContext;
// A TF_AbstractTensor is an input to an operation. E.g. it could be a union
// type of eager and graph tensors.
// type of eager and graph tensors. It is also the result of executing an
// operation.
typedef struct TF_AbstractTensor TF_AbstractTensor;
// A TF_AbstractOp is the metadata we need to execute an operation. E.g. this
// could contain the op type and other attributes.
typedef struct TF_AbstractOp TF_AbstractOp;
// `TF_ExecutionContextOptions` define what type of `TF_ExecutionContext` is
// created. It can be used to pass context specific params.
typedef struct TF_ExecutionContextOptions TF_ExecutionContextOptions;
void TF_DeleteExecutionContextOptions(TF_ExecutionContextOptions*);
// Stores a function representation that can be used for execution or for
// setting functional attributes of other composite ops e.g. control flow.
typedef struct TF_AbstractFunction TF_AbstractFunction;
// Creates a context for tracing the execution of operations into a function.
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s);
// Creates a context for eager execution of operations.
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*,
TF_Status* s);
TF_ExecutionContext* TF_NewExecutionContext(TF_ExecutionContextOptions*,
TF_Status* s);
void TF_DeleteExecutionContext(TF_ExecutionContext*);
// Create an operation suitable to use with the provided context. The operation
// requires its type (e.g. "AddV2") to be set independently.
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx);
void TF_DeleteAbstractOp(TF_AbstractOp*);
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
// TODO(srbs): Add APIs for specifying attrs etc.
// `op_type` must outlive `op`.
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
@ -65,9 +74,16 @@ void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
TF_DataType value, TF_Status* s);
// TF_OutputList just lets us not specify the number of outputs of an operation
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
// TF_OutputList holds the list of TF_AbstractTensor that results from executing
// an operation.
// It just lets us not specify the number of outputs of an operation
// beforehand. This forces a memory allocation in the runtime, which is bad, but
// it allows for generic code.
// TODO(aminim): the description above isn't clear with respect to
// TF_OutputListNumOutputs and the current eager implementation which requires
// the number of outputs to be set by the client.
typedef struct TF_OutputList TF_OutputList;
TF_OutputList* TF_NewOutputList();
void TF_DeleteOutputList(TF_OutputList* o);
@ -75,38 +91,38 @@ void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*);
int TF_OutputListNumOutputs(TF_OutputList* o);
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
// Stores a function representation that can be used for execution or for
// setting functional attributes of other composite ops e.g. control flow.
typedef struct TF_AbstractFunction TF_AbstractFunction;
TF_AbstractFunction* TF_ExecutionContextToFunction(
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const TF_AbstractTensor* inputs, int num_outputs,
const TF_AbstractTensor* outputs, TF_Status* status);
void TF_DeleteAbstractFunction(TF_AbstractFunction*);
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext*,
TF_AbstractFunction*, TF_Status*);
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
// capture some inputs and then add a node in the graph, and after
// execution/node creation it'll go and record things that happened in any tape
// which happens to be active.
// capture some inputs and then add a node in the graph. The output tensors are
// returned through the provided TF_OutputList.
// Any active tape will observe the effects of this execution.
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s);
// -----------------------------------------------------------------------------
// APIs specific to Eager and graph modes
// -----------------------------------------------------------------------------
// Creates a new TF_AbstractFunction from the current tracing states in the
// context. The returned TF_GraphToFunction must be deleted by the client.
// TODO(aminim): clarify the contract on the state of the context after this
// call.
TF_AbstractFunction* TF_ExecutionContextToFunction(
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const TF_AbstractTensor* inputs, int num_outputs,
const TF_AbstractTensor* outputs, TF_Status* status);
TF_ExecutionContextOptions* TF_NewGraphContextOptions();
TF_ExecutionContextOptions* TF_NewEagerContextOptions(TFE_ContextOptions*);
void TF_DeleteAbstractFunction(TF_AbstractFunction*);
// Register the function with the given context. This is particularly useful for
// making a function available to an eager context.
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext*,
TF_AbstractFunction*, TF_Status*);
// -----------------------------------------------------------------------------
// APIs specific to Eager modes
// -----------------------------------------------------------------------------
// Temporary APIs till we figure out how to create scalar valued Eager
// tensors and how to get value out of eager abstract tensors.
TF_AbstractTensor* TF_NewAbstractTensor();
void TF_AbstractTensorSetEagerTensor(
TF_AbstractTensor* at, TFE_TensorHandle* t,
TF_Status* s); // `at` takes ownership of `t`.
TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t,
TF_Status* s);
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s);
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext*);

View File

@ -0,0 +1,183 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <vector>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::string;
namespace tensorflow {
namespace internal {
// Simple wrapper over a TFE_TensorHandle
struct EagerTensor : public AbstractTensor {
TFE_TensorHandle* t = nullptr;
EagerTensor() : AbstractTensor(kKind) {}
explicit EagerTensor(TFE_TensorHandle* t) : AbstractTensor(kKind), t(t) {}
~EagerTensor() override { TFE_DeleteTensorHandle(t); }
static constexpr AbstractTensorKind kKind = kEagerTensor;
};
// Simple wrapper over a TFE_Op
class EagerOp : public AbstractOp {
public:
explicit EagerOp(TFE_Context* ctx) : AbstractOp(kKind), ctx_(ctx) {}
void SetOpType(const char* const op_type, TF_Status* s) override {
op_ = TFE_NewOp(ctx_, op_type, s);
}
void SetOpName(const char* const op_name, TF_Status* s) override {
// Name is ignored in eager mode.
}
void SetAttrType(const char* const attr_name, TF_DataType value,
TF_Status* s) override {
if (op_ == nullptr) {
TF_SetStatus(s, TF_FAILED_PRECONDITION,
"op_type must be specified before specifying attrs.");
return;
}
TFE_OpSetAttrType(op_, attr_name, value);
}
~EagerOp() override { TFE_DeleteOp(op_); }
static constexpr AbstractOpKind kKind = kEagerOp;
private:
friend class EagerContext; // For access to op_.
TFE_Op* op_ = nullptr;
TFE_Context* ctx_;
};
// Wraps a TFE_Context and dispatch EagerOp with EagerTensor inputs.
class EagerContext : public ExecutionContext {
public:
EagerContext() : ExecutionContext(kKind) {}
void Build(TFE_ContextOptions* options, TF_Status* status) {
eager_ctx_ = TFE_NewContext(options, status);
}
AbstractOp* CreateOperation() override {
// TODO(srbs): Should the lifetime of this op be tied to the context.
return new EagerOp(eager_ctx_);
}
void ExecuteOperation(AbstractOp* op, int num_inputs,
AbstractTensor* const* inputs, OutputList* o,
TF_Status* s) override {
auto* eager_op = dyncast<EagerOp>(op);
if (eager_op == nullptr) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Unable to cast AbstractOp to TF_EagerOp.");
return;
}
auto* tfe_op = eager_op->op_;
if (TF_GetCode(s) != TF_OK) return;
for (int i = 0; i < num_inputs; ++i) {
auto* eager_tensor = dyncast<const EagerTensor>(inputs[i]);
if (!eager_tensor) {
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
return;
}
TFE_OpAddInput(tfe_op, eager_tensor->t, s);
if (TF_GetCode(s) != TF_OK) return;
}
if (o->expected_num_outputs == -1) {
string msg =
"The number of outputs must be provided in eager mode. Use "
"TF_OutputListSetNumOutputs.";
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return;
}
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals;
int num_retvals = o->expected_num_outputs;
retvals.resize(num_retvals);
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
if (TF_GetCode(s) != TF_OK) {
return;
}
o->outputs.clear();
o->outputs.reserve(num_retvals);
for (int i = 0; i < num_retvals; ++i) {
o->outputs.push_back(new EagerTensor(retvals[i]));
}
}
void RegisterFunction(AbstractFunction* afunc, TF_Status* s) override {
auto* func = afunc->GetTfFunction(s);
if (!func) {
return;
}
TFE_ContextAddFunction(eager_ctx_, func, s);
}
~EagerContext() override { TFE_DeleteContext(eager_ctx_); }
static constexpr ExecutionContextKind kKind = kEagerContext;
private:
friend TFE_Context* ::TF_ExecutionContextGetTFEContext(
TF_ExecutionContext* ctx);
TFE_Context* eager_ctx_;
};
} // namespace internal
} // namespace tensorflow
// =============================================================================
// Public C API entry points
// These are only the entry points specific to the Eager API.
// =============================================================================
using tensorflow::internal::dyncast;
using tensorflow::internal::unwrap;
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions* options,
TF_Status* s) {
auto* ctx = new tensorflow::internal::EagerContext();
ctx->Build(options, s);
return wrap(ctx);
}
TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t,
TF_Status* s) {
return wrap(new tensorflow::internal::EagerTensor(t));
}
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s) {
auto* eager_tensor = dyncast<tensorflow::internal::EagerTensor>(unwrap(at));
if (!eager_tensor) {
string msg = tensorflow::strings::StrCat("Not an eager tensor handle.",
reinterpret_cast<uintptr_t>(at));
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
}
return eager_tensor->t;
}
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx) {
auto* eager_ctx = dyncast<tensorflow::internal::EagerContext>(unwrap(ctx));
if (!eager_ctx) return nullptr;
return eager_ctx->eager_ctx_;
}

View File

@ -0,0 +1,248 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <vector>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::string;
namespace tensorflow {
namespace internal {
class GraphContext;
// GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index
// into the list of outputs for the operation.
struct GraphTensor : public AbstractTensor {
TF_Output output{};
GraphContext* ctx = nullptr;
GraphTensor() : AbstractTensor(kKind) {}
GraphTensor(TF_Output output, GraphContext* ctx)
: AbstractTensor(kKind), output(output), ctx(ctx) {}
static constexpr AbstractTensorKind kKind = kGraphTensor;
};
// GraphOp wraps and populate a TF_OperationDescription.
class GraphOp : public AbstractOp {
public:
explicit GraphOp(TF_Graph* g) : AbstractOp(kKind), g_(g) {}
void SetOpType(const char* const op_type, TF_Status* s) override {
if (op_) {
TF_SetStatus(
s, TF_FAILED_PRECONDITION,
strings::StrCat("SetOpType called on already built op.").c_str());
return;
}
if (op_name_ != nullptr) {
op_.reset(TF_NewOperation(g_, op_type, op_name_));
op_name_ = nullptr;
} else {
op_type_ = op_type;
}
}
void SetOpName(const char* const op_name, TF_Status* s) override {
if (op_) {
TF_SetStatus(
s, TF_FAILED_PRECONDITION,
strings::StrCat("SetOpName called on already built op.").c_str());
return;
}
if (op_type_ != nullptr) {
op_.reset(TF_NewOperation(g_, op_type_, op_name));
op_type_ = nullptr;
} else {
op_name_ = op_name;
}
}
void SetAttrType(const char* const attr_name, TF_DataType value,
TF_Status* s) override {
if (!op_) {
TF_SetStatus(
s, TF_FAILED_PRECONDITION,
"op_type and op_name must be specified before specifying attrs.");
return;
}
TF_SetAttrType(op_.get(), attr_name, value);
}
~GraphOp() override {}
static constexpr AbstractOpKind kKind = kGraphOp;
private:
friend class GraphContext; // For access to op_.
TF_Graph* g_;
std::unique_ptr<TF_OperationDescription> op_;
// Hold `op_type` and `op_name` till both are available since we need both
// to build a graph operation.
const char* op_type_ = nullptr;
const char* op_name_ = nullptr;
};
// GraphFunction is a thin wrapper over a TF_Function.
struct GraphFunction : public AbstractFunction {
TF_Function* func = nullptr;
GraphFunction() : AbstractFunction(kKind) {}
explicit GraphFunction(TF_Function* func)
: AbstractFunction(kKind), func(func) {}
~GraphFunction() override {
if (func) TF_DeleteFunction(func);
}
TF_Function* GetTfFunction(TF_Status* s) override { return func; }
static constexpr AbstractFunctionKind kKind = kGraphFunc;
};
// GraphContext wraps a TF_Graph and manages the "execution" of operation, i.e.
// adding them to the graph.
class GraphContext : public ExecutionContext {
public:
GraphContext()
: ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {}
AbstractOp* CreateOperation() override {
// TODO(srbs): Should the lifetime of this op be tied to the context.
return new GraphOp(graph_.get());
}
void ExecuteOperation(AbstractOp* op, int num_inputs,
AbstractTensor* const* inputs, OutputList* o,
TF_Status* s) override {
auto* graph_op = dyncast<GraphOp>(op);
if (graph_op == nullptr) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Unable to cast AbstractOp to TF_GraphOp.");
return;
}
auto* tf_opdesc = graph_op->op_.release();
for (int i = 0; i < num_inputs; ++i) {
auto* graph_tensor = dyncast<GraphTensor>(inputs[i]);
if (!graph_tensor) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Capturing eager tensors is not supported yet.");
return;
} else {
if (graph_tensor->ctx != this) {
TF_SetStatus(
s, TF_INVALID_ARGUMENT,
"Capturing tensors from other graphs is not supported yet.");
return;
}
TF_AddInput(tf_opdesc, graph_tensor->output);
}
}
auto* operation = TF_FinishOperation(tf_opdesc, s);
// TF_FinishOperation deletes `tf_opdesc` so clear its reference.
graph_op->op_ = nullptr;
if (TF_GetCode(s) != TF_OK) return;
int num_outputs = TF_OperationNumOutputs(operation);
o->outputs.clear();
o->outputs.reserve(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
o->outputs.push_back(new GraphTensor({operation, i}, this));
}
}
TF_Function* ToFunction(const char* fn_name, int num_inputs,
const GraphTensor* inputs, int num_outputs,
const GraphTensor* outputs, TF_Status* status) const {
std::vector<TF_Output> graph_inputs;
graph_inputs.resize(num_inputs);
std::vector<TF_Output> graph_outputs;
graph_outputs.resize(num_outputs);
for (int i = 0; i < num_inputs; i++) {
graph_inputs[i] = inputs[i].output;
}
for (int i = 0; i < num_outputs; i++) {
graph_outputs[i] = outputs[i].output;
}
return TF_GraphToFunction(graph_.get(), fn_name, 0, -1, nullptr,
graph_inputs.size(), graph_inputs.data(),
graph_outputs.size(), graph_outputs.data(),
nullptr, nullptr, fn_name, status);
}
void RegisterFunction(AbstractFunction* func, TF_Status* s) override {
TF_SetStatus(s, TF_UNIMPLEMENTED,
"Registering graph functions has not been implemented yet.");
}
~GraphContext() override {}
static constexpr ExecutionContextKind kKind = kGraphContext;
private:
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
};
// Helper that converts the graph currently held in the context into a function.
static AbstractFunction* ExecutionContextToFunction(
const ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const AbstractTensor* inputs, int num_outputs,
const AbstractTensor* outputs, TF_Status* status) {
auto* graph_ctx = dyncast<const GraphContext>(fn_body);
if (graph_ctx == nullptr) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"fn_body is not a TF_GraphContext.");
return nullptr;
}
auto* graph_inputs = dyncast<const GraphTensor>(inputs);
if (!graph_inputs) {
TF_SetStatus(status, TF_INVALID_ARGUMENT, "inputs aren't GraphTensors.");
return nullptr;
}
auto* graph_outputs = dyncast<const GraphTensor>(outputs);
if (!graph_outputs) {
TF_SetStatus(status, TF_INVALID_ARGUMENT, "outputs aren't GraphTensors.");
return nullptr;
}
GraphFunction* func = new GraphFunction;
func->func = graph_ctx->ToFunction(fn_name, num_inputs, graph_inputs,
num_outputs, graph_outputs, status);
return func;
}
} // namespace internal
} // namespace tensorflow
// =============================================================================
// Public C API entry points
// These are only the entry points specific to the Graph API.
// =============================================================================
using tensorflow::internal::unwrap;
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s) {
return wrap(new tensorflow::internal::GraphContext());
}
TF_AbstractFunction* TF_ExecutionContextToFunction(
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const TF_AbstractTensor* inputs, int num_outputs,
const TF_AbstractTensor* outputs, TF_Status* status) {
return wrap(ExecutionContextToFunction(unwrap(fn_body), fn_name, num_inputs,
unwrap(inputs), num_outputs,
unwrap(outputs), status));
}

View File

@ -0,0 +1,184 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_
#include <vector>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/platform/casts.h"
namespace tensorflow {
namespace internal {
// =============================================================================
// Implementation detail for the unified execution APIs for Eager and tracing
// backends (graph/MLIR).
//
// This defines a set of abstract classes that are intended to provide the
// functionality of the opaque C types exposed in the public APIs defined in the
// `c_api_unified_experimental.h` header.
// =============================================================================
// We can't depend on C++ rtti, but we still want to be able to have a safe
// dynamic_cast to provide diagnostics to the user when the API is misused.
// Instead we model RTTI by listing all the possible subclasses for each
// abstract base. Each subclass initializes the base class with the right
// `kind`, which allows an equivalent to `std::dynamic_cast` provided by this
// utility.
template <typename T, typename S>
T* dyncast(S source) {
if (source->getKind() != T::kKind) {
return nullptr;
}
return tensorflow::down_cast<T*>(source);
}
// Represents either an EagerTensor or a GraphTensor.
// This base class does not expose any public methods other than to distinguish
// which subclass it actually is. The user is responsible to use the right
// type of AbstractTensor in their context (do not pass an EagerTensor to a
// GraphContext and vice-versa).
class AbstractTensor {
protected:
enum AbstractTensorKind { kGraphTensor, kEagerTensor, kMLIRTensor };
explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {}
public:
// Returns which subclass is this instance of.
AbstractTensorKind getKind() const { return kind_; }
virtual ~AbstractTensor() = default;
private:
const AbstractTensorKind kind_;
};
// Represents the results of the execution of an operation.
struct OutputList {
std::vector<AbstractTensor*> outputs;
int expected_num_outputs = -1;
};
// Holds the result of tracing a function.
class AbstractFunction {
protected:
enum AbstractFunctionKind { kGraphFunc };
explicit AbstractFunction(AbstractFunctionKind kind) : kind_(kind) {}
public:
// Returns which subclass is this instance of.
AbstractFunctionKind getKind() const { return kind_; }
virtual ~AbstractFunction() = default;
// Temporary API till we figure the right abstraction for AbstractFunction.
// At the moment both Eager and Graph needs access to a "TF_Function" object.
virtual TF_Function* GetTfFunction(TF_Status* s) = 0;
private:
const AbstractFunctionKind kind_;
};
// An abstract operation describes an operation by its type, name, and
// attributes. It can be "executed" by the context with some input tensors.
// It is allowed to reusing the same abstract operation for multiple execution
// on a given context, with the same or different input tensors.
class AbstractOp {
protected:
enum AbstractOpKind { kGraphOp, kEagerOp };
explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {}
public:
// Returns which subclass is this instance of.
AbstractOpKind getKind() const { return kind_; }
virtual ~AbstractOp() = default;
// Sets the type of the operation (for example `AddV2`).
virtual void SetOpType(const char* op_type, TF_Status* s) = 0;
// Sets the name of the operation: this is an optional identifier that is
// not intended to carry semantics and preserved/propagated without
// guarantees.
virtual void SetOpName(const char* op_name, TF_Status* s) = 0;
// Add a `TypeAttribute` on the operation.
virtual void SetAttrType(const char* attr_name, TF_DataType value,
TF_Status* s) = 0;
private:
const AbstractOpKind kind_;
};
// This holds the context for the execution: dispatching operations either to an
// eager implementation or to a graph implementation.
struct ExecutionContext {
protected:
enum ExecutionContextKind { kGraphContext, kEagerContext };
explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {}
public:
// Returns which subclass is this instance of.
ExecutionContextKind getKind() const { return k; }
virtual ~ExecutionContext() = default;
// Executes the operation on the provided inputs and populate the OutputList
// with the results. The input tensors must match the current context.
// The effect of "executing" an operation depends on the context: in an Eager
// context it will dispatch it to the runtime for execution, while in a
// tracing context it will add the operation to the current function.
virtual void ExecuteOperation(AbstractOp* op, int num_inputs,
AbstractTensor* const* inputs, OutputList* o,
TF_Status* s) = 0;
// Creates an empty AbstractOperation suitable to use with this context.
virtual AbstractOp* CreateOperation() = 0;
// Registers a functions with this context, after this the function is
// available to be called/referenced by its name in this context.
virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0;
private:
const ExecutionContextKind k;
};
// Create utilities to wrap/unwrap: this convert from the C opaque types to the
// C++ implementation, and back.
#define MAKE_WRAP_UNWRAP(C_TYPEDEF, CPP_CLASS) \
static inline CPP_CLASS* const& unwrap(C_TYPEDEF* const& o) { \
return reinterpret_cast<CPP_CLASS* const&>(o); \
} \
static inline const CPP_CLASS* const& unwrap(const C_TYPEDEF* const& o) { \
return reinterpret_cast<const CPP_CLASS* const&>(o); \
} \
static inline C_TYPEDEF* const& wrap(CPP_CLASS* const& o) { \
return reinterpret_cast<C_TYPEDEF* const&>(o); \
} \
static inline const C_TYPEDEF* const& wrap(const CPP_CLASS* const& o) { \
return reinterpret_cast<const C_TYPEDEF* const&>(o); \
}
MAKE_WRAP_UNWRAP(TF_ExecutionContext, ExecutionContext)
MAKE_WRAP_UNWRAP(TF_AbstractFunction, AbstractFunction)
MAKE_WRAP_UNWRAP(TF_AbstractTensor, AbstractTensor)
MAKE_WRAP_UNWRAP(TF_AbstractOp, AbstractOp)
MAKE_WRAP_UNWRAP(TF_OutputList, OutputList)
} // namespace internal
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_

View File

@ -15,17 +15,14 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include <string.h>
#include <memory>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/cc/profiler/profiler.h"
#include "tensorflow/core/lib/monitoring/collection_registry.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/str_util.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
using tensorflow::string;
@ -36,8 +33,7 @@ TEST(UnifedCAPI, TestBasicEager) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContextOptions* options = TF_NewEagerContextOptions(opts);
TF_ExecutionContext* ctx = TF_NewExecutionContext(options, status.get());
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
@ -46,8 +42,8 @@ TEST(UnifedCAPI, TestBasicEager) {
// Build an abstract input tensor.
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx);
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
TF_AbstractTensor* at = TF_NewAbstractTensor();
TF_AbstractTensorSetEagerTensor(at, t, status.get());
TF_AbstractTensor* at =
TF_CreateAbstractTensorFromEagerTensor(t, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract operation.
@ -83,15 +79,12 @@ TEST(UnifedCAPI, TestBasicEager) {
TF_DeleteAbstractTensor(result);
TF_DeleteOutputList(o);
TF_DeleteExecutionContext(ctx);
TF_DeleteExecutionContextOptions(options);
}
TEST(UnifedCAPI, TestBasicGraph) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
TF_ExecutionContext* graph_ctx =
TF_NewExecutionContext(options, status.get());
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
@ -143,10 +136,8 @@ TEST(UnifedCAPI, TestBasicGraph) {
// Build eager context.
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContextOptions* eager_ctx_options =
TF_NewEagerContextOptions(opts);
TF_ExecutionContext* eager_execution_ctx =
TF_NewExecutionContext(eager_ctx_options, status.get());
TF_NewEagerExecutionContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
@ -158,11 +149,11 @@ TEST(UnifedCAPI, TestBasicGraph) {
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract input tensor.
TF_AbstractTensor* input_t = TF_NewAbstractTensor();
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(eager_execution_ctx);
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
TF_AbstractTensorSetEagerTensor(input_t, input_eager, status.get());
TF_AbstractTensor* input_t =
TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
@ -191,16 +182,13 @@ TEST(UnifedCAPI, TestBasicGraph) {
TF_DeleteExecutionContext(graph_ctx);
TF_DeleteExecutionContext(eager_execution_ctx);
TF_DeleteExecutionContextOptions(eager_ctx_options);
TF_DeleteExecutionContextOptions(options);
}
TEST(UnifedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContextOptions* options = TF_NewEagerContextOptions(opts);
TF_ExecutionContext* ctx = TF_NewExecutionContext(options, status.get());
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
@ -210,15 +198,12 @@ TEST(UnifedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
TF_DeleteExecutionContext(ctx);
TF_DeleteExecutionContextOptions(options);
}
TEST(UnifedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
TF_ExecutionContext* graph_ctx =
TF_NewExecutionContext(options, status.get());
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
@ -234,15 +219,12 @@ TEST(UnifedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
TF_DeleteAbstractOp(placeholder_op);
TF_DeleteExecutionContext(graph_ctx);
TF_DeleteExecutionContextOptions(options);
}
TEST(UnifedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
TF_ExecutionContext* graph_ctx =
TF_NewExecutionContext(options, status.get());
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
@ -258,7 +240,6 @@ TEST(UnifedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
TF_DeleteAbstractOp(placeholder_op);
TF_DeleteExecutionContext(graph_ctx);
TF_DeleteExecutionContextOptions(options);
}
TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
@ -266,8 +247,7 @@ TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContextOptions* options = TF_NewEagerContextOptions(opts);
TF_ExecutionContext* ctx = TF_NewExecutionContext(options, status.get());
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
@ -281,8 +261,8 @@ TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
// Build an abstract input tensor.
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx);
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
TF_AbstractTensor* at = TF_NewAbstractTensor();
TF_AbstractTensorSetEagerTensor(at, t, status.get());
TF_AbstractTensor* at =
TF_CreateAbstractTensorFromEagerTensor(t, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
@ -292,9 +272,7 @@ TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build a Graph context.
TF_ExecutionContextOptions* graph_options = TF_NewGraphContextOptions();
TF_ExecutionContext* graph_ctx =
TF_NewExecutionContext(graph_options, status.get());
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Execute eager op using graph context.
@ -307,17 +285,13 @@ TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
TF_DeleteOutputList(o);
TF_DeleteExecutionContext(ctx);
TF_DeleteExecutionContextOptions(options);
TF_DeleteExecutionContext(graph_ctx);
TF_DeleteExecutionContextOptions(graph_options);
}
TEST(UnifedCAPI, TestExecutingGraphOpInEagerModeRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
TF_ExecutionContext* graph_ctx =
TF_NewExecutionContext(options, status.get());
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
@ -355,10 +329,8 @@ TEST(UnifedCAPI, TestExecutingGraphOpInEagerModeRaises) {
// Build eager context.
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContextOptions* eager_ctx_options =
TF_NewEagerContextOptions(opts);
TF_ExecutionContext* eager_execution_ctx =
TF_NewExecutionContext(eager_ctx_options, status.get());
TF_NewEagerExecutionContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
@ -374,8 +346,6 @@ TEST(UnifedCAPI, TestExecutingGraphOpInEagerModeRaises) {
TF_DeleteOutputList(placeholder_outputs);
TF_DeleteExecutionContext(graph_ctx);
TF_DeleteExecutionContext(eager_execution_ctx);
TF_DeleteExecutionContextOptions(eager_ctx_options);
TF_DeleteExecutionContextOptions(options);
}
} // namespace

View File

@ -17,9 +17,11 @@ limitations under the License.
#include <vector>
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/types.pb.h"
@ -60,13 +62,31 @@ class AbstractContextInterface {
// Create a handle to wrap and manage a Tensor
virtual AbstractTensorHandleInterface* CreateLocalHandle(
AbstractTensorInterface* t) = 0;
// Copy the handle to another device.
virtual AbstractTensorHandleInterface* CopyTensorHandleToDevice(
AbstractTensorHandleInterface* handle, const char* device_name,
Status* status) = 0;
// Create an operation to perform op execution
virtual AbstractOperationInterface* CreateOperation() = 0;
// Load a SavedModelAPI object from the given directory and tags
virtual std::unique_ptr<SavedModelAPI> LoadSavedModelAPI(
const std::string& directory,
const absl::optional<std::unordered_set<std::string>>& tags,
tensorflow::Status* status) = 0;
// List attributes of available devices
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
virtual void ClearCachesAndThreadExecutors() = 0;
// Initialize the step resource container for a training step. This is used
// in current TF runtime. For tfrt, it is used by fallback op handler.
virtual void StartStep() = 0;
// Destroy the step resource container for a training step.
virtual void EndStep() = 0;
protected:
virtual ~AbstractContextInterface() {}
};

View File

@ -16,8 +16,10 @@ limitations under the License.
#include "tensorflow/c/eager/dlpack.h"
#include "include/dlpack/dlpack.h" // from @dlpack
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_reference.h"
@ -41,15 +43,15 @@ struct TfDlManagedTensorCtx {
// Gets tensor from eager tensor handle.
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
tensorflow::TensorHandle* handle =
tensorflow::TensorHandleFromInterface(h->handle);
if (handle->IsRemote()) {
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
if (handle->Type() != TensorHandle::LOCAL) {
status->status = tensorflow::errors::InvalidArgument(
"DLPack doesn't support remote tensor");
"DLPack doesn't support ", handle->TypeString(), " tensor");
return nullptr;
}
const tensorflow::Tensor* tensor;
@ -107,7 +109,7 @@ DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) {
// Gets DLPack's DLContext from eager tensor handle.
DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) {
DLContext ctx;
const char* device_name = h->handle->DeviceName(&status->status);
const char* device_name = tensorflow::unwrap(h)->DeviceName(&status->status);
DeviceNameUtils::ParsedName parsed_name;
tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
std::string device_type = parsed_name.type;

View File

@ -7,10 +7,26 @@ package(
licenses = ["notice"], # Apache 2.0
)
# Currently pybind extension shared objects must use only C API headers since
# the C API has static initializers duplicated in the Python bindings. So we
# need a second rule that omits .cc files, in
# tensorflow/python:_pywrap_parallel_device.
filegroup(
name = "headers",
srcs = ["parallel_device.h"],
visibility = ["//tensorflow/python:__pkg__"],
)
filegroup(
name = "sources",
srcs = ["parallel_device.cc"],
visibility = ["//tensorflow/python:__pkg__"],
)
cc_library(
name = "parallel_device",
srcs = ["parallel_device.cc"],
hdrs = ["parallel_device.h"],
srcs = [":sources"],
hdrs = [":headers"],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c/eager:c_api",

View File

@ -574,23 +574,21 @@ void DeleteParallelDevice(void* device_info) {
} // namespace
void RegisterParallelDevice(TFE_Context* context, const char* device_name,
const char** underlying_devices,
int num_underlying_devices, TF_Status* status) {
TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToParallelDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromParallelDevice;
custom_device.delete_device = &DeleteParallelDevice;
custom_device.execute = &ParallelDeviceExecute;
void AllocateParallelDevice(const char* device_name,
const char* const* underlying_devices,
int num_underlying_devices,
TFE_CustomDevice* device, void** device_info) {
device->copy_tensor_to_device = &CopyToParallelDevice;
device->copy_tensor_from_device = &CopyTensorFromParallelDevice;
device->delete_device = &DeleteParallelDevice;
device->execute = &ParallelDeviceExecute;
std::vector<std::string> underlying_devices_vector;
underlying_devices_vector.reserve(num_underlying_devices);
for (int device_index = 0; device_index < num_underlying_devices;
++device_index) {
underlying_devices_vector.push_back(underlying_devices[device_index]);
}
ParallelDevice* d =
new ParallelDevice(device_name, underlying_devices_vector);
TFE_RegisterCustomDevice(context, custom_device, device_name, d, status);
*device_info = new ParallelDevice(device_name, underlying_devices_vector);
}
} // namespace eager

View File

@ -16,12 +16,14 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
namespace tensorflow {
namespace eager {
// Register a parallel device named `device_name` which forwards operations to
// Allocate a parallel device named `device_name` which forwards operations to
// `underlying_devices`, maintaining "parallel tensors" with components placed
// on each underlying device.
//
@ -50,11 +52,12 @@ namespace eager {
// TPUReplicatedOutput(input=x, num_replicas=2)` un-packs the parallel tensor
// into its components.
//
// `context` owns the parallel device. `underlying_devices` must stay valid
// while the parallel device is in use.
void RegisterParallelDevice(TFE_Context* context, const char* device_name,
const char** underlying_devices,
int num_underlying_devices, TF_Status* status);
// The filled `device` struct and the allocated `device_info` struct may be
// passed to TFE_RegisterCustomDevice. The `device_name` arguments must match.
void AllocateParallelDevice(const char* device_name,
const char* const* underlying_devices,
int num_underlying_devices,
TFE_CustomDevice* device, void** device_info);
} // namespace eager
} // namespace tensorflow

View File

@ -288,6 +288,19 @@ void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) {
*static_cast<float*>(TF_TensorData(value_zero.get())));
}
template <std::size_t num_devices>
void RegisterParallelDevice(
TFE_Context* context, const char* device_name,
const std::array<const char*, num_devices>& underlying_devices,
TF_Status* status) {
TFE_CustomDevice device;
void* device_info;
tensorflow::eager::AllocateParallelDevice(
device_name, underlying_devices.data(), underlying_devices.size(),
&device, &device_info);
TFE_RegisterCustomDevice(context, device, device_name, device_info, status);
}
// Create and modify a variable placed on a parallel device which composes
// `first_device` and `second_device`.
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
@ -297,9 +310,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
TF_NewStatus(), TF_DeleteStatus);
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::array<const char*, 2> underlying_devices{first_device, second_device};
tensorflow::eager::RegisterParallelDevice(
context, device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
RegisterParallelDevice(context, device_name, underlying_devices,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle (uninitialized to start) placed on the parallel
@ -456,16 +468,14 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> underlying_devices;
const char* first_device_name =
"/job:localhost/replica:0/task:0/device:CPU:0";
underlying_devices.push_back(first_device_name);
const char* second_device_name =
"/job:localhost/replica:0/task:0/device:CPU:1";
underlying_devices.push_back(second_device_name);
tensorflow::eager::RegisterParallelDevice(
context.get(), device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
std::array<const char*, 2> underlying_devices{first_device_name,
second_device_name};
RegisterParallelDevice(context.get(), device_name, underlying_devices,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr cpu_value(FloatTensorHandle(3., status.get()));
@ -524,12 +534,11 @@ TEST(PARALLEL_DEVICE, TestDifferentShapes) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> underlying_devices;
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
tensorflow::eager::RegisterParallelDevice(
context.get(), device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
std::array<const char*, 2> underlying_devices{
"/job:localhost/replica:0/task:0/device:CPU:0",
"/job:localhost/replica:0/task:0/device:CPU:1"};
RegisterParallelDevice(context.get(), device_name, underlying_devices,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create two vectors with different lengths
@ -570,24 +579,22 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
// Create a parallel device with two CPUs
const char* first_device_name =
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> first_underlying_devices{
std::array<const char*, 2> first_underlying_devices{
"/job:localhost/replica:0/task:0/device:CPU:0",
"/job:localhost/replica:0/task:0/device:CPU:1"};
tensorflow::eager::RegisterParallelDevice(
context.get(), first_device_name, first_underlying_devices.data(),
first_underlying_devices.size(), status.get());
RegisterParallelDevice(context.get(), first_device_name,
first_underlying_devices, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a second parallel device with the first parallel device and one
// additional CPU.
const char* second_device_name =
"/job:localhost/replica:0/task:0/device:CUSTOM:1";
std::vector<const char*> second_underlying_devices{
std::array<const char*, 2> second_underlying_devices{
"/job:localhost/replica:0/task:0/device:CUSTOM:0",
"/job:localhost/replica:0/task:0/device:CPU:2"};
tensorflow::eager::RegisterParallelDevice(
context.get(), second_device_name, second_underlying_devices.data(),
second_underlying_devices.size(), status.get());
RegisterParallelDevice(context.get(), second_device_name,
second_underlying_devices, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a tensor on the first parallel device
@ -656,11 +663,10 @@ TEST(PARALLEL_DEVICE, TestInvalidPacking) {
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> underlying_devices;
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
tensorflow::eager::RegisterParallelDevice(
context.get(), device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
std::array<const char*, 1> underlying_devices{
"/job:localhost/replica:0/task:0/device:CPU:0"};
RegisterParallelDevice(context.get(), device_name, underlying_devices,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
@ -775,12 +781,11 @@ TEST(PARALLEL_DEVICE, TestCollective) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> underlying_devices;
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
tensorflow::eager::RegisterParallelDevice(
context.get(), device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
std::array<const char*, 2> underlying_devices{
"/job:localhost/replica:0/task:0/device:CPU:0",
"/job:localhost/replica:0/task:0/device:CPU:1"};
RegisterParallelDevice(context.get(), device_name, underlying_devices,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a tensor on the parallel device
@ -867,12 +872,11 @@ TEST(PARALLEL_DEVICE, TestFunction) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> underlying_devices;
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
tensorflow::eager::RegisterParallelDevice(
context.get(), device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
std::array<const char*, 2> underlying_devices{
"/job:localhost/replica:0/task:0/device:CPU:0",
"/job:localhost/replica:0/task:0/device:CPU:1"};
RegisterParallelDevice(context.get(), device_name, underlying_devices,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* function_name = "test_reduce_mul";

View File

@ -0,0 +1,24 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_
#define TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_
#include "tensorflow/core/framework/cancellation.h"
struct TFE_CancellationManager {
tensorflow::CancellationManager cancellation_manager;
};
#endif // TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_

View File

@ -0,0 +1,35 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
#define TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/context_interface.h"
// Wraps a pointer to a context implementation.
//
// WARNING: Since the underlying object could be ref-counted a user of this
// interface cannot destruct the underlying context object. Instead, call
// TFE_DeleteContext who calls Release() on the context pointer and deletes
// the TFE_Context structure.
typedef struct TFE_Context TFE_Context;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractContextInterface, TFE_Context);
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_

View File

@ -0,0 +1,37 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_TFE_EXECUTOR_INTERNAL_H_
#define TENSORFLOW_C_EAGER_TFE_EXECUTOR_INTERNAL_H_
#include <memory>
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
struct TFE_Executor {
explicit TFE_Executor(bool async)
: owned_executor(new tensorflow::EagerExecutor(async)) {}
explicit TFE_Executor(tensorflow::EagerExecutor* executor)
: owned_executor(nullptr), unowned_executor(executor) {}
tensorflow::EagerExecutor* executor() {
return owned_executor == nullptr ? unowned_executor : owned_executor.get();
}
std::unique_ptr<tensorflow::EagerExecutor> owned_executor;
tensorflow::EagerExecutor* unowned_executor;
};
#endif // TENSORFLOW_C_EAGER_TFE_EXECUTOR_INTERNAL_H_

View File

@ -0,0 +1,146 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_TFE_MONITORING_INTERNAL_H_
#define TENSORFLOW_C_EAGER_TFE_MONITORING_INTERNAL_H_
#include <functional>
#include <memory>
#include <string>
#include "absl/memory/memory.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/platform/types.h"
struct TFE_MonitoringCounterCell {
tensorflow::monitoring::CounterCell cell;
};
template <int NumLabels>
struct TFE_MonitoringCounter {
template <typename... LabelDesc>
TFE_MonitoringCounter(const char* name, const char* description,
LabelDesc&&... label) {
counter = absl::WrapUnique(tensorflow::monitoring::Counter<NumLabels>::New(
name, description, label...));
}
std::unique_ptr<tensorflow::monitoring::Counter<NumLabels>> counter;
};
struct TFE_MonitoringCounter0 : TFE_MonitoringCounter<0> {
using TFE_MonitoringCounter::TFE_MonitoringCounter;
};
struct TFE_MonitoringCounter1 : TFE_MonitoringCounter<1> {
using TFE_MonitoringCounter::TFE_MonitoringCounter;
};
struct TFE_MonitoringCounter2 : TFE_MonitoringCounter<2> {
using TFE_MonitoringCounter::TFE_MonitoringCounter;
};
struct TFE_MonitoringIntGaugeCell {
tensorflow::monitoring::GaugeCell<tensorflow::int64> cell;
};
struct TFE_MonitoringStringGaugeCell {
tensorflow::monitoring::GaugeCell<tensorflow::string> cell;
};
struct TFE_MonitoringBoolGaugeCell {
tensorflow::monitoring::GaugeCell<bool> cell;
};
template <typename ValueType, int NumLabels>
struct TFE_MonitoringGauge {
template <typename... LabelDesc>
TFE_MonitoringGauge(const char* name, const char* description,
LabelDesc&&... label) {
gauge = absl::WrapUnique(
tensorflow::monitoring::Gauge<ValueType, NumLabels>::New(
name, description, label...));
}
std::unique_ptr<tensorflow::monitoring::Gauge<ValueType, NumLabels>> gauge;
};
struct TFE_MonitoringIntGauge0 : TFE_MonitoringGauge<tensorflow::int64, 0> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringIntGauge1 : TFE_MonitoringGauge<tensorflow::int64, 1> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringIntGauge2 : TFE_MonitoringGauge<tensorflow::int64, 2> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringStringGauge0 : TFE_MonitoringGauge<tensorflow::string, 0> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringStringGauge1 : TFE_MonitoringGauge<tensorflow::string, 1> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringStringGauge2 : TFE_MonitoringGauge<tensorflow::string, 2> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBoolGauge0 : TFE_MonitoringGauge<bool, 0> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBoolGauge1 : TFE_MonitoringGauge<bool, 1> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBoolGauge2 : TFE_MonitoringGauge<bool, 2> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBuckets {
explicit TFE_MonitoringBuckets(
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
fn) {
create_buckets = fn;
}
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
create_buckets;
};
struct TFE_MonitoringSamplerCell {
tensorflow::monitoring::SamplerCell cell;
};
template <int NumLabels>
struct TFE_MonitoringSampler {
template <typename... LabelDesc>
TFE_MonitoringSampler(
const char* name,
std::unique_ptr<tensorflow::monitoring::Buckets> buckets,
const char* description, LabelDesc&&... label) {
sampler = absl::WrapUnique(tensorflow::monitoring::Sampler<NumLabels>::New(
{name, description, label...}, std::move(buckets)));
}
std::unique_ptr<tensorflow::monitoring::Sampler<NumLabels>> sampler;
};
struct TFE_MonitoringSampler0 : TFE_MonitoringSampler<0> {
using TFE_MonitoringSampler::TFE_MonitoringSampler;
};
struct TFE_MonitoringSampler1 : TFE_MonitoringSampler<1> {
using TFE_MonitoringSampler::TFE_MonitoringSampler;
};
struct TFE_MonitoringSampler2 : TFE_MonitoringSampler<2> {
using TFE_MonitoringSampler::TFE_MonitoringSampler;
};
#endif // TENSORFLOW_C_EAGER_TFE_MONITORING_INTERNAL_H_

View File

@ -0,0 +1,39 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
#define TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/framework/attr_value.pb.h"
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
// that sometimes do not require serialization.
typedef struct TFE_OpAttrs TFE_OpAttrs;
typedef struct TFE_Context TFE_Context;
typedef struct TFE_Op TFE_Op;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AttrBuilder, TFE_OpAttrs);
// Set an AttrValue on the op. Doesn't handle the list types.
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
const tensorflow::AttrValue& default_value,
const char* attr_name, TF_Status* status);
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_

View File

@ -0,0 +1,36 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
#define TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/operation_interface.h"
// Wraps a pointer to an operation implementation.
//
// WARNING: Since the underlying object could be ref-counted a user of this
// interface cannot destruct the underlying operation object. Instead, call
// TFE_DeleteOp who calls Release() on the operation pointer and deletes
// the TFE_Op structure.
typedef struct TFE_Op TFE_Op;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface, TFE_Op);
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface*, TFE_Op*);
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_

View File

@ -0,0 +1,30 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_TFE_TENSOR_DEBUG_INFO_INTERNAL_H_
#define TENSORFLOW_C_EAGER_TFE_TENSOR_DEBUG_INFO_INTERNAL_H_
#include <vector>
#include "tensorflow/core/platform/types.h"
struct TFE_TensorDebugInfo {
explicit TFE_TensorDebugInfo(const std::vector<tensorflow::int64>& dims)
: dev_dims(dims) {}
// Fully-padded, minor-to-major.
std::vector<tensorflow::int64> dev_dims;
};
#endif // TENSORFLOW_C_EAGER_TFE_TENSOR_DEBUG_INFO_INTERNAL_H_

View File

@ -0,0 +1,38 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
#define TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
// Wraps a pointer to a tensor handle implementation.
//
// WARNING: Since the underlying object could be ref-counted a user of this
// interface cannot destruct the underlying handle object. Instead, call
// TFE_DeleteTensorHandle who calls Release() on the handle pointer and deletes
// the TFE_TensorHandle structure.
typedef struct TFE_TensorHandle TFE_TensorHandle;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface,
TFE_TensorHandle);
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface*,
TFE_TensorHandle*);
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_

View File

@ -85,17 +85,36 @@ class ModularFileSystemTest : public ::testing::TestWithParam<std::string> {
const std::string test_name = tensorflow::str_util::StringReplace(
::testing::UnitTest::GetInstance()->current_test_info()->name(), "/",
"_", /*replace_all=*/true);
root_dir_ = tensorflow::io::JoinPath(
::testing::TempDir(),
tensorflow::strings::StrCat("tf_fs_", rng_val_, "_", test_name));
if (!cloud_path_.empty()) {
// We have to join path for non-local filesystem manually to make sure
// that this test will run on Windows since `tensorflow::io::JoinPath`
// behaves differently on Windows. `tmp_dir` should be something like
// `path/to/tmp/dir/`. After joining path, we will have
// /path/to/tmp/dir/tf_fs_rng_name/`
root_dir_ = tensorflow::strings::StrCat(
"/", tmp_dir_,
tensorflow::strings::StrCat("tf_fs_", rng_val_, "_", test_name), "/");
} else {
root_dir_ = tensorflow::io::JoinPath(
tmp_dir_,
tensorflow::strings::StrCat("tf_fs_", rng_val_, "_", test_name));
}
if (!GetParam().empty()) {
root_dir_ = tensorflow::strings::StrCat(GetParam(), "://", cloud_path_,
root_dir_);
}
env_ = Env::Default();
}
void SetUp() override {
if (mkdir(root_dir_.c_str(), 0755) != 0) {
int error_code = errno;
GTEST_SKIP() << "Cannot create working directory: "
<< tensorflow::IOError(root_dir_, error_code);
FileSystem* fs = nullptr;
Status s = env_->GetFileSystemForFile(root_dir_, &fs);
if (fs == nullptr || !s.ok())
GTEST_SKIP() << "No filesystem registered: " << s;
s = fs->CreateDir(root_dir_);
if (!s.ok()) {
GTEST_SKIP() << "Cannot create working directory: " << s;
}
}
@ -115,9 +134,10 @@ class ModularFileSystemTest : public ::testing::TestWithParam<std::string> {
std::string GetURIForPath(StringPiece path) {
const std::string translated_name =
tensorflow::io::JoinPath(root_dir_, path);
if (GetParam().empty()) return translated_name;
return tensorflow::strings::StrCat(GetParam(), "://", translated_name);
// We have already checked `GetParam().empty()` in
// `ModularFileSystemTest()`. root_dir_ should contain `GetParam() + "://"`
// if it isn't empty.
return translated_name;
}
// Converts absolute paths to paths relative to root_dir_.
@ -133,15 +153,28 @@ class ModularFileSystemTest : public ::testing::TestWithParam<std::string> {
rng_val_ = distribution(gen);
}
static void SetCloudPath(const std::string& cloud_path) {
cloud_path_ = cloud_path;
if (cloud_path_.back() == '/') cloud_path_.pop_back();
}
static void SetTmpDir(const std::string& tmp_dir) {
tmp_dir_ = tmp_dir.empty() ? ::testing::TempDir() : tmp_dir;
}
protected:
Env* env_;
private:
std::string root_dir_;
static int rng_val_;
static std::string cloud_path_;
static std::string tmp_dir_;
};
int ModularFileSystemTest::rng_val_;
std::string ModularFileSystemTest::cloud_path_;
std::string ModularFileSystemTest::tmp_dir_;
// As some of the implementations might be missing, the tests should still pass
// if the returned `Status` signals the unimplemented state.
@ -1729,6 +1762,20 @@ static bool GetURIScheme(const std::string& scheme) {
return true;
}
// This function is used for cloud filesystem
// `S3` and `GCS` require the `root_dir_` to have bucket name
// `HDFS` requires the `root_dir` to have namenode
// `root_dir_ = scheme + "://" cloud_path_ + root_dir_`
static bool SetCloudPath(const std::string& cloud_path_) {
ModularFileSystemTest::SetCloudPath(cloud_path_);
return true;
}
static bool SetTmpDir(const std::string& tmp_dir_) {
ModularFileSystemTest::SetTmpDir(tmp_dir_);
return true;
}
} // namespace
} // namespace tensorflow
@ -1741,7 +1788,12 @@ GTEST_API_ int main(int argc, char** argv) {
tensorflow::Flag("dso", tensorflow::LoadDSO, "",
"Path to shared object to load"),
tensorflow::Flag("scheme", tensorflow::GetURIScheme, "",
"URI scheme to test")};
"URI scheme to test"),
tensorflow::Flag("cloud_path", tensorflow::SetCloudPath, "",
"Path for cloud filesystem (namenode for hdfs, "
"bucketname for s3/gcs)"),
tensorflow::Flag("tmp_dir", tensorflow::SetTmpDir, "",
"Temporary directory to store test data.")};
if (!tensorflow::Flags::Parse(&argc, argv, flag_list)) {
std::cout << tensorflow::Flags::Usage(argv[0], flag_list);
return -1;

View File

@ -0,0 +1,66 @@
# Tensorflow C SavedModel API
## Overview
These are the new experimental C SavedModel APIs for loading and running
SavedModels in a TF2-idiomatic fashion. See
[RFC 207](https://github.com/tensorflow/community/pull/207) for additional
context.
The directory structure is as follows:
```none
saved_model/
public/
internal/
core/
```
## saved_model/public
`saved_model/public` is intended to house *only the public headers* of the
SavedModel C API.
These headers:
1. declare opaque C types (like `TF_SavedModel`),
2. declare the functions that operate on these types (like `TF_LoadSavedModel`).
Once they leave experimental, these APIs should be considered stable for use
by external clients.
These headers are in a separate directory to make it obvious to clients which
headers they should depend on, and which headers are implementation details.
Separating these public headers by directory also allow future programmatic
checks to ensure that TF public headers only `#include` other public TF headers.
## saved_model/internal
`saved_model/internal` is the "glue" between the C API and the internal C++
implementation.
Its role is to:
1. implement the C API functions declared in `saved_model/public`
2. define the C API types declared in `saved_model/public`
The files fulfilling 1. are named `*.cc` (eg: `concrete_function.cc`), while
the files fulfilling 2. are `*type.h` (eg: `concrete_function_type.h`).
The headers exposing the internal implementation of the opaque C types are only
visible to other implementors of the C API. This is similar to how other
TF C API implementations use `tf_status_internal.h` (to extract the underlying
`tensorflow::Status`). All other targets in this directory are private.
## saved_model/core
`saved_model/core` contains pure C++ "Classes" underlying the C API types
in `saved_model/public/`. These are implementation
details subject to change, and have limited visibility to implementors only.
This is the bottom-most layer of the `C++ -> C -> C++` sandwich.

View File

@ -0,0 +1,85 @@
# Experimental SavedModel C APIs for TensorFlow. See RFC
# https://github.com/tensorflow/community/pull/207
# Targets in this directory are pure C++ "Classes" underlying the C API types
# under tf/c/experimental/saved_model/public/. They are subject to change and
# have visibility limited to Tensorflow's implementation only.
package(
default_visibility = [
"//tensorflow/c:__subpackages__",
"//tensorflow/c/experimental/saved_model/internal:__pkg__",
"//tensorflow/core:__subpackages__",
],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "concrete_function",
srcs = [
"concrete_function.cc",
],
hdrs = [
"concrete_function.h",
],
deps = [
":function_metadata",
"//tensorflow/c/eager:operation_interface",
"//tensorflow/c/eager:tensor_handle_interface",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "function_metadata",
hdrs = [
"function_metadata.h",
],
)
cc_library(
name = "saved_model_api",
hdrs = [
"saved_model_api.h",
],
deps = [
":concrete_function",
"//tensorflow/core:lib",
],
)
cc_library(
name = "tf_saved_model_impl",
srcs = [
"tf_saved_model_impl.cc",
],
hdrs = ["tf_saved_model_impl.h"],
deps = [
":concrete_function",
":saved_model_api",
"//tensorflow/core:lib",
"@com_google_absl//absl/types:optional",
],
)
cc_library(
name = "pywrap_required_hdrs",
textual_hdrs = [
"concrete_function.h",
"function_metadata.h",
"saved_model_api.h",
],
visibility = ["//tensorflow/python:__pkg__"],
)
filegroup(
name = "mobile_srcs_only_runtime",
srcs = [
"concrete_function.cc",
"concrete_function.h",
"function_metadata.h",
"saved_model_api.h",
"tf_saved_model_impl.cc",
"tf_saved_model_impl.h",
],
visibility = ["//tensorflow/core:__pkg__"],
)

View File

@ -0,0 +1,32 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
namespace tensorflow {
const std::vector<tensorflow::AbstractTensorHandleInterface*>&
ConcreteFunction::GetCaptures() const {
return captures_;
}
const FunctionMetadata& ConcreteFunction::GetFunctionMetadata() const {
return metadata_;
}
} // namespace tensorflow

View File

@ -0,0 +1,55 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_
#include <vector>
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/core/framework/function.pb.h"
namespace tensorflow {
// Note that ConcreteFunctions's lifetimes are effectively bound
// to the SavedModel they are loaded from, since they retain pointers
// to the TensorHandles owned by the SavedModel, and the FunctionDef
// of the SavedModel.
// Note(bmzhao): This class is only TEMPORARILY virtual, as a way to unblock
// TFRT integration with TF Serving. Do not add more virtual implementations of
// this class. Eventually we want to remove this virtual base class indirection
// and have only a single implementation.
class ConcreteFunction {
public:
virtual ~ConcreteFunction() = 0;
// This method returns the "Call" Op used to execute the function.
virtual AbstractOperationInterface* GetCallOp() = 0;
const std::vector<tensorflow::AbstractTensorHandleInterface*>& GetCaptures()
const;
const FunctionMetadata& GetFunctionMetadata() const;
private:
FunctionMetadata metadata_;
std::vector<tensorflow::AbstractTensorHandleInterface*> captures_;
FunctionDef* function_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_

View File

@ -0,0 +1,27 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_FUNCTION_METADATA_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_FUNCTION_METADATA_H_
namespace tensorflow {
class FunctionMetadata {
// TODO(bmzhao): Fill in with fields as necessary
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_FUNCTION_METADATA_H_

View File

@ -0,0 +1,55 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_API_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_API_H_
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
// Note(bmzhao): This class is only TEMPORARILY virtual, as a way to unblock
// TFRT integration with TF Serving. Do not add more virtual implementations of
// this class. Eventually we want to remove this virtual base class indirection
// and have only a single implementation.
class SavedModelAPI {
public:
// Retrieve a function from the TF2 SavedModel, using the "path" to a function
// in a TF2 savedmodel.
// Note: `function` is a double pointer, so that implementations are
// able to return a pointer to an internal member.
virtual Status GetFunction(const std::string& function_path,
ConcreteFunction** function) = 0;
// Retrieve a function from a SavedModel, using the key of the
// SignatureDef map:
// https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89
virtual Status GetSignatureDefFunction(const std::string& signature_def_key,
ConcreteFunction** function) = 0;
virtual std::vector<ConcreteFunction*> ListFunctions() = 0;
virtual ~SavedModelAPI() = default;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_API_H_

View File

@ -0,0 +1,60 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "absl/types/optional.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
Status TFSavedModelAPIImpl::GetFunction(const std::string& function_path,
ConcreteFunction** function) {
// TODO(bmzhao): Add support for retrieving a function.
return errors::Unimplemented(
"Retrieving functions is unimplemented currently");
}
Status TFSavedModelAPIImpl::GetSignatureDefFunction(
const std::string& signature_def_key, ConcreteFunction** function) {
// TODO(bmzhao): Add support for retrieving a signaturedef function.
return errors::Unimplemented(
"Retrieving functions is unimplemented currently");
}
std::vector<ConcreteFunction*> TFSavedModelAPIImpl::ListFunctions() {
std::vector<ConcreteFunction*> result;
result.reserve(functions_.size());
for (ConcreteFunction& function : functions_) {
result.push_back(&function);
}
return result;
}
Status TFSavedModelAPIImpl::Load(
const std::string& directory,
const absl::optional<std::unordered_set<std::string>>& tags,
TFSavedModelAPIImpl* out) {
// TODO(bmzhao): Add support for loading a TFSavedModelImpl.
return errors::Unimplemented(
"TFSavedModelAPIImpl loading is unimplemented currently");
}
} // namespace tensorflow

View File

@ -0,0 +1,55 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_
#include <string>
#include <unordered_set>
#include <vector>
#include "absl/types/optional.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
class TFSavedModelAPIImpl : public SavedModelAPI {
public:
TFSavedModelAPIImpl() = default;
Status GetFunction(const std::string& function_path,
ConcreteFunction** function) override;
Status GetSignatureDefFunction(const std::string& signature_def_key,
ConcreteFunction** function) override;
static Status Load(
const std::string& directory,
const absl::optional<std::unordered_set<std::string>>& tags,
TFSavedModelAPIImpl* out);
std::vector<ConcreteFunction*> ListFunctions() override;
~TFSavedModelAPIImpl() override = default;
private:
std::vector<ConcreteFunction> functions_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_

View File

@ -0,0 +1,181 @@
# Experimental Implementation of SavedModel C APIs for TensorFlow. See RFC
# https://github.com/tensorflow/community/pull/207
# External clients should not worry about this directory; all contents are implementation details.
# Code in this directory is intended to form the glue between the C API and the internal C++
# implementation by
# 1. mapping C API calls onto correponding methods of C++ objects
# 2. mapping opaque C types onto C++ classes
# Note(bmzhao): The *.cc files in this directory form the direct implementation of the
# C API functions exposed in tf/c/experimental/saved_model/public/.
# Note(bmzhao): All *type.h files in this directory are the internal definitions of
# the opaque C types. These headers should only be visible to internal tensorflow
# implementors.
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
"tf_copts",
)
package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "concrete_function",
srcs = [
"concrete_function.cc",
],
hdrs = [
"//tensorflow/c/experimental/saved_model/public:concrete_function.h",
],
copts = tf_copts(),
# TODO(bmzhao): Remove this as we refactor C API to granular targets,
# so that we can depend on c/eager/c_api_unified_experimental.h.
features = ["-layering_check"],
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
deps = [
":concrete_function_type",
":function_metadata",
":function_metadata_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_internal",
"//tensorflow/c/eager:tfe_op_internal",
"//tensorflow/c/experimental/saved_model/core:concrete_function",
"//tensorflow/c/experimental/saved_model/core:function_metadata",
],
)
cc_library(
name = "concrete_function_list",
srcs = [
"concrete_function_list.cc",
],
hdrs = [
"//tensorflow/c/experimental/saved_model/public:concrete_function_list.h",
],
copts = tf_copts(),
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
deps = [
":concrete_function",
":concrete_function_list_type",
":concrete_function_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c/experimental/saved_model/core:concrete_function",
],
)
cc_library(
name = "concrete_function_list_type",
hdrs = [
"concrete_function_list_type.h",
],
deps = [
"//tensorflow/c/experimental/saved_model/core:concrete_function",
],
)
cc_library(
name = "concrete_function_type",
hdrs = [
"concrete_function_type.h",
],
deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/c/experimental/saved_model/core:concrete_function",
],
)
cc_library(
name = "function_metadata",
srcs = [
"function_metadata.cc",
],
hdrs = [
"//tensorflow/c/experimental/saved_model/public:function_metadata.h",
],
copts = tf_copts(),
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
deps = [
":function_metadata_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c/experimental/saved_model/core:function_metadata",
],
)
cc_library(
name = "function_metadata_type",
hdrs = [
"function_metadata_type.h",
],
deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/c/experimental/saved_model/core:function_metadata",
],
)
cc_library(
name = "saved_model_api",
srcs = [
"saved_model_api.cc",
],
hdrs = [
"//tensorflow/c/experimental/saved_model/public:saved_model_api.h",
],
copts = tf_copts(),
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
deps = [
":concrete_function",
":concrete_function_list",
":concrete_function_list_type",
":concrete_function_type",
":saved_model_api_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_internal",
"//tensorflow/c/eager:tfe_context_internal",
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
"//tensorflow/core:lib",
"@com_google_absl//absl/types:optional",
],
)
cc_library(
name = "saved_model_api_type",
hdrs = [
"saved_model_api_type.h",
],
deps = [
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
],
)
tf_cc_test(
name = "saved_model_api_test",
size = "small",
srcs = [
"saved_model_api_test.cc",
],
data = [
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
],
deps = [
"//tensorflow/c:tf_status",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/c/experimental/saved_model/public:saved_model_api",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)

View File

@ -0,0 +1,42 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h"
extern "C" {
TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) {
return tensorflow::wrap(const_cast<tensorflow::FunctionMetadata*>(
&tensorflow::unwrap(func)->GetFunctionMetadata()));
}
TF_OutputList* TF_ConcreteFunctionGetCaptures(TF_ConcreteFunction* func) {
// TODO(bmzhao): Refactor TF_OutputList struct definition into a separate
// internal header, and implement this function.
return nullptr;
}
TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func) {
return tensorflow::wrap(tensorflow::unwrap(func)->GetCallOp());
}
} // end extern "C"

View File

@ -0,0 +1,37 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stddef.h>
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
extern "C" {
size_t TF_ConcreteFunctionListNumOutputs(TF_ConcreteFunctionList* list) {
return list->list.size();
}
TF_ConcreteFunction* TF_ConcreteFunctionListGet(TF_ConcreteFunctionList* list,
int i) {
return tensorflow::wrap(list->list[i]);
}
void TF_DeleteConcreteFunctionList(TF_ConcreteFunctionList* list) {
delete list;
}
} // end extern "C"

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
#include <vector>
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
// Internal structures used by the SavedModel C API. These are likely to change
// and should not be depended on.
struct TF_ConcreteFunctionList {
std::vector<tensorflow::ConcreteFunction*> list;
};
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_

View File

@ -0,0 +1,36 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
// Internal structures used by the SavedModel C API. These are likely to change
// and should not be depended on.
// It doesn't make sense to wrap tensorflow::ConcreteFunction* in a separate
// struct, since the lifetime of the struct and the raw pointer it wraps would
// be different. Therefore TF_ConcreteFunction* = tensorflow::ConcreteFunction*.
typedef struct TF_ConcreteFunction TF_ConcreteFunction;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ConcreteFunction, TF_ConcreteFunction)
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%{
#include "tensorflow/lite/experimental/kernels/hashtable_ops.h"
%}
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
%include "tensorflow/lite/experimental/kernels/hashtable_ops.h"
#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h"
// TODO(bmzhao): Add getter functions here as necessary.

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
typedef struct TF_FunctionMetadata TF_FunctionMetadata;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::FunctionMetadata, TF_FunctionMetadata)
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_

View File

@ -0,0 +1,97 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
#include <memory>
#include <string>
#include <unordered_set>
#include "absl/types/optional.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/platform/status.h"
extern "C" {
TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
TF_Status* status) {
std::string saved_model_dir(dirname);
std::unique_ptr<tensorflow::SavedModelAPI> result =
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, absl::nullopt,
&status->status);
if (!status->status.ok()) {
return nullptr;
}
return new TF_SavedModel{std::move(result)};
}
TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
const char* const* tags, int tags_len,
TF_Status* status) {
std::string saved_model_dir(dirname);
std::unordered_set<std::string> tagset;
for (int i = 0; i < tags_len; ++i) {
tagset.insert(std::string(tags[i]));
}
std::unique_ptr<tensorflow::SavedModelAPI> result =
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, std::move(tagset),
&status->status);
if (!status->status.ok()) {
return nullptr;
}
return new TF_SavedModel{std::move(result)};
}
void TF_DeleteSavedModel(TF_SavedModel* model) { delete model; }
TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model,
const char* function_path,
TF_Status* status) {
tensorflow::ConcreteFunction* result = nullptr;
tensorflow::Status get_function_status =
model->saved_model->GetFunction(function_path, &result);
status->status.Update(get_function_status);
if (!get_function_status.ok()) {
return nullptr;
}
return tensorflow::wrap(result);
}
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
TF_SavedModel* model, const char* signature_def_key, TF_Status* status) {
tensorflow::ConcreteFunction* result = nullptr;
tensorflow::Status get_function_status =
model->saved_model->GetSignatureDefFunction(signature_def_key, &result);
status->status.Update(get_function_status);
if (!get_function_status.ok()) {
return nullptr;
}
return tensorflow::wrap(result);
}
TF_ConcreteFunctionList* TF_ListSavedModelFunctions(TF_SavedModel* model) {
return new TF_ConcreteFunctionList{model->saved_model->ListFunctions()};
}
} // end extern "C"

View File

@ -0,0 +1,109 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
#include <string>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/test.h"
namespace {
constexpr char kTestData[] = "cc/saved_model/testdata";
const char* kServeTag[] = {"serve"};
std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
return tensorflow::io::JoinPath(tensorflow::testing::TensorFlowSrcRoot(),
kTestData, saved_model_dir);
}
// This value parameterized test allows us to test both TFRT
// and non TFRT runtimes.
// https://github.com/google/googletest/blob/dcc92d0ab6c4ce022162a23566d44f673251eee4/googletest/docs/advanced.md#value-parameterized-tests
class CSavedModelAPITest : public ::testing::TestWithParam<bool> {};
TEST_P(CSavedModelAPITest, LoadsSavedModelWithTags) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
bool use_tfrt = GetParam();
if (use_tfrt) {
TFE_DeleteContextOptions(opts);
TF_DeleteStatus(status);
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
}
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
TF_SavedModel* saved_model =
TF_LoadSavedModelWithTags(model_dir.c_str(), ctx, kServeTag, 1, status);
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
// That unblocks writing other tests that require a TF_SavedModel*,
// like loading a ConcreteFunction. This test at least checks that the
// C API builds and can be minimally run.
EXPECT_EQ(TF_GetCode(status), TF_UNIMPLEMENTED);
TF_DeleteSavedModel(saved_model);
TF_DeleteStatus(status);
TFE_DeleteContext(ctx);
}
TEST_P(CSavedModelAPITest, LoadsSavedModel) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
bool use_tfrt = GetParam();
if (use_tfrt) {
TFE_DeleteContextOptions(opts);
TF_DeleteStatus(status);
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
}
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
TF_SavedModel* saved_model =
TF_LoadSavedModel(model_dir.c_str(), ctx, status);
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
// That unblocks writing other tests that require a TF_SavedModel*,
// like loading a ConcreteFunction. This test at least checks that the
// C API builds and can be minimally run.
EXPECT_EQ(TF_GetCode(status), TF_UNIMPLEMENTED);
TF_DeleteSavedModel(saved_model);
TF_DeleteStatus(status);
TFE_DeleteContext(ctx);
}
INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticSavedModelTests, CSavedModelAPITest,
::testing::Bool());
} // namespace

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_
#include <memory>
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
// Internal structures used by the SavedModel C API. These are likely to change
// and should not be depended on.
struct TF_SavedModel {
std::unique_ptr<tensorflow::SavedModelAPI> saved_model;
};
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_

View File

@ -0,0 +1,63 @@
# Experimental SavedModel C APIs for TensorFlow.
# See RFC https://github.com/tensorflow/community/pull/207
# All headers are on the public surface of Tensorflow's C API.
# Once moved out of experimental, these will be stable.
# The idea behind a separate public/ directory is to make apparent
# which headers are part of TF's public interface (and which headers)
# are implementation details. This structure allows us to also perform future
# programmatic checks that all "public" headers only include other "public"
# headers.
package(
# This is intentionally public
default_visibility = [
"//visibility:public",
],
licenses = ["notice"], # Apache 2.0
)
# TODO(bmzhao): Remove these exports_files and rules, swap with cc_public_library instead.
# cc_public_library would allows us to separate the header dep graph from header+srcs dep graph.
exports_files(
[
"concrete_function.h",
"concrete_function_list.h",
"function_metadata.h",
"saved_model_api.h",
],
visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"],
)
# The purpose of this header is to provide insulation against
# future changes where we rename/move a public header, without
# forcing all clients to change their "#includes".
cc_library(
name = "c_saved_model_api",
hdrs = ["c_saved_model_api.h"],
deps = [
":concrete_function",
":concrete_function_list",
":function_metadata",
":saved_model_api",
],
)
alias(
name = "concrete_function",
actual = "//tensorflow/c/experimental/saved_model/internal:concrete_function",
)
alias(
name = "concrete_function_list",
actual = "//tensorflow/c/experimental/saved_model/internal:concrete_function_list",
)
alias(
name = "function_metadata",
actual = "//tensorflow/c/experimental/saved_model/internal:function_metadata",
)
alias(
name = "saved_model_api",
actual = "//tensorflow/c/experimental/saved_model/internal:saved_model_api",
)

View File

@ -0,0 +1,26 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_
// IWYU pragma: begin_exports
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
// IWYU pragma: end_exports
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_

View File

@ -0,0 +1,50 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// An opaque type that corresponds to a Function loaded from a SavedModel.
// TODO(bmzhao): Work together w/srbs@ to make sure this composes w/the
// C++ Unified Eager/Graph API's AbstractFunction
typedef struct TF_ConcreteFunction TF_ConcreteFunction;
// Returns FunctionMetadata associated with `func`. Metadata's lifetime is
// bound to `func`, which is bound to the TF_SavedModel it was loaded from.
TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(
TF_ConcreteFunction* func);
// Returns a list of TensorHandles implicitly captured by this function.
TF_CAPI_EXPORT extern TF_OutputList* TF_ConcreteFunctionGetCaptures(
TF_ConcreteFunction* func);
// Returns a TFE_Op suitable for executing this function.
TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp(
TF_ConcreteFunction* func);
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
#include <stddef.h>
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
// An opaque type that is acts like a list of TF_ConcreteFunction pointers.
typedef struct TF_ConcreteFunctionList TF_ConcreteFunctionList;
// Returns the size of `list`.
TF_CAPI_EXPORT size_t
TF_ConcreteFunctionListSize(TF_ConcreteFunctionList* list);
// Returns the `i`th TF_ConcreteFunction in the list.
TF_CAPI_EXPORT TF_ConcreteFunction* TF_ConcreteFunctionListGet(
TF_ConcreteFunctionList* list, int i);
// Deletes `list`.
TF_CAPI_EXPORT void TF_DeleteConcreteFunctionList(
TF_ConcreteFunctionList* list);
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_

View File

@ -0,0 +1,35 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_FUNCTION_METADATA_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_FUNCTION_METADATA_H_
#include "tensorflow/c/c_api_macros.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// An opaque type used to store any metadata associated with a function.
typedef struct TF_FunctionMetadata TF_FunctionMetadata;
// TODO(bmzhao): Add getters for fields as we determine what metadata
// we want to expose.
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_FUNCTION_METADATA_H_

View File

@ -0,0 +1,108 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SAVED_MODEL_API_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SAVED_MODEL_API_H_
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
#include "tensorflow/c/tf_status.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// An opaque type representing a Tensorflow "SavedModel"
// (https://www.tensorflow.org/guide/saved_model) that we always pass by pointer
// to achieve ABI stability.
typedef struct TF_SavedModel TF_SavedModel;
// Load a SavedModel from `dirname`. We expect the SavedModel to contain a
// single Metagraph (as for those exported from TF2's `tf.saved_model.save`).
//
// Params:
// dirname - A directory filepath that the SavedModel is at.
// ctx - A TFE_Context containing optional load/TF runtime options.
// `ctx` must outlive the returned TF_SavedModel pointer.
// status - Set to OK on success and an appropriate error on failure.
// Returns:
// If status is not OK, returns nullptr. Otherwise, returns a newly created
// TF_SavedModel instance. It must be deleted by calling TF_DeleteSavedModel.
TF_CAPI_EXPORT extern TF_SavedModel* TF_LoadSavedModel(const char* dirname,
TFE_Context* ctx,
TF_Status* status);
// Load a SavedModel from `dirname`.
//
// Params:
// dirname - A directory filepath that the SavedModel is at.
// ctx - A TFE_Context containing optional load/TF runtime options.
// `ctx` must outlive the returned TF_SavedModel pointer.
// tags - char* array of SavedModel tags. We will load the metagraph matching
// the tags.
// tags_len - number of elements in the `tags` array.
// status - Set to OK on success and an appropriate error on failure.
// Returns:
// If status is not OK, returns nullptr. Otherwise, returns a newly created
// TF_SavedModel instance. It must be deleted by calling TF_DeleteSavedModel.
TF_CAPI_EXPORT extern TF_SavedModel* TF_LoadSavedModelWithTags(
const char* dirname, TFE_Context* ctx, const char* const* tags,
int tags_len, TF_Status* status);
// Deletes a TF_SavedModel, and frees any resources owned by it.
TF_CAPI_EXPORT extern void TF_DeleteSavedModel(TF_SavedModel* model);
// Retrieve a function from the TF2 SavedModel via function path.
//
// Params:
// model - The TF2 SavedModel to load a function from.
// function_path - A string containing the path from the root saved python
// object to a tf.function method.
// TODO(bmzhao): Add a detailed example of this with a
// python tf.module before moving this out of experimental.
// status - Set to OK on success and an appropriate error on failure.
// Returns:
// If status is not OK, returns nullptr. Otherwise, returns a
// TF_ConcreteFunction instance. The lifetime of this instance is
// "conceptually" bound to `model`. Once `model` is deleted, all
// `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted.
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(
TF_SavedModel* model, const char* function_path, TF_Status* status);
// Retrieve a function from the TF SavedModel via a SignatureDef key.
//
// Params:
// model - The SavedModel to load a function from.
// signature_def_key - The string key of the SignatureDef map of a SavedModel:
// https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89
// status - Set to OK on success and an appropriate error on failure.
// Returns:
// If status is not OK, returns nullptr. Otherwise, returns a
// TF_ConcreteFunction instance. Once `model` is deleted, all
// `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted.
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
TF_SavedModel* model, const char* signature_def_key, TF_Status* status);
// Returns a list of all ConcreteFunctions stored in this SavedModel.
// The lifetime of the returned list is bound to `model`.
TF_CAPI_EXPORT extern TF_ConcreteFunctionList* TF_ListSavedModelFunctions(
TF_SavedModel* model);
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SAVED_MODEL_API_H_

View File

@ -182,6 +182,7 @@ cc_library_with_android_deps(
deps = [
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:ops",

View File

@ -0,0 +1,63 @@
# Experimental C++ APIs for TensorFlow.
# New TF C++ APIs under the tensorflow::cc namespace aim to guarantee ABI stability.
# Users are expected to compile against public c++ headers, and link against
# libtensorflow (https://www.tensorflow.org/install/lang_c).
# We aim to achieve ABI stability in new C++ APIs by only using types
# on the API surface that:
# 1. Have a header-only implementation
# 2. Are std:: types
# 3. Wrap an opaque C type
package(
# This is intentionally public
default_visibility = [
"//visibility:public",
],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "runtime",
hdrs = [
"runtime.h",
],
deps = [
":status",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
],
)
cc_library(
name = "runtime_builder",
hdrs = [
"runtime_builder.h",
],
deps = [
":runtime",
":status",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
],
)
cc_library(
name = "status",
hdrs = [
"status.h",
],
deps = [
"//tensorflow/c:tf_status",
],
)
cc_library(
name = "tensor",
hdrs = [
"tensor.h",
],
deps = [
"//tensorflow/c:tf_datatype",
"//tensorflow/c:tf_tensor",
],
)

View File

@ -0,0 +1,68 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_
#include <memory>
#include "tensorflow/c/eager/c_api_experimental.h"
namespace tensorflow {
namespace cc {
// Runtime represents an opaque instance of a Tensorflow runtime, with its own
// resources, threadpools, etc. Clients are expected to construct a Runtime
// object through tensorflow::cc::RuntimeBuilder::Build, after setting any
// relevant configuration options. Many Tensorflow functions take a reference to
// the runtime as an argument (eg: tensorflow::cc::SavedModelAPI::Load), and
// may have different implementations depending on the runtime. For many of
// these Runtime-attached objects (such as tensorflow::cc::TensorHandle), the
// Runtime must outlive these objects.
class Runtime {
public:
// Runtime is movable, but not copyable.
Runtime(Runtime&&) = default;
Runtime& operator=(Runtime&&) = default;
private:
friend class RuntimeBuilder;
friend class SavedModelAPI;
// Wraps a TFE_Context. Takes ownership of ctx.
explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {}
// Deletes the currently wrapped TFE_Context, swaps it with ctx,
// and takes ownership of ctx.
void Reset(TFE_Context* ctx) { ctx_.reset(ctx); }
// Returns the TFE_Context that this object wraps. This object
// retains ownership of the pointer.
TFE_Context* GetTFEContext() const { return ctx_.get(); }
// Runtime is not copyable
Runtime(const Runtime&) = delete;
Runtime& operator=(const Runtime&) = delete;
struct TFEContextDeleter {
void operator()(TFE_Context* p) const { TFE_DeleteContext(p); }
};
std::unique_ptr<TFE_Context, TFEContextDeleter> ctx_;
};
} // namespace cc
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_

View File

@ -0,0 +1,84 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_
#include <memory>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/cc/experimental/base/public/runtime.h"
#include "tensorflow/cc/experimental/base/public/status.h"
namespace tensorflow {
namespace cc {
// RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime.
// Use this to set configuration options, like threadpool size, etc.
class RuntimeBuilder {
public:
RuntimeBuilder() : options_(TFE_NewContextOptions()) {}
// If `use_tfrt` is true, we will use the new Tensorflow Runtime
// (https://blog.tensorflow.org/2020/04/tfrt-new-tensorflow-runtime.html) as
// our runtime implementation.
RuntimeBuilder& SetUseTFRT(bool use_tfrt);
// Build a Tensorflow Runtime.
//
// Params:
// status - Set to OK on success and an appropriate error on failure.
// Returns:
// If status is not OK, returns nullptr. Otherwise, returns a
// unique_ptr<tensorflow::cc::Runtime>.
std::unique_ptr<Runtime> Build(Status* status);
// RuntimeBuilder is movable, but not copyable.
RuntimeBuilder(RuntimeBuilder&&) = default;
RuntimeBuilder& operator=(RuntimeBuilder&&) = default;
private:
// RuntimeBuilder is not copyable
RuntimeBuilder(const RuntimeBuilder&) = delete;
RuntimeBuilder& operator=(const RuntimeBuilder&) = delete;
struct TFEContextOptionsDeleter {
void operator()(TFE_ContextOptions* p) const {
TFE_DeleteContextOptions(p);
}
};
std::unique_ptr<TFE_ContextOptions, TFEContextOptionsDeleter> options_;
};
inline RuntimeBuilder& RuntimeBuilder::SetUseTFRT(bool use_tfrt) {
TFE_ContextOptionsSetTfrt(options_.get(), use_tfrt);
return *this;
}
inline std::unique_ptr<Runtime> RuntimeBuilder::Build(Status* status) {
TFE_Context* result = TFE_NewContext(options_.get(), status->GetTFStatus());
if (!status->ok()) {
return nullptr;
}
// We can't use std::make_unique here because of its interaction with a
// private constructor: https://abseil.io/tips/134
return std::unique_ptr<Runtime>(new Runtime(result));
}
} // namespace cc
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_

View File

@ -0,0 +1,93 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_
#include <memory>
#include <string>
#include "tensorflow/c/tf_status.h"
namespace tensorflow {
namespace cc {
// Status is a wrapper around an error code and an optional error message.
// The set of error codes are defined here:
// https://github.com/tensorflow/tensorflow/blob/08931c1e3e9eb2e26230502d678408e66730826c/tensorflow/c/tf_status.h#L39-L60
// Many Tensorflow APIs return a Status, or take a Status as an out parameter.
// Clients should check for status.ok() after calling these APIs, and either
// handle or propagate the error appropriately.
// TODO(bmzhao): Add a detailed code example before moving out of experimental.
class Status {
public:
// Create a success status
Status() : status_(TF_NewStatus()) {}
// Return the status code
TF_Code code() const;
// Returns the error message in Status.
std::string message() const;
// Returns the error message in Status.
bool ok() const;
// Record <code, msg> in Status. Any previous information is lost.
// A common use is to clear a status: SetStatus(TF_OK, "");
void SetStatus(TF_Code code, const std::string& msg);
// Status is movable, but not copyable.
Status(Status&&) = default;
Status& operator=(Status&&) = default;
private:
friend class RuntimeBuilder;
friend class Runtime;
friend class SavedModelAPI;
// Wraps a TF_Status*, and takes ownership of it.
explicit Status(TF_Status* status) : status_(status) {}
// Status is not copyable
Status(const Status&) = delete;
Status& operator=(const Status&) = delete;
// Returns the TF_Status that this object wraps. This object
// retains ownership of the pointer.
TF_Status* GetTFStatus() const { return status_.get(); }
struct TFStatusDeleter {
void operator()(TF_Status* p) const { TF_DeleteStatus(p); }
};
std::unique_ptr<TF_Status, TFStatusDeleter> status_;
};
inline TF_Code Status::code() const { return TF_GetCode(status_.get()); }
inline std::string Status::message() const {
return std::string(TF_Message(status_.get()));
}
inline bool Status::ok() const { return code() == TF_OK; }
inline void Status::SetStatus(TF_Code code, const std::string& msg) {
TF_SetStatus(status_.get(), code, msg.c_str());
}
} // namespace cc
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_

View File

@ -0,0 +1,117 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_
#include <stddef.h>
#include <stdint.h>
#include <memory>
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_tensor.h"
namespace tensorflow {
namespace cc {
// Tensor represents an n-dimensional array of values.
class Tensor {
public:
// TODO(bmzhao): Add a factory function that constructs a Tensor from a char
// buffer, with an options struct (to specify the buffer's layout, device?,
// whether to create a TFRT or TF tensor, whether we should take ownership of
// the memory, etc). This requires extending TF_NewTensor with an options
// struct:
// https://github.com/tensorflow/tensorflow/blob/3c520614a3c056d56afdc79b59979b9b0087f8b9/tensorflow/c/tf_tensor.h#L77-L80
// TODO(bmzhao): In the case we construct a tensor from non-owned memory,
// we should offer a way to deep copy the tensor into a new tensor, which
// owns the underlying memory. This could be a .deepcopy()/clone() method.
// TODO(bmzhao): In the future, we want to relax the non-copyability
// constraint. To do so, we can add a C API function that acts like CopyFrom:
// https://github.com/tensorflow/tensorflow/blob/08931c1e3e9eb2e26230502d678408e66730826c/tensorflow/core/framework/tensor.h#L301-L311
// Tensor is movable, but not copyable
Tensor(Tensor&&) = default;
Tensor& operator=(Tensor&&) = default;
// Returns the number of dimensions in the tensor. Can be -1, which represents
// unknown rank.
int dims() const;
// Returns the number of elements in in demension `d`.
// REQUIRES: `0 <= d < dims()`
int64_t dim_size(int d) const;
// Returns a pointer to the underlying data buffer.
void* data() const;
// Returns the data type of the tensor.
TF_DataType dtype() const;
// Returns the number of elements in the tensor. For a tensor with a partially
// defined shape, -1 means not fully defined.
int64_t num_elements() const;
// Returns the size of the underlying data in bytes.
size_t num_bytes() const;
private:
friend class TensorHandle;
friend class Runtime;
// Wraps a TF_Tensor. Takes ownership of handle.
explicit Tensor(TF_Tensor* tensor) : tensor_(tensor) {}
// Tensor is not copyable
Tensor(const Tensor&) = delete;
Tensor& operator=(const Tensor&) = delete;
// Returns the underlying TF_Tensor that this object wraps.
// This object retains ownership of the pointer.
TF_Tensor* GetTFTensor() const { return tensor_.get(); }
struct TFTensorDeleter {
void operator()(TF_Tensor* p) const { TF_DeleteTensor(p); }
};
std::unique_ptr<TF_Tensor, TFTensorDeleter> tensor_;
};
inline void* Tensor::data() const { return TF_TensorData(tensor_.get()); }
inline int Tensor::dims() const { return TF_NumDims(tensor_.get()); }
inline int64_t Tensor::dim_size(int d) const {
return TF_Dim(tensor_.get(), d);
}
inline TF_DataType Tensor::dtype() const {
return TF_TensorType(tensor_.get());
}
inline int64_t Tensor::num_elements() const {
return TF_TensorElementCount(tensor_.get());
}
inline size_t Tensor::num_bytes() const {
return TF_TensorByteSize(tensor_.get());
}
} // namespace cc
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_

View File

@ -13,19 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/framework/gradients.h"
#include <deque>
#include <vector>
#include "tensorflow/cc/framework/grad_op_registry.h"
#include "tensorflow/cc/framework/gradients.h"
#include "tensorflow/cc/framework/while_gradients.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/while_context.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/macros.h"

View File

@ -24,7 +24,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"

View File

@ -0,0 +1,58 @@
# Experimental C++ SavedModel Header Only APIs. See RFC
# https://github.com/tensorflow/community/pull/207
package(
# This is intentionally public
default_visibility = [
"//visibility:public",
],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "concrete_function",
hdrs = [
"concrete_function.h",
],
deps = [
":function_metadata",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/experimental/saved_model/public:concrete_function",
"//tensorflow/cc/experimental/base/public:status",
],
)
cc_library(
name = "concrete_function_list",
hdrs = [
"concrete_function_list.h",
],
deps = [
":concrete_function",
"//tensorflow/c/experimental/saved_model/public:concrete_function_list",
],
)
cc_library(
name = "function_metadata",
hdrs = [
"function_metadata.h",
],
deps = [
"//tensorflow/c/experimental/saved_model/public:function_metadata",
],
)
cc_library(
name = "saved_model_api",
hdrs = [
"saved_model_api.h",
],
deps = [
":concrete_function",
":concrete_function_list",
"//tensorflow/c/experimental/saved_model/public:saved_model_api",
"//tensorflow/cc/experimental/base/public:runtime",
"//tensorflow/cc/experimental/base/public:status",
],
)

View File

@ -0,0 +1,59 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_
#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_
#include <vector>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/cc/experimental/base/public/status.h"
#include "tensorflow/cc/saved_model/experimental/public/function_metadata.h"
namespace tensorflow {
namespace cc {
// ConcreteFunction is an executable "function" loaded from a SavedModelAPI.
class ConcreteFunction final {
public:
// TODO(bmzhao): Adding ConcreteFunction::Run in subsequent CL, since
// it depends on tensorflow::cc::Tensor and tensorflow::cc::TensorHandle
// Returns FunctionMetadata associated with this ConcreteFunction.
const FunctionMetadata* GetFunctionMetadata();
private:
friend class SavedModelAPI;
friend class ConcreteFunctionList;
// TODO(bmzhao): Consider adding a macro for wrapping/unwrapping
// when moving out of experimental.
static ConcreteFunction* wrap(TF_ConcreteFunction* p) {
return reinterpret_cast<ConcreteFunction*>(p);
}
static TF_ConcreteFunction* unwrap(ConcreteFunction* p) {
return reinterpret_cast<TF_ConcreteFunction*>(p);
}
};
inline const FunctionMetadata* ConcreteFunction::GetFunctionMetadata() {
return FunctionMetadata::wrap(TF_ConcreteFunctionGetMetadata(unwrap(this)));
}
} // namespace cc
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_

View File

@ -0,0 +1,61 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
#include <vector>
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h"
namespace tensorflow {
namespace cc {
// ConcreteFunctionList helps convert an opaque pointer to an array of
// ConcreteFunction pointers to a std::vector.
class ConcreteFunctionList {
public:
// Converts this object to a std::vector<ConcreteFunction*>
std::vector<ConcreteFunction*> ToVector();
private:
friend class SavedModelAPI;
// Wraps a TF_ConcreteFunctionList. Takes ownership of list.
explicit ConcreteFunctionList(TF_ConcreteFunctionList* list) : list_(list) {}
struct TFConcreteFunctionListDeleter {
void operator()(TF_ConcreteFunctionList* p) const {
TF_DeleteConcreteFunctionList(p);
}
};
std::unique_ptr<TF_ConcreteFunctionList, TFConcreteFunctionListDeleter> list_;
};
inline std::vector<ConcreteFunction*> ConcreteFunctionList::ToVector() {
int size = TF_ConcreteFunctionListSize(list_.get());
std::vector<ConcreteFunction*> result;
result.reserve(size);
for (int i = 0; i < size; ++i) {
result.push_back(
ConcreteFunction::wrap(TF_ConcreteFunctionListGet(list_.get(), i)));
}
return result;
}
} // namespace cc
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_

View File

@ -0,0 +1,45 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_
#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_
#include <memory>
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
namespace tensorflow {
namespace cc {
// FunctionMetadata stores additional function information, including
// optional signaturedef feeds/fetches (for TF1-based ConcreteFunctions),
// a valid function path (for TF2-based ConcreteFunctions), and
// the types + number of inputs and outputs.
class FunctionMetadata final {
// TODO(bmzhao): Add getters here as necessary.
private:
friend class ConcreteFunction;
static FunctionMetadata* wrap(TF_FunctionMetadata* p) {
return reinterpret_cast<FunctionMetadata*>(p);
}
static TF_FunctionMetadata* unwrap(FunctionMetadata* p) {
return reinterpret_cast<TF_FunctionMetadata*>(p);
}
};
} // namespace cc
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_

View File

@ -0,0 +1,160 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_
#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
#include "tensorflow/cc/experimental/base/public/runtime.h"
#include "tensorflow/cc/experimental/base/public/status.h"
#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h"
#include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h"
namespace tensorflow {
namespace cc {
// SavedModelAPI offers a way to load Tensorflow Saved Models
// (https://www.tensorflow.org/guide/saved_model) and execute saved
// tf.functions or legacy SignatureDefs in a TF2-idiomatic fashion.
// See RFC 207
// (https://github.com/tensorflow/community/blob/master/rfcs/20200218-tf-c-saved-model.md)
// TODO(bmzhao): Add an e2e example here, once ConcreteFunction::Run is added.
class SavedModelAPI {
public:
// Load a SavedModel from `dirname`.
//
// Params:
// saved_model_path - A directory filepath that the SavedModel is at.
// runtime - A runtime used to load SavedModelAPI. `runtime` must outlive the
// returned TF_SavedModel pointer.
// tags - Optional set of tags. If tags = nullptr, we expect the SavedModel
// to contain a single Metagraph (as for those exported from TF2's
// `tf.saved_model.save`). If tags != nullptr, we load the metagraph
// matching the tags:
// https://github.com/tensorflow/tensorflow/blob/428cdeda09aef81e958eeb274b83d27ad635b57b/tensorflow/core/protobuf/meta_graph.proto#L50-L56
// status - Set to OK on success and an appropriate error on failure.
// Returns:
// If status is not OK, returns nullptr.
static std::unique_ptr<SavedModelAPI> Load(
const std::string& saved_model_path, const Runtime& runtime,
Status* status, const std::unordered_set<std::string>* tags = nullptr);
// Retrieve a function from the TF2 SavedModel via function path.
//
// Params:
// function_path - A string containing the path from the root saved python
// object to a tf.function method.
// status - Set to OK on success and an appropriate error on failure.
// Returns:
// If status is not OK, returns nullptr. Otherwise, returns a
// tensorflow::cc::ConcreteFunction pointer. The lifetime of this pointer
// is bound to SavedModelAPI it was loaded from.
ConcreteFunction* GetConcreteFunction(const std::string& function_path,
Status* status);
// Retrieve a function from the TF SavedModel via a SignatureDef key.
//
// Params:
// signature_def_key - String key of SignatureDef map of a SavedModel:
// https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89
// status - Set to OK on success and an appropriate error on failure.
// Returns:
// If status is not OK, returns nullptr. Otherwise, returns a
// tensorflow::cc::ConcreteFunction pointer. The lifetime of this pointer
// is bound to SavedModelAPI it was loaded from.
ConcreteFunction* GetSignatureDefFunction(const std::string& function_path,
Status* status);
// Lists all Conrete Functions available from the SavedModel.
std::vector<ConcreteFunction*> ListFunctions();
// SavedModelAPI is movable, but not copyable.
SavedModelAPI(SavedModelAPI&&) = default;
SavedModelAPI& operator=(SavedModelAPI&&) = default;
private:
SavedModelAPI(const SavedModelAPI&) = delete;
SavedModelAPI& operator=(const SavedModelAPI&) = delete;
explicit SavedModelAPI(TF_SavedModel* model) : saved_model_(model) {}
struct TFSavedModelDeleter {
void operator()(TF_SavedModel* p) const { TF_DeleteSavedModel(p); }
};
std::unique_ptr<TF_SavedModel, TFSavedModelDeleter> saved_model_;
};
inline std::unique_ptr<SavedModelAPI> SavedModelAPI::Load(
const std::string& saved_model_path, const Runtime& runtime, Status* status,
const std::unordered_set<std::string>* tags) {
TF_SavedModel* saved_model = nullptr;
if (tags == nullptr) {
saved_model =
TF_LoadSavedModel(saved_model_path.c_str(), runtime.GetTFEContext(),
status->GetTFStatus());
} else {
std::vector<const char*> tags_vector;
tags_vector.reserve(tags->size());
for (const std::string& tag : *tags) {
tags_vector.push_back(tag.c_str());
}
saved_model = TF_LoadSavedModelWithTags(
saved_model_path.c_str(), runtime.GetTFEContext(), tags_vector.data(),
tags_vector.size(), status->GetTFStatus());
}
if (!status->ok()) {
return nullptr;
}
// We can't use std::make_unique here because of its interaction with a
// private constructor: https://abseil.io/tips/134
return std::unique_ptr<SavedModelAPI>(new SavedModelAPI(saved_model));
}
inline ConcreteFunction* SavedModelAPI::GetConcreteFunction(
const std::string& function_path, Status* status) {
TF_ConcreteFunction* function = TF_GetSavedModelConcreteFunction(
saved_model_.get(), function_path.c_str(), status->GetTFStatus());
if (!status->ok()) {
return nullptr;
}
return ConcreteFunction::wrap(function);
}
inline ConcreteFunction* SavedModelAPI::GetSignatureDefFunction(
const std::string& function_path, Status* status) {
TF_ConcreteFunction* function = TF_GetSavedModelSignatureDefFunction(
saved_model_.get(), function_path.c_str(), status->GetTFStatus());
if (!status->ok()) {
return nullptr;
}
return ConcreteFunction::wrap(function);
}
inline std::vector<ConcreteFunction*> SavedModelAPI::ListFunctions() {
ConcreteFunctionList list(TF_ListSavedModelFunctions(saved_model_.get()));
return list.ToVector();
}
} // namespace cc
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_

View File

@ -0,0 +1,22 @@
# Tests for the C++ header-only SavedModelAPI.
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
package(
licenses = ["notice"], # Apache 2.0
)
tf_cc_test(
name = "saved_model_api_test",
srcs = [
"saved_model_api_test.cc",
],
deps = [
"//tensorflow/cc/experimental/base/public:runtime",
"//tensorflow/cc/experimental/base/public:runtime_builder",
"//tensorflow/cc/experimental/base/public:status",
"//tensorflow/cc/saved_model/experimental/public:saved_model_api",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)

View File

@ -0,0 +1,97 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/saved_model/experimental/public/saved_model_api.h"
#include <memory>
#include <string>
#include <unordered_set>
#include "tensorflow/cc/experimental/base/public/runtime.h"
#include "tensorflow/cc/experimental/base/public/runtime_builder.h"
#include "tensorflow/cc/experimental/base/public/status.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
constexpr char kTestData[] = "cc/saved_model/testdata";
std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
return tensorflow::io::JoinPath(tensorflow::testing::TensorFlowSrcRoot(),
kTestData, saved_model_dir);
}
// This value parameterized test allows us to test both TFRT
// and non TFRT runtimes.
// https://github.com/google/googletest/blob/dcc92d0ab6c4ce022162a23566d44f673251eee4/googletest/docs/advanced.md#value-parameterized-tests
class CPPSavedModelAPITest : public ::testing::TestWithParam<bool> {};
TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
cc::Status status;
cc::RuntimeBuilder builder;
bool use_tfrt = GetParam();
if (use_tfrt) {
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
}
builder.SetUseTFRT(use_tfrt);
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message();
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
std::unordered_set<std::string> tags = {"serve"};
std::unique_ptr<cc::SavedModelAPI> model =
cc::SavedModelAPI::Load(model_dir, *runtime, &status, &tags);
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
// That unblocks writing other tests that require a TF_SavedModel*,
// like loading a ConcreteFunction. This test at least checks that the
// C API builds and can be minimally run.
EXPECT_EQ(status.code(), TF_UNIMPLEMENTED);
}
TEST_P(CPPSavedModelAPITest, LoadsSavedModel) {
cc::Status status;
cc::RuntimeBuilder builder;
bool use_tfrt = GetParam();
if (use_tfrt) {
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
}
builder.SetUseTFRT(use_tfrt);
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message();
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
std::unique_ptr<cc::SavedModelAPI> model =
cc::SavedModelAPI::Load(model_dir, *runtime, &status);
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
// That unblocks writing other tests that require a TF_SavedModel*,
// like loading a ConcreteFunction. This test at least checks that the
// C API builds and can be minimally run.
EXPECT_EQ(status.code(), TF_UNIMPLEMENTED);
}
INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests,
CPPSavedModelAPITest, ::testing::Bool());
} // namespace
} // namespace tensorflow

View File

@ -19,12 +19,16 @@ limitations under the License.
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/reader.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf_internal.h"
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#include "tensorflow/core/protobuf/saver.pb.h"
@ -65,12 +69,39 @@ uint64 GetLatencyMicroseconds(const uint64 start_microseconds) {
return end_microseconds - start_microseconds;
}
// Ensure that constant tensors loaded from the saved model have valid shape.
// Also ensure that constant nodes have a value assigned to them.
// TODO(b/154763635): this is temporary and will be replaced with a better audit
static Status ValidateSavedTensors(const GraphDef& graph_def) {
for (const auto& node : graph_def.node()) {
const auto node_iterator = node.attr().find("value");
if (node_iterator != node.attr().end()) {
AttrValue node_value = node_iterator->second;
if (node_value.has_tensor()) {
const PartialTensorShape node_shape(node_value.tensor().tensor_shape());
if (node_shape.num_elements() < 0) {
return errors::FailedPrecondition(
"Saved model contains node \"", node.name(), "\" (op \"",
node.op(), "\") which initializes from a tensor with ",
node_shape.num_elements(), " elements");
}
}
} else if (node.op() == "Const") {
return errors::FailedPrecondition(
"Saved model contains node \"", node.name(),
"\" which is a constant tensor but no value has been provided");
}
}
return Status::OK();
}
Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def,
const SessionOptions& session_options,
std::unique_ptr<Session>* session) {
Session* session_p = nullptr;
TF_RETURN_IF_ERROR(NewSession(session_options, &session_p));
session->reset(session_p);
TF_RETURN_IF_ERROR(ValidateSavedTensors(meta_graph_def.graph_def()));
return (*session)->Create(meta_graph_def.graph_def());
}

View File

@ -40,6 +40,10 @@ constexpr char kTestDataInitOpV2[] =
"cc/saved_model/testdata/half_plus_two_v2/00000123";
constexpr char kTestDataV2DebugInfo[] =
"cc/saved_model/testdata/x_plus_y_v2_debuginfo";
constexpr char kTestFuzzGeneratedNegativeShape[] =
"cc/saved_model/testdata/fuzz_generated/negative_shape";
constexpr char kTestFuzzGeneratedConstWithNoValue[] =
"cc/saved_model/testdata/fuzz_generated/const_with_no_value";
class LoaderTest : public ::testing::Test {
protected:
@ -256,5 +260,29 @@ TEST_F(LoaderTest, SavedModelV2DebugInfo) {
EXPECT_NE(bundle.debug_info.get(), nullptr);
}
TEST_F(LoaderTest, NegativeShapeDimension) {
SavedModelBundle bundle;
RunOptions run_options;
SessionOptions session_options;
const string export_dir = io::JoinPath(testing::TensorFlowSrcRoot(),
kTestFuzzGeneratedNegativeShape);
Status st = LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle);
EXPECT_FALSE(st.ok());
}
TEST_F(LoaderTest, ConstNoValue) {
SavedModelBundle bundle;
RunOptions run_options;
SessionOptions session_options;
const string export_dir = io::JoinPath(testing::TensorFlowSrcRoot(),
kTestFuzzGeneratedConstWithNoValue);
Status st = LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle);
EXPECT_FALSE(st.ok());
}
} // namespace
} // namespace tensorflow

Binary file not shown.

View File

@ -38,7 +38,7 @@ namespace benchmark {
struct Options {
// kDefaultMicros specifies the default time to run the benchmark, and is used
// if neither max_iters nor max_micros is set.
static const int64 kDefaultMicros = 3000000;
static constexpr int64 kDefaultMicros = 3000000;
int64 max_iters = 0; // Maximum iterations to run, ignored if <= 0.
int64 max_micros = 0; // Maximum microseconds to run, ignored if <= 0.

View File

@ -38,6 +38,7 @@ def tf_library(
tfcompile_tool = "//tensorflow/compiler/aot:tfcompile",
include_standard_runtime_deps = True,
enable_xla_hlo_profiling = False,
enable_tracemes = False,
mlir_components = "None",
deps = None,
tags = []):
@ -89,6 +90,9 @@ def tf_library(
enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated
program, and emit metadata that lets us pretty-print the gathered
profile counters.
enable_tracemes: Tell tfcompile to generate calls to
TraceMe::Activity{Start|End} around HLO instructions that can be used by
Xprof to construct profiler timelines.
mlir_components: When the value is "None", no components use MLIR. When
the value is "Bridge", use MLIR to translate GraphDef to HLO.
deps: a list of deps to include on the build rules for the generated
@ -190,6 +194,11 @@ def tf_library(
else:
profiling_flag = ""
if enable_tracemes:
traceme_flag = "--xla_cpu_enable_xprof_traceme=true"
else:
traceme_flag = "--xla_cpu_enable_xprof_traceme=false"
mlir_flag = "--mlir_components=" + mlir_components
srcs = [tfcompile_graph, config]
@ -218,7 +227,7 @@ def tf_library(
" --out_header=$(@D)/" + header_file +
" --out_metadata_object=$(@D)/" + metadata_object_file +
" --out_function_object=$(@D)/" + function_object_file +
" " + flags + " " + profiling_flag + " " + mlir_flag
" " + flags + " " + profiling_flag + " " + mlir_flag + " " + traceme_flag
),
tools = [tfcompile_tool],
visibility = visibility,

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/memory_types.h"
@ -41,7 +42,6 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/public/version.h"

View File

@ -18,12 +18,12 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/test_util.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/graph_def_builder_util.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"

View File

@ -46,6 +46,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/graph_def_util.h"
@ -55,7 +56,6 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/public/version.h"

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/function.h"
@ -45,7 +46,6 @@ limitations under the License.
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/public/version.h"

View File

@ -21,12 +21,12 @@ limitations under the License.
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/graph_def_builder_util.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"

View File

@ -25,12 +25,11 @@ limitations under the License.
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"

View File

@ -28,9 +28,9 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h"

View File

@ -21,8 +21,8 @@ limitations under the License.
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
#include "tensorflow/compiler/tf2xla/test_util.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/proto_serialization.h"

View File

@ -24,9 +24,9 @@ limitations under the License.
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/test_util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h"

View File

@ -358,13 +358,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_,
resources_, constants_, /*lazy=*/false, &client, &variables, &kernel,
&executable);
if (!s.ok() && (platform_info_.device_type().type_string() == DEVICE_CPU ||
platform_info_.device_type().type_string() == DEVICE_GPU)) {
// Suggest auto jit if the failure was with GPU or CPU.
errors::AppendToMessage(&s,
xla::status_macros::kPossibleAutoJitAlternative);
}
OP_REQUIRES_OK(ctx, s);
}

View File

@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/memory_types.h"
@ -49,7 +50,6 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/public/version.h"
@ -2078,6 +2078,8 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"XlaSend",
"XlaSharding",
"XlaSort",
"XlaSpmdFullToShardShape",
"XlaSpmdShardToFullShape",
"XlaSvd",
"XlaWhile",
"_Arg",

Some files were not shown because too many files have changed in this diff Show More