Merge branch 'upstream/master' into interface_16x8
This commit is contained in:
commit
0391c064f5
49
.bazelrc
49
.bazelrc
@ -356,9 +356,10 @@ build:rbe_linux --linkopt=-lm
|
||||
build:rbe_cpu_linux --config=rbe_linux
|
||||
build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain"
|
||||
build:rbe_cpu_linux --extra_toolchains="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:cc-toolchain-k8"
|
||||
build:rbe_cpu_linux --extra_execution_platforms"=@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
|
||||
build:rbe_cpu_linux --host_platform="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
|
||||
build:rbe_cpu_linux --platforms="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
|
||||
build:rbe_cpu_linux --extra_execution_platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
|
||||
build:rbe_cpu_linux --extra_execution_platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
|
||||
build:rbe_cpu_linux --host_platform="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
|
||||
build:rbe_cpu_linux --platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
|
||||
|
||||
build:rbe_linux_cuda_base --config=rbe_linux
|
||||
build:rbe_linux_cuda_base --repo_env=TF_NEED_TENSORRT=1
|
||||
@ -380,17 +381,37 @@ build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-gcc7_
|
||||
build:rbe_linux_cuda_nvcc --define=using_cuda_nvcc=true
|
||||
test:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
|
||||
|
||||
build:rbe_linux_cuda_clang --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda_clang --crosstool_top="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda_clang --extra_toolchains="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda_clang --extra_execution_platforms="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_clang --host_platform="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_clang --platforms="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_clang --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||
build:rbe_linux_cuda_clang --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||
build:rbe_linux_cuda_clang --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||
build:rbe_linux_cuda_clang --define=using_cuda_clang=true
|
||||
test:rbe_linux_cuda_clang --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda_nvcc_base --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda_nvcc_base --crosstool_top="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda_nvcc_base --extra_toolchains="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda_nvcc_base --extra_execution_platforms="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc_base --host_platform="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc_base --platforms="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||
build:rbe_linux_cuda_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||
build:rbe_linux_cuda_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||
build:rbe_linux_cuda_nvcc_base --define=using_cuda_nvcc=true
|
||||
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
|
||||
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
|
||||
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
|
||||
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
|
||||
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
|
||||
|
||||
build:rbe_linux_cuda_clang_base --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda_clang_base --crosstool_top="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda_clang_base --extra_toolchains="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda_clang_base --extra_execution_platforms="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_clang_base --host_platform="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_clang_base --platforms="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_clang_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||
build:rbe_linux_cuda_clang_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||
build:rbe_linux_cuda_clang_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||
build:rbe_linux_cuda_clang_base --define=using_cuda_clang=true
|
||||
build:rbe_linux_cuda_clang_py27 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
|
||||
build:rbe_linux_cuda_clang_py35 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
|
||||
build:rbe_linux_cuda_clang_py36 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
|
||||
build:rbe_linux_cuda_clang_py37 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
|
||||
build:rbe_linux_cuda_clang_py38 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
|
||||
|
||||
common:rbe_gpu_linux --config=rbe_linux_cuda_nvcc
|
||||
|
||||
|
@ -38,6 +38,9 @@ state what is wrong:
|
||||
- Producing correct results, but the model is slower than expected (model generated from old converter)
|
||||
|
||||
|
||||
**RNN conversion support**
|
||||
If converting TF RNN to TFLite fused RNN ops, please prefix [RNN] in the title.
|
||||
|
||||
**Any other info / logs**
|
||||
|
||||
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
|
||||
|
39
.github/stale.yml
vendored
Normal file
39
.github/stale.yml
vendored
Normal file
@ -0,0 +1,39 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
#
|
||||
# THIS IS A GENERATED DOCKERFILE.
|
||||
#
|
||||
# This file was assembled from multiple pieces, whose use is documented
|
||||
# throughout. Please refer to the TensorFlow dockerfiles documentation
|
||||
# for more information.
|
||||
|
||||
# Number of days of inactivity before an Issue or Pull Request becomes stale
|
||||
daysUntilStale: 7
|
||||
# Number of days of inactivity before a stale Issue or Pull Request is closed
|
||||
daysUntilClose: 7
|
||||
# Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable
|
||||
onlyLabels:
|
||||
- stat:awaiting response
|
||||
# Comment to post when marking as stale. Set to `false` to disable
|
||||
markComment: >
|
||||
This issue has been automatically marked as stale because it has not had
|
||||
recent activity. It will be closed if no further activity occurs. Thank you.
|
||||
# Comment to post when removing the stale label. Set to `false` to disable
|
||||
unmarkComment: false
|
||||
closeComment: >
|
||||
Closing as stale. Please reopen if you'd like to work on this further.
|
||||
limitPerRun: 30
|
||||
# Limit to only `issues` or `pulls`
|
||||
only: issues
|
10
configure.py
10
configure.py
@ -1171,14 +1171,16 @@ def system_specific_test_config(environ_cp):
|
||||
test_only_filters = ['-oss_serial']
|
||||
if is_windows():
|
||||
test_and_build_filters.append('-no_windows')
|
||||
if environ_cp.get('TF_NEED_CUDA', None) == '1':
|
||||
if ((environ_cp.get('TF_NEED_CUDA', None) == '1') or
|
||||
(environ_cp.get('TF_NEED_ROCM', None) == '1')):
|
||||
test_and_build_filters += ['-no_windows_gpu', '-no_gpu']
|
||||
else:
|
||||
test_and_build_filters.append('-gpu')
|
||||
elif is_macos():
|
||||
test_and_build_filters += ['-gpu', '-nomac', '-no_mac']
|
||||
elif is_linux():
|
||||
if environ_cp.get('TF_NEED_CUDA', None) == '1':
|
||||
if ((environ_cp.get('TF_NEED_CUDA', None) == '1') or
|
||||
(environ_cp.get('TF_NEED_ROCM', None) == '1')):
|
||||
test_and_build_filters.append('-no_gpu')
|
||||
write_to_bazelrc('test --test_env=LD_LIBRARY_PATH')
|
||||
else:
|
||||
@ -1416,6 +1418,10 @@ def main():
|
||||
write_action_env_to_bazelrc('LD_LIBRARY_PATH',
|
||||
environ_cp.get('LD_LIBRARY_PATH'))
|
||||
|
||||
if (environ_cp.get('TF_NEED_ROCM') == '1' and environ_cp.get('ROCM_PATH')):
|
||||
write_action_env_to_bazelrc('ROCM_PATH', environ_cp.get('ROCM_PATH'))
|
||||
write_action_env_to_bazelrc('ROCM_ROOT', environ_cp.get('ROCM_PATH'))
|
||||
|
||||
environ_cp['TF_NEED_CUDA'] = str(
|
||||
int(get_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False)))
|
||||
if (environ_cp.get('TF_NEED_CUDA') == '1' and
|
||||
|
@ -523,6 +523,12 @@ package_group(
|
||||
],
|
||||
)
|
||||
|
||||
package_group(name = "ndarray_tensor_allow_list")
|
||||
|
||||
# Packages that use composite tensors or dispatch.
|
||||
# TODO(b/154762408) Remove this package group once it's no longer needed.
|
||||
package_group(name = "composite_tensor_whitelist")
|
||||
|
||||
filegroup(
|
||||
name = "intel_binary_blob",
|
||||
data = if_mkl_ml(
|
||||
|
@ -116,7 +116,7 @@ from tensorflow.python.lib.io import file_io as _fi
|
||||
|
||||
# Get sitepackages directories for the python installation.
|
||||
_site_packages_dirs = []
|
||||
_site_packages_dirs += [_site.USER_SITE]
|
||||
_site_packages_dirs += [] if _site.USER_SITE is None else [_site.USER_SITE]
|
||||
_site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p]
|
||||
if 'getsitepackages' in dir(_site):
|
||||
_site_packages_dirs += _site.getsitepackages()
|
||||
|
@ -126,7 +126,7 @@ from tensorflow.python.lib.io import file_io as _fi
|
||||
|
||||
# Get sitepackages directories for the python installation.
|
||||
_site_packages_dirs = []
|
||||
_site_packages_dirs += [_site.USER_SITE]
|
||||
_site_packages_dirs += [] if _site.USER_SITE is None else [_site.USER_SITE]
|
||||
_site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p]
|
||||
if 'getsitepackages' in dir(_site):
|
||||
_site_packages_dirs += _site.getsitepackages()
|
||||
|
@ -58,6 +58,7 @@ filegroup(
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"c_api_internal.h",
|
||||
"conversion_macros.h",
|
||||
"python_api.h",
|
||||
"tensor_interface.h",
|
||||
"tf_status_helper.h",
|
||||
@ -86,6 +87,13 @@ tf_cuda_library(
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
],
|
||||
"//tensorflow:chromiumos": [
|
||||
":tf_attrtype",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:platform",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":tf_attrtype",
|
||||
"//tensorflow/core:core_cpu",
|
||||
@ -118,6 +126,13 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "c_api_macros",
|
||||
hdrs = ["c_api_macros.h"],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api",
|
||||
hdrs = [
|
||||
@ -327,6 +342,9 @@ tf_cuda_library(
|
||||
":checkpoint_reader",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_internal",
|
||||
"//tensorflow/c/eager:tfe_context_internal",
|
||||
"//tensorflow/c/eager:tfe_op_internal",
|
||||
"//tensorflow/c/eager:tfe_tensorhandle_internal",
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
@ -517,12 +535,12 @@ tf_cuda_cc_test(
|
||||
":test_op1.so",
|
||||
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
|
||||
],
|
||||
kernels = [":test_op_kernel"],
|
||||
linkopts = select({
|
||||
"//tensorflow:macos": ["-headerpad_max_install_names"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
tags = [
|
||||
"no_windows", # TODO(b/155444728)
|
||||
"noasan",
|
||||
],
|
||||
# We must ensure that the dependencies can be dynamically linked since
|
||||
@ -531,6 +549,7 @@ tf_cuda_cc_test(
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_test_util",
|
||||
":test_op_kernel",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:grad_ops",
|
||||
"//tensorflow/cc/saved_model:signature_constants",
|
||||
@ -597,6 +616,7 @@ tf_cc_test(
|
||||
":c_api",
|
||||
":c_api_internal",
|
||||
":c_test_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -721,3 +741,11 @@ tf_cuda_library(
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "conversion_macros",
|
||||
hdrs = [
|
||||
"conversion_macros.h",
|
||||
],
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
@ -39,6 +39,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eval_const_tensor.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||
@ -53,7 +54,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/graph/validate.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
|
@ -21,6 +21,9 @@ limitations under the License.
|
||||
#include "tensorflow/c/checkpoint_reader.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
@ -686,8 +689,7 @@ TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
|
||||
std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
|
||||
|
||||
status->status = tensorflow::Status::OK();
|
||||
return new TFE_TensorHandle{
|
||||
tensorflow::TensorHandle::CreateLocalHandle(tensor)};
|
||||
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -708,7 +710,7 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
|
||||
|
||||
// New server created for new server_def. Unused if updating server_def.
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::GrpcServer* grpc_server =
|
||||
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||
if (grpc_server == nullptr) {
|
||||
@ -822,14 +824,13 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
||||
|
||||
const int num_inputs = input_shapes->num_items;
|
||||
NodeDef node_def;
|
||||
node_def.set_name(tfe_op->operation->Name());
|
||||
node_def.set_op(tfe_op->operation->Name());
|
||||
tensorflow::AbstractOperationInterface* op = tensorflow::unwrap(tfe_op);
|
||||
node_def.set_name(op->Name());
|
||||
node_def.set_op(op->Name());
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
node_def.add_input("dummy_input");
|
||||
}
|
||||
OperationFromInterface(tfe_op->operation)
|
||||
->Attrs()
|
||||
.FillAttrValueMap(node_def.mutable_attr());
|
||||
OperationFromInterface(op)->Attrs().FillAttrValueMap(node_def.mutable_attr());
|
||||
|
||||
const tensorflow::OpRegistrationData* op_reg_data;
|
||||
status->status =
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/c_test_util.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
|
@ -38,7 +38,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
33
tensorflow/c/c_api_macros.h
Normal file
33
tensorflow/c/c_api_macros.h
Normal file
@ -0,0 +1,33 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_C_API_MACROS_H_
|
||||
#define TENSORFLOW_C_C_API_MACROS_H_
|
||||
|
||||
#ifdef SWIG
|
||||
#define TF_CAPI_EXPORT
|
||||
#else
|
||||
#if defined(_WIN32)
|
||||
#ifdef TF_COMPILE_LIBRARY
|
||||
#define TF_CAPI_EXPORT __declspec(dllexport)
|
||||
#else
|
||||
#define TF_CAPI_EXPORT __declspec(dllimport)
|
||||
#endif // TF_COMPILE_LIBRARY
|
||||
#else
|
||||
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
|
||||
#endif // _WIN32
|
||||
#endif // SWIG
|
||||
|
||||
#endif // TENSORFLOW_C_C_API_MACROS_H_
|
33
tensorflow/c/conversion_macros.h
Normal file
33
tensorflow/c/conversion_macros.h
Normal file
@ -0,0 +1,33 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_CONVERSION_MACROS_H_
|
||||
#define TENSORFLOW_C_CONVERSION_MACROS_H_
|
||||
|
||||
#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \
|
||||
inline cpp_impl *unwrap(wrapper *w) { \
|
||||
return reinterpret_cast<cpp_impl *>(w); \
|
||||
} \
|
||||
\
|
||||
inline const cpp_impl *unwrap(const wrapper *w) { \
|
||||
return reinterpret_cast<const cpp_impl *>(w); \
|
||||
} \
|
||||
\
|
||||
inline wrapper *wrap(cpp_impl *i) { return reinterpret_cast<wrapper *>(i); } \
|
||||
inline const wrapper *wrap(const cpp_impl *i) { \
|
||||
return reinterpret_cast<const wrapper *>(i); \
|
||||
}
|
||||
|
||||
#endif // TENSORFLOW_C_CONVERSION_MACROS_H_
|
@ -16,6 +16,7 @@ load(
|
||||
"//tensorflow/core/platform:build_config_root.bzl",
|
||||
"tf_cuda_tests_tags",
|
||||
)
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
@ -41,12 +42,20 @@ tf_cuda_library(
|
||||
":context_interface",
|
||||
":operation_interface",
|
||||
":tensor_handle_interface",
|
||||
":tfe_context_internal",
|
||||
":tfe_cancellation_manager_internal",
|
||||
":tfe_executor_internal",
|
||||
":tfe_monitoring_internal",
|
||||
":tfe_op_attrs_internal",
|
||||
":tfe_op_internal",
|
||||
":tfe_tensor_debug_info_internal",
|
||||
":tfe_tensorhandle_internal",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:fixed_array",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/c:tf_status_internal",
|
||||
"//tensorflow/c:tf_tensor_internal",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
@ -100,6 +109,11 @@ filegroup(
|
||||
"dlpack.h",
|
||||
"operation_interface.h",
|
||||
"tensor_handle_interface.h",
|
||||
"tfe_cancellation_manager_internal.h",
|
||||
"tfe_executor_internal.h",
|
||||
"tfe_monitoring_internal.h",
|
||||
"tfe_op_attrs_internal.h",
|
||||
"tfe_tensor_debug_info_internal.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
@ -107,33 +121,27 @@ filegroup(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
cc_library(
|
||||
name = "c_api_internal",
|
||||
srcs = [
|
||||
hdrs = [
|
||||
"c_api_experimental.h",
|
||||
"c_api_unified_experimental.h",
|
||||
"c_api_internal.h",
|
||||
],
|
||||
hdrs = ["c_api_internal.h"],
|
||||
visibility = [
|
||||
"//learning/deepmind/courier:__subpackages__",
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":c_api",
|
||||
":context_interface",
|
||||
":operation_interface",
|
||||
":tensor_handle_interface",
|
||||
"//tensorflow/c:c_api",
|
||||
":tfe_cancellation_manager_internal",
|
||||
":tfe_context_internal",
|
||||
":tfe_executor_internal",
|
||||
":tfe_monitoring_internal",
|
||||
":tfe_op_attrs_internal",
|
||||
":tfe_op_internal",
|
||||
":tfe_tensor_debug_info_internal",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/common_runtime/eager:eager_executor",
|
||||
],
|
||||
)
|
||||
|
||||
@ -177,13 +185,110 @@ cc_library(
|
||||
":operation_interface",
|
||||
":tensor_handle_interface",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfe_context_internal",
|
||||
hdrs = ["tfe_context_internal.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":context_interface",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfe_cancellation_manager_internal",
|
||||
hdrs = ["tfe_cancellation_manager_internal.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfe_executor_internal",
|
||||
hdrs = ["tfe_executor_internal.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core/common_runtime/eager:eager_executor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfe_monitoring_internal",
|
||||
hdrs = ["tfe_monitoring_internal.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfe_op_attrs_internal",
|
||||
hdrs = ["tfe_op_attrs_internal.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfe_op_internal",
|
||||
hdrs = ["tfe_op_internal.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":operation_interface",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfe_tensor_debug_info_internal",
|
||||
hdrs = ["tfe_tensor_debug_info_internal.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfe_tensorhandle_internal",
|
||||
hdrs = ["tfe_tensorhandle_internal.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":tensor_handle_interface",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api_test_util",
|
||||
testonly = 1,
|
||||
@ -213,7 +318,8 @@ tf_cuda_cc_test(
|
||||
],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = [
|
||||
"guitar",
|
||||
"noguitar", # TODO(b/155445984): flaky
|
||||
#"guitar",
|
||||
"multi_gpu",
|
||||
],
|
||||
deps = [
|
||||
@ -221,6 +327,8 @@ tf_cuda_cc_test(
|
||||
":c_api_experimental",
|
||||
":c_api_internal",
|
||||
":c_api_test_util",
|
||||
":tfe_op_internal",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
@ -239,12 +347,16 @@ tf_cuda_cc_test(
|
||||
srcs = [
|
||||
"c_api_remote_test.cc",
|
||||
],
|
||||
# TODO(b/136478427): Figure out how to correctly shut the server down
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = ["noasan"], # leaks gRPC server instances
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_internal",
|
||||
":c_api_test_util",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -256,11 +368,42 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "c_api_cluster_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"c_api_cluster_test.cc",
|
||||
],
|
||||
# TODO(b/136478427): Figure out how to correctly shut the server down
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = ["noasan"], # leaks gRPC server instances
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_internal",
|
||||
":c_api_test_util",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"//tensorflow/core/platform:env",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api_experimental",
|
||||
srcs = [
|
||||
"c_api_experimental.cc",
|
||||
"c_api_unified_experimental.cc",
|
||||
"c_api_unified_experimental_eager.cc",
|
||||
"c_api_unified_experimental_graph.cc",
|
||||
"c_api_unified_experimental_internal.h",
|
||||
],
|
||||
hdrs = [
|
||||
"c_api_experimental.h",
|
||||
@ -275,6 +418,9 @@ tf_cuda_library(
|
||||
"//conditions:default": [
|
||||
":c_api",
|
||||
":c_api_internal",
|
||||
":tfe_context_internal",
|
||||
":tfe_op_internal",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/core:core_cpu",
|
||||
@ -357,6 +503,7 @@ tf_cuda_cc_test(
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/core:lib",
|
||||
@ -438,8 +585,9 @@ cc_library(
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_internal",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c:tf_status_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -26,7 +26,6 @@ limitations under the License.
|
||||
// clang-format on
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/fixed_array.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
@ -34,6 +33,9 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
#include "tensorflow/c/eager/c_api_tfrt.h"
|
||||
@ -298,7 +300,7 @@ tensorflow::Status CreateRemoteContexts(
|
||||
|
||||
std::vector<bool> filtered_device_mask;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->FilterDevicesForRemoteWorkers(
|
||||
remote_worker, base_request.cluster_device_attributes(),
|
||||
&filtered_device_mask);
|
||||
@ -383,7 +385,7 @@ tensorflow::Status UpdateRemoteContexts(
|
||||
|
||||
std::vector<bool> filtered_device_mask;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->FilterDevicesForRemoteWorkers(
|
||||
remote_worker, base_request.cluster_device_attributes(),
|
||||
&filtered_device_mask);
|
||||
@ -464,7 +466,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
// New server created for new server_def. Unused if updating server_def.
|
||||
std::unique_ptr<tensorflow::ServerInterface> new_server;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::GrpcServer* grpc_server;
|
||||
if (reset_context) {
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
|
||||
@ -498,6 +500,17 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
grpc_server->master_env()->worker_cache->GetEagerClientCache(
|
||||
&remote_eager_workers));
|
||||
|
||||
// For cluster update, use a status group to aggregate statuses from
|
||||
// * adding and removing remote devices
|
||||
// * creating remote contexts on newly added workers
|
||||
// * updating remote contexts on existing workers
|
||||
// * updating the master context
|
||||
// Note that we should not return immediately on errors in the middle of these
|
||||
// updates to prevent cluster from having inconsistent context views.
|
||||
//
|
||||
// Unused if `reset_context` is True.
|
||||
tensorflow::StatusGroup sg;
|
||||
|
||||
// When updating an existing context, populate the following lists with:
|
||||
// * added_workers: set(remote_workers) - set(curr_remote_workers)
|
||||
// * removed_workers: set(curr_remote_workers) - set(remote_workers)
|
||||
@ -533,7 +546,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
DifferentiateWorkerLists(&curr_remote_workers, &remote_workers,
|
||||
&added_workers, &removed_workers,
|
||||
&existing_workers);
|
||||
LOG_AND_RETURN_IF_ERROR(GetReplacedFromExistingWorkers(
|
||||
sg.Update(GetReplacedFromExistingWorkers(
|
||||
&existing_workers, context_id, context->GetContextViewId(), server_def,
|
||||
remote_eager_workers.get(), &replaced_workers));
|
||||
if (VLOG_IS_ON(1)) {
|
||||
@ -557,11 +570,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
existing_workers.end());
|
||||
}
|
||||
}
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr));
|
||||
LOG_AND_RETURN_IF_ERROR(AddRemoteDevicesToMgr(
|
||||
added_workers, grpc_server->master_env()->worker_cache,
|
||||
remote_device_mgr));
|
||||
sg.Update(RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr));
|
||||
sg.Update(AddRemoteDevicesToMgr(added_workers,
|
||||
grpc_server->master_env()->worker_cache,
|
||||
remote_device_mgr));
|
||||
}
|
||||
|
||||
std::vector<tensorflow::DeviceAttributes> cluster_device_attributes;
|
||||
@ -582,7 +594,6 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
}
|
||||
|
||||
// Initialize remote eager workers.
|
||||
// TODO(b/138847548) Create remote eager contexts in async mode by default.
|
||||
if (reset_context) {
|
||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
||||
ctx, remote_workers, context_id, context_view_id, keep_alive_secs,
|
||||
@ -594,7 +605,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
// existing workers to also have the updated context_view_id, so
|
||||
// we must set their context_view_id to the existing master's
|
||||
// context_view_id + 1.
|
||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
||||
sg.Update(CreateRemoteContexts(
|
||||
ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||
context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||
@ -604,20 +615,19 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
VLOG(1) << "Updating cluster with existing worker " << w;
|
||||
}
|
||||
}
|
||||
LOG_AND_RETURN_IF_ERROR(UpdateRemoteContexts(
|
||||
ctx, existing_workers, added_workers, removed_workers, context_id,
|
||||
context_view_id + 1, server_def, remote_eager_workers.get(),
|
||||
base_request));
|
||||
sg.Update(UpdateRemoteContexts(ctx, existing_workers, added_workers,
|
||||
removed_workers, context_id,
|
||||
context_view_id + 1, server_def,
|
||||
remote_eager_workers.get(), base_request));
|
||||
}
|
||||
}
|
||||
|
||||
tensorflow::RemoteRendezvous* r =
|
||||
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
|
||||
auto session_name = tensorflow::strings::StrCat("eager_", context_id);
|
||||
auto* device_mgr = grpc_server->worker_env()->device_mgr;
|
||||
std::shared_ptr<tensorflow::WorkerSession> worker_session;
|
||||
|
||||
if (reset_context) {
|
||||
tensorflow::RemoteRendezvous* r =
|
||||
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
|
||||
auto* device_mgr = grpc_server->worker_env()->device_mgr;
|
||||
std::shared_ptr<tensorflow::WorkerSession> worker_session;
|
||||
TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
|
||||
session_name, server_def, base_request.cluster_device_attributes(),
|
||||
true));
|
||||
@ -644,13 +654,13 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
// GrpcServer cannot be destroyed after it is started.
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
|
||||
} else {
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->worker_env()->session_mgr->UpdateSession(
|
||||
session_name, server_def, base_request.cluster_device_attributes(),
|
||||
true));
|
||||
LOG_AND_RETURN_IF_ERROR(context->UpdateRemoteMaster(
|
||||
grpc_server->worker_env(), std::move(remote_eager_workers),
|
||||
added_workers, removed_workers, context_id, r));
|
||||
sg.Update(grpc_server->worker_env()->session_mgr->UpdateSession(
|
||||
session_name, server_def, base_request.cluster_device_attributes(),
|
||||
/*isolate_session_state=*/true));
|
||||
sg.Update(context->UpdateRemoteMaster(context_id,
|
||||
std::move(remote_eager_workers),
|
||||
added_workers, removed_workers));
|
||||
LOG_AND_RETURN_IF_ERROR(sg.as_summary_status());
|
||||
}
|
||||
#undef LOG_AND_RETURN_IF_ERROR
|
||||
|
||||
@ -684,8 +694,13 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
|
||||
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
if (opts->use_tfrt) {
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
status->status = tensorflow::Status::OK();
|
||||
return new TFE_Context{new tfrt::ContextInterface()};
|
||||
tfrt::SmallVector<std::string, 4> op_handler_chains;
|
||||
tfrt::SmallVector<tensorflow::DeviceAttributes, 4> device_attributes;
|
||||
status->status = tfrt::ListOpHandlerChains(
|
||||
opts->session_options.options, &op_handler_chains, &device_attributes);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
return tensorflow::wrap(
|
||||
new tfrt::ContextInterface(op_handler_chains, device_attributes));
|
||||
#else
|
||||
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
||||
return nullptr;
|
||||
@ -702,14 +717,14 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
tensorflow::Rendezvous* r =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||
|
||||
return new TFE_Context{new tensorflow::EagerContext(
|
||||
return tensorflow::wrap(new tensorflow::EagerContext(
|
||||
opts->session_options.options,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
opts->device_placement_policy),
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
|
||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
|
||||
/*device_mgr_owned*/ true, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator())};
|
||||
tensorflow::GetDefaultCustomKernelCreator()));
|
||||
}
|
||||
|
||||
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
|
||||
@ -720,14 +735,14 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
|
||||
tensorflow::Rendezvous* r =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr);
|
||||
|
||||
return new TFE_Context{new tensorflow::EagerContext(
|
||||
return tensorflow::wrap(new tensorflow::EagerContext(
|
||||
opts->session_options.options,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
opts->device_placement_policy),
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
|
||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
|
||||
/*device_mgr_owned*/ false, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator())};
|
||||
tensorflow::GetDefaultCustomKernelCreator()));
|
||||
}
|
||||
|
||||
void TFE_DeleteContext(TFE_Context* ctx) {
|
||||
@ -735,23 +750,18 @@ void TFE_DeleteContext(TFE_Context* ctx) {
|
||||
return;
|
||||
}
|
||||
|
||||
// context->RefCountIsOne() should be true here.
|
||||
// TODO(iga): Remove EagerContext refcounting.
|
||||
ctx->context->Release();
|
||||
|
||||
delete ctx;
|
||||
// ctx->RefCountIsOne() should be true here.
|
||||
tensorflow::unwrap(ctx)->Release();
|
||||
}
|
||||
|
||||
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
|
||||
TF_DeviceList* l = new TF_DeviceList;
|
||||
ctx->context->ListDevices(&l->response);
|
||||
tensorflow::unwrap(ctx)->ListDevices(&l->response);
|
||||
return l;
|
||||
}
|
||||
|
||||
void TFE_ContextClearCaches(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
context->ClearCachesAndThreadExecutors();
|
||||
tensorflow::unwrap(ctx)->ClearCachesAndThreadExecutors();
|
||||
}
|
||||
|
||||
// Set server_def on the context, possibly updating it.
|
||||
@ -773,7 +783,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
||||
if (server_def.has_cluster_device_filters()) {
|
||||
const auto& cdf = server_def.cluster_device_filters();
|
||||
for (const auto& jdf : cdf.jobs()) {
|
||||
const string& remote_prefix = "/job:" + jdf.name() + "/task:";
|
||||
const string remote_prefix = "/job:" + jdf.name() + "/task:";
|
||||
for (const auto& tdf : jdf.tasks()) {
|
||||
const int32_t task_index = tdf.first;
|
||||
std::vector<string> device_filters(tdf.second.device_filters_size());
|
||||
@ -782,7 +792,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
||||
}
|
||||
const string remote_worker = remote_prefix + std::to_string(task_index);
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status =
|
||||
context->SetRemoteDeviceFilters(remote_worker, device_filters);
|
||||
}
|
||||
@ -804,7 +814,7 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
tensorflow::ServerDef server_def;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
if (!server_def.ParseFromArray(proto, proto_len)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Invalid tensorflow.ServerDef protocol buffer");
|
||||
@ -834,7 +844,7 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
return false;
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::GrpcServer* grpc_server =
|
||||
static_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||
|
||||
@ -890,7 +900,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
||||
status->status = tensorflow::Status::OK();
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->SyncExecutors();
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
@ -898,7 +908,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
||||
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetThreadLocalDevicePlacementPolicy(
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
|
||||
}
|
||||
@ -909,7 +919,7 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
|
||||
TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
return static_cast<TFE_ContextDevicePlacementPolicy>(
|
||||
context->GetDevicePlacementPolicy());
|
||||
}
|
||||
@ -919,8 +929,7 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
|
||||
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
return new TFE_TensorHandle{
|
||||
tensorflow::TensorHandle::CreateLocalHandle(tensor)};
|
||||
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
|
||||
}
|
||||
|
||||
void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
|
||||
@ -928,84 +937,84 @@ void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
|
||||
|
||||
tensorflow::profiler::TraceMe activity(
|
||||
"TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
|
||||
if (h->handle) {
|
||||
h->handle->Release();
|
||||
if (h) {
|
||||
tensorflow::unwrap(h)->Release();
|
||||
}
|
||||
delete h;
|
||||
}
|
||||
|
||||
TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
|
||||
return static_cast<TF_DataType>(h->handle->DataType());
|
||||
return static_cast<TF_DataType>(tensorflow::unwrap(h)->DataType());
|
||||
}
|
||||
|
||||
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return -1;
|
||||
}
|
||||
|
||||
int num_dims = -1;
|
||||
status->status = h->handle->NumDims(&num_dims);
|
||||
status->status = tensorflow::unwrap(h)->NumDims(&num_dims);
|
||||
return num_dims;
|
||||
}
|
||||
|
||||
int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return -1;
|
||||
}
|
||||
|
||||
int64 num_elements = -1;
|
||||
status->status = h->handle->NumElements(&num_elements);
|
||||
status->status = tensorflow::unwrap(h)->NumElements(&num_elements);
|
||||
return num_elements;
|
||||
}
|
||||
|
||||
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
|
||||
TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return -1;
|
||||
}
|
||||
|
||||
int64 dim = -1;
|
||||
status->status = h->handle->Dim(dim_index, &dim);
|
||||
status->status = tensorflow::unwrap(h)->Dim(dim_index, &dim);
|
||||
return dim;
|
||||
}
|
||||
|
||||
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return nullptr;
|
||||
}
|
||||
return h->handle->DeviceName(&status->status);
|
||||
return tensorflow::unwrap(h)->DeviceName(&status->status);
|
||||
}
|
||||
|
||||
const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
|
||||
TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return nullptr;
|
||||
}
|
||||
return h->handle->BackingDeviceName(&status->status);
|
||||
return tensorflow::unwrap(h)->BackingDeviceName(&status->status);
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
|
||||
TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return new TFE_TensorHandle{h->handle->Copy()};
|
||||
return tensorflow::wrap(tensorflow::unwrap(h)->Copy());
|
||||
}
|
||||
|
||||
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tensorflow::AbstractTensorInterface* t = h->handle->Resolve(&status->status);
|
||||
tensorflow::AbstractTensorInterface* t =
|
||||
tensorflow::unwrap(h)->Resolve(&status->status);
|
||||
if (t == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -1014,22 +1023,22 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
}
|
||||
|
||||
void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::TensorHandleFromInterface(h->handle);
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
|
||||
if (VariantDeviceIsCustom(handle->device())) {
|
||||
const tensorflow::Tensor* t;
|
||||
status->status = handle->Tensor(&t);
|
||||
return t->data();
|
||||
}
|
||||
|
||||
if (handle->IsRemote()) {
|
||||
if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"TFE_TensorHandleDevicePointer may not be called on a remote tensor "
|
||||
"handle.");
|
||||
"TFE_TensorHandleDevicePointer may not be called on a ",
|
||||
handle->TypeString(), " tensor handle.");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::Device* device(absl::get<tensorflow::Device*>(handle->device()));
|
||||
@ -1055,7 +1064,7 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
void* deallocator_arg, TF_Status* status) {
|
||||
tensorflow::Device* device = nullptr;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->FindDeviceFromName(device_name, &device);
|
||||
tensorflow::CustomDevice* custom_device = nullptr;
|
||||
if (!status->status.ok()) {
|
||||
@ -1081,11 +1090,11 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
tensorflow::TensorShape(dimvec), buf);
|
||||
buf->Unref();
|
||||
if (custom_device == nullptr) {
|
||||
return new TFE_TensorHandle{tensorflow::TensorHandle::CreateLocalHandle(
|
||||
std::move(t), device, device, context)};
|
||||
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
|
||||
std::move(t), device, device, context));
|
||||
} else {
|
||||
return new TFE_TensorHandle{tensorflow::TensorHandle::CreateLocalHandle(
|
||||
std::move(t), custom_device, context)};
|
||||
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
|
||||
std::move(t), custom_device, context));
|
||||
}
|
||||
}
|
||||
|
||||
@ -1094,16 +1103,16 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
// bytes of the memory pointed to by the device pointer returned above.
|
||||
size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
||||
TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return 0;
|
||||
}
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::TensorHandleFromInterface(h->handle);
|
||||
if (handle->IsRemote()) {
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
|
||||
if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"TFE_TensorHandleDeviceMemorySize may not be called on a remote tensor "
|
||||
"handle.");
|
||||
"TFE_TensorHandleDeviceMemorySize may not be called on a ",
|
||||
handle->TypeString(), " tensor handle.");
|
||||
return 0;
|
||||
}
|
||||
const tensorflow::Tensor* tensor;
|
||||
@ -1116,12 +1125,14 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
||||
|
||||
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op> new_op(new TFE_Op{ctx->context->CreateOperation()});
|
||||
status->status = new_op->operation->Reset(op_or_function_name, nullptr);
|
||||
tensorflow::AbstractOperationInterface* new_op =
|
||||
tensorflow::unwrap(ctx)->CreateOperation();
|
||||
status->status = new_op->Reset(op_or_function_name, nullptr);
|
||||
if (!status->status.ok()) {
|
||||
new_op.reset();
|
||||
new_op->Release();
|
||||
new_op = nullptr;
|
||||
}
|
||||
return new_op.release();
|
||||
return tensorflow::wrap(new_op);
|
||||
}
|
||||
|
||||
void TFE_DeleteOp(TFE_Op* op) {
|
||||
@ -1129,24 +1140,20 @@ void TFE_DeleteOp(TFE_Op* op) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (op->operation) {
|
||||
op->operation->Release();
|
||||
}
|
||||
|
||||
delete op;
|
||||
tensorflow::unwrap(op)->Release();
|
||||
}
|
||||
|
||||
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
|
||||
status->status = op->operation->SetDeviceName(device_name);
|
||||
status->status = tensorflow::unwrap(op)->SetDeviceName(device_name);
|
||||
}
|
||||
|
||||
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
|
||||
return op->operation->DeviceName().c_str();
|
||||
return tensorflow::unwrap(op)->DeviceName().c_str();
|
||||
}
|
||||
|
||||
void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
tensorflow::Status s = op->operation->SetUseXla(enable);
|
||||
tensorflow::Status s = tensorflow::unwrap(op)->SetUseXla(enable);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Could not enable XLA compilation for op: " << s;
|
||||
}
|
||||
@ -1157,18 +1164,13 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
|
||||
}
|
||||
|
||||
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
|
||||
status->status = op->operation->AddInput(input->handle);
|
||||
status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input));
|
||||
}
|
||||
|
||||
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
|
||||
TF_Status* status) {
|
||||
absl::FixedArray<tensorflow::AbstractTensorHandleInterface*> handles(
|
||||
num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
handles[i] = inputs[i]->handle;
|
||||
}
|
||||
status->status =
|
||||
op->operation->AddInputList({handles.data(), handles.size()});
|
||||
status->status = tensorflow::unwrap(op)->AddInputList(
|
||||
{tensorflow::unwrap(inputs), static_cast<size_t>(num_inputs)});
|
||||
}
|
||||
|
||||
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
||||
@ -1176,8 +1178,8 @@ TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
||||
TF_AttrType ret = TF_ATTR_INT;
|
||||
const tensorflow::AttrTypeMap* attr_types_;
|
||||
bool is_function;
|
||||
status->status = tensorflow::AttrTypeMapForOp(op->operation->Name().c_str(),
|
||||
&attr_types_, &is_function);
|
||||
status->status = tensorflow::AttrTypeMapForOp(
|
||||
tensorflow::unwrap(op)->Name().c_str(), &attr_types_, &is_function);
|
||||
if (!status->status.ok()) {
|
||||
return ret;
|
||||
}
|
||||
@ -1203,7 +1205,7 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
|
||||
|
||||
void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
|
||||
size_t length) {
|
||||
auto s = op->operation->SetAttrString(
|
||||
auto s = tensorflow::unwrap(op)->SetAttrString(
|
||||
attr_name, static_cast<const char*>(value), length);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
@ -1211,29 +1213,30 @@ void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
|
||||
auto s = op->operation->SetAttrInt(attr_name, value);
|
||||
auto s = tensorflow::unwrap(op)->SetAttrInt(attr_name, value);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
|
||||
auto s = op->operation->SetAttrFloat(attr_name, value);
|
||||
auto s = tensorflow::unwrap(op)->SetAttrFloat(attr_name, value);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
|
||||
auto s = op->operation->SetAttrBool(attr_name, (value == 0) ? false : true);
|
||||
auto s = tensorflow::unwrap(op)->SetAttrBool(attr_name,
|
||||
(value == 0) ? false : true);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
|
||||
auto s = op->operation->SetAttrType(attr_name,
|
||||
static_cast<tensorflow::DataType>(value));
|
||||
auto s = tensorflow::unwrap(op)->SetAttrType(
|
||||
attr_name, static_cast<tensorflow::DataType>(value));
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
@ -1241,12 +1244,14 @@ void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
|
||||
|
||||
void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
|
||||
const int num_dims, TF_Status* out_status) {
|
||||
out_status->status = op->operation->SetAttrShape(attr_name, dims, num_dims);
|
||||
out_status->status =
|
||||
tensorflow::unwrap(op)->SetAttrShape(attr_name, dims, num_dims);
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
|
||||
const TFE_Op* value) {
|
||||
auto s = op->operation->SetAttrFunction(attr_name, value->operation);
|
||||
auto s = tensorflow::unwrap(op)->SetAttrFunction(
|
||||
attr_name, tensorflow::unwrap(const_cast<TFE_Op*>(value)));
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
@ -1254,7 +1259,7 @@ void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
|
||||
|
||||
void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
|
||||
const char* data, size_t length) {
|
||||
auto s = op->operation->SetAttrFunctionName(attr_name, data, length);
|
||||
auto s = tensorflow::unwrap(op)->SetAttrFunctionName(attr_name, data, length);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
@ -1265,14 +1270,14 @@ void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
|
||||
tensorflow::Tensor t;
|
||||
status->status = TF_TensorToTensor(tensor, &t);
|
||||
tensorflow::TensorInterface interface(t);
|
||||
status->status = op->operation->SetAttrTensor(attr_name, &interface);
|
||||
status->status = tensorflow::unwrap(op)->SetAttrTensor(attr_name, &interface);
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
|
||||
const void* const* values, const size_t* lengths,
|
||||
int num_values) {
|
||||
auto s =
|
||||
op->operation->SetAttrStringList(attr_name, values, lengths, num_values);
|
||||
auto s = tensorflow::unwrap(op)->SetAttrStringList(attr_name, values, lengths,
|
||||
num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
@ -1280,7 +1285,8 @@ void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
|
||||
|
||||
void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
|
||||
const float* values, int num_values) {
|
||||
auto s = op->operation->SetAttrFloatList(attr_name, values, num_values);
|
||||
auto s =
|
||||
tensorflow::unwrap(op)->SetAttrFloatList(attr_name, values, num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
@ -1288,7 +1294,8 @@ void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
|
||||
|
||||
void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
|
||||
const int64_t* values, int num_values) {
|
||||
auto s = op->operation->SetAttrIntList(attr_name, values, num_values);
|
||||
auto s =
|
||||
tensorflow::unwrap(op)->SetAttrIntList(attr_name, values, num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
@ -1296,7 +1303,7 @@ void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
|
||||
|
||||
void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
|
||||
const TF_DataType* values, int num_values) {
|
||||
auto s = op->operation->SetAttrTypeList(
|
||||
auto s = tensorflow::unwrap(op)->SetAttrTypeList(
|
||||
attr_name, reinterpret_cast<const tensorflow::DataType*>(values),
|
||||
num_values);
|
||||
if (!s.ok()) {
|
||||
@ -1306,7 +1313,8 @@ void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
|
||||
|
||||
void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
|
||||
const unsigned char* values, int num_values) {
|
||||
auto s = op->operation->SetAttrBoolList(attr_name, values, num_values);
|
||||
auto s =
|
||||
tensorflow::unwrap(op)->SetAttrBoolList(attr_name, values, num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
@ -1315,19 +1323,14 @@ void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
|
||||
void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
|
||||
const int64_t** dims, const int* num_dims,
|
||||
int num_values, TF_Status* out_status) {
|
||||
out_status->status =
|
||||
op->operation->SetAttrShapeList(attr_name, dims, num_dims, num_values);
|
||||
out_status->status = tensorflow::unwrap(op)->SetAttrShapeList(
|
||||
attr_name, dims, num_dims, num_values);
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
|
||||
const TFE_Op** value, int num_values) {
|
||||
absl::FixedArray<const tensorflow::AbstractOperationInterface*> values(
|
||||
num_values);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
values[i] = value[i]->operation;
|
||||
}
|
||||
auto s = op->operation->SetAttrFunctionList(attr_name,
|
||||
{values.data(), values.size()});
|
||||
auto s = tensorflow::unwrap(op)->SetAttrFunctionList(
|
||||
attr_name, {tensorflow::unwrap(value), static_cast<size_t>(num_values)});
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
@ -1342,12 +1345,13 @@ void TFE_OpSetAttrValueProto(const TFE_Op* op, const char* attr_name,
|
||||
tensorflow::errors::InvalidArgument("Unparseable AttrValue proto");
|
||||
return;
|
||||
}
|
||||
if (op == nullptr || op->operation == nullptr) {
|
||||
if (op == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Got a null or uninitialized `op` argument");
|
||||
return;
|
||||
}
|
||||
tensorflow::EagerOperation* operation = OperationFromInterface(op->operation);
|
||||
tensorflow::EagerOperation* operation =
|
||||
OperationFromInterface(tensorflow::unwrap(const_cast<TFE_Op*>(op)));
|
||||
operation->MutableAttrs()->Set(attr_name, attr_value);
|
||||
}
|
||||
|
||||
@ -1355,7 +1359,7 @@ TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
|
||||
const char* input_name,
|
||||
TF_Status* status) {
|
||||
int ret = -1;
|
||||
status->status = op->operation->InputLength(input_name, &ret);
|
||||
status->status = tensorflow::unwrap(op)->InputLength(input_name, &ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@ -1363,71 +1367,29 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
|
||||
const char* output_name,
|
||||
TF_Status* status) {
|
||||
int ret = -1;
|
||||
status->status = op->operation->OutputLength(output_name, &ret);
|
||||
status->status = tensorflow::unwrap(op)->OutputLength(output_name, &ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||
TF_Status* status) {
|
||||
absl::FixedArray<tensorflow::AbstractTensorHandleInterface*> handles(
|
||||
*num_retvals);
|
||||
status->status = op->operation->Execute(absl::MakeSpan(handles), num_retvals);
|
||||
if (!status->status.ok()) {
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals[i] = new TFE_TensorHandle{handles[i]};
|
||||
}
|
||||
status->status = tensorflow::unwrap(op)->Execute(
|
||||
absl::MakeSpan(tensorflow::unwrap(retvals), *num_retvals), num_retvals);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
TFE_Context* ctx,
|
||||
const char* device_name,
|
||||
TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tensorflow::TensorHandle* handle = nullptr;
|
||||
tensorflow::Device* device;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
status->status = context->FindDeviceFromName(device_name, &device);
|
||||
if (!status->status.ok()) {
|
||||
tensorflow::CustomDevice* dev;
|
||||
status->status = context->FindCustomDeviceFromName(device_name, &dev);
|
||||
if (status->status.ok()) {
|
||||
status->status = dev->CopyTensorToDevice(
|
||||
tensorflow::TensorHandleFromInterface(h->handle), &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle{handle};
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
// Handle tensor handles currently in custom devices
|
||||
const char* handle_device_name = h->handle->DeviceName(&status->status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::CustomDevice* dev;
|
||||
status->status = context->FindCustomDeviceFromName(handle_device_name, &dev);
|
||||
auto* result = tensorflow::unwrap(ctx)->CopyTensorHandleToDevice(
|
||||
tensorflow::unwrap(h), device_name, &status->status);
|
||||
if (status->status.ok()) {
|
||||
status->status = dev->CopyTensorFromDevice(
|
||||
tensorflow::TensorHandleFromInterface(h->handle), device_name, &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle{handle};
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Handle regular case.
|
||||
status->status = tensorflow::EagerCopyToDevice(
|
||||
tensorflow::TensorHandleFromInterface(h->handle), context,
|
||||
&context->Executor(), device, false, &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle{handle};
|
||||
return tensorflow::wrap(result);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
@ -1442,39 +1404,39 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
|
||||
return;
|
||||
}
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->AddFunctionDef(function_def);
|
||||
}
|
||||
|
||||
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->AddFunctionDef(function->fdef);
|
||||
}
|
||||
|
||||
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->RemoveFunction(name);
|
||||
}
|
||||
|
||||
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
return context->FindFunctionDef(name) != nullptr;
|
||||
}
|
||||
|
||||
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetShouldStoreGraphs(true);
|
||||
}
|
||||
|
||||
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetShouldStoreGraphs(false);
|
||||
}
|
||||
|
||||
@ -1482,13 +1444,13 @@ void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
|
||||
TF_Status* status) {
|
||||
return new TFE_TensorHandle{tensorflow::TensorHandle::CreateLocalHandle(t)};
|
||||
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(t));
|
||||
}
|
||||
|
||||
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->Executor().WaitForAllPendingNodes();
|
||||
if (!status->status.ok()) return;
|
||||
tensorflow::mutex_lock ml(*context->MetadataMu());
|
||||
@ -1510,26 +1472,23 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
|
||||
} // namespace
|
||||
|
||||
void TFE_ContextStartStep(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
context->StartStep();
|
||||
tensorflow::unwrap(ctx)->StartStep();
|
||||
}
|
||||
|
||||
void TFE_ContextEndStep(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
context->EndStep();
|
||||
tensorflow::unwrap(ctx)->EndStep();
|
||||
}
|
||||
|
||||
void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) {
|
||||
tensorflow::EagerOperation* operation = OperationFromInterface(op->operation);
|
||||
*attrs = TFE_OpAttrs(&operation->Attrs(), operation->Name().c_str());
|
||||
const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) {
|
||||
return tensorflow::wrap(
|
||||
&OperationFromInterface(tensorflow::unwrap(op))->Attrs());
|
||||
}
|
||||
|
||||
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
|
||||
tensorflow::AttrValueMap m;
|
||||
attrs->attributes->FillAttrValueMap(&m);
|
||||
tensorflow::EagerOperation* operation = OperationFromInterface(op->operation);
|
||||
tensorflow::unwrap(attrs)->FillAttrValueMap(&m);
|
||||
tensorflow::EagerOperation* operation =
|
||||
OperationFromInterface(tensorflow::unwrap(op));
|
||||
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
|
||||
for (const auto& attribute : m) {
|
||||
destination->Set(attribute.first, attribute.second);
|
||||
@ -1539,8 +1498,8 @@ void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
|
||||
void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
|
||||
TF_Status* status) {
|
||||
tensorflow::NameAttrList name_and_attrs;
|
||||
attrs->attributes->FillAttrValueMap(name_and_attrs.mutable_attr());
|
||||
name_and_attrs.set_name(attrs->name);
|
||||
tensorflow::unwrap(attrs)->FillAttrValueMap(name_and_attrs.mutable_attr());
|
||||
name_and_attrs.set_name(tensorflow::unwrap(attrs)->op_name());
|
||||
status->status = MessageToBuffer(name_and_attrs, buf);
|
||||
}
|
||||
|
||||
@ -1617,33 +1576,34 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
const string& name() override { return name_; }
|
||||
|
||||
tensorflow::Status CopyTensorToDevice(
|
||||
tensorflow::TensorHandle* tensor,
|
||||
tensorflow::TensorHandle* handle,
|
||||
tensorflow::TensorHandle** result) override {
|
||||
tensor->Ref();
|
||||
TFE_TensorHandle tensor_handle{tensor};
|
||||
handle->Ref();
|
||||
TF_Status status;
|
||||
TFE_TensorHandle* result_handle =
|
||||
device_.copy_tensor_to_device(context_, &tensor_handle, &status, info_);
|
||||
tensor_handle.handle->Release();
|
||||
TFE_TensorHandle* result_handle = device_.copy_tensor_to_device(
|
||||
context_, tensorflow::wrap(handle), &status, info_);
|
||||
handle->Release();
|
||||
if (!status.status.ok()) return status.status;
|
||||
*result = tensorflow::TensorHandleFromInterface(result_handle->handle);
|
||||
*result = tensorflow::TensorHandleFromInterface(
|
||||
tensorflow::unwrap(result_handle));
|
||||
(*result)->Ref();
|
||||
TFE_DeleteTensorHandle(result_handle);
|
||||
return status.status;
|
||||
}
|
||||
|
||||
tensorflow::Status CopyTensorFromDevice(
|
||||
tensorflow::TensorHandle* tensor,
|
||||
tensorflow::TensorHandle* handle,
|
||||
const tensorflow::string& target_device_name,
|
||||
tensorflow::TensorHandle** result) override {
|
||||
TF_Status status;
|
||||
tensor->Ref();
|
||||
TFE_TensorHandle tensor_handle{tensor};
|
||||
handle->Ref();
|
||||
TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
|
||||
context_, &tensor_handle, target_device_name.c_str(), &status, info_);
|
||||
tensor_handle.handle->Release();
|
||||
context_, tensorflow::wrap(handle), target_device_name.c_str(), &status,
|
||||
info_);
|
||||
handle->Release();
|
||||
if (!status.status.ok()) return status.status;
|
||||
*result = tensorflow::TensorHandleFromInterface(result_handle->handle);
|
||||
*result = tensorflow::TensorHandleFromInterface(
|
||||
tensorflow::unwrap(result_handle));
|
||||
(*result)->Ref();
|
||||
TFE_DeleteTensorHandle(result_handle);
|
||||
return status.status;
|
||||
@ -1656,16 +1616,17 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
inputs.reserve(op->Inputs().size());
|
||||
for (int i = 0; i < op->Inputs().size(); ++i) {
|
||||
op->Inputs()[i]->Ref();
|
||||
inputs.push_back(new TFE_TensorHandle{op->Inputs()[i]});
|
||||
inputs.push_back(tensorflow::wrap(op->Inputs()[i]));
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
|
||||
TF_Status status;
|
||||
TFE_OpAttrs attributes(&op->Attrs(), op->Name().c_str());
|
||||
device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(),
|
||||
&attributes, num_retvals, outputs.data(), &status, info_);
|
||||
wrap(&op->Attrs()), num_retvals, outputs.data(), &status,
|
||||
info_);
|
||||
if (status.status.ok()) {
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals[i] = tensorflow::TensorHandleFromInterface(outputs[i]->handle);
|
||||
retvals[i] = tensorflow::TensorHandleFromInterface(
|
||||
tensorflow::unwrap(outputs[i]));
|
||||
retvals[i]->Ref();
|
||||
TFE_DeleteTensorHandle(outputs[i]);
|
||||
}
|
||||
@ -1693,7 +1654,7 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
||||
auto custom_device =
|
||||
std::make_unique<CustomDeviceAPI>(ctx, device, device_info, device_name);
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status =
|
||||
context->RegisterCustomDevice(device_name, std::move(custom_device));
|
||||
}
|
||||
|
313
tensorflow/c/eager/c_api_cluster_test.cc
Normal file
313
tensorflow/c/eager/c_api_cluster_test.cc
Normal file
@ -0,0 +1,313 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using ::tensorflow::string;
|
||||
|
||||
tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
|
||||
tensorflow::ServerDef server_def;
|
||||
server_def.set_protocol("grpc");
|
||||
server_def.set_job_name(job_name);
|
||||
server_def.set_task_index(0);
|
||||
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
|
||||
tensorflow::JobDef* job_def = cluster_def->add_job();
|
||||
job_def->set_name(job_name);
|
||||
for (int i = 0; i < num_tasks; i++) {
|
||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||
job_def->mutable_tasks()->insert(
|
||||
{i, tensorflow::strings::StrCat("localhost:", port)});
|
||||
}
|
||||
return server_def;
|
||||
}
|
||||
|
||||
tensorflow::ServerDef GetServerDef(int num_tasks) {
|
||||
return GetServerDef("localhost", num_tasks);
|
||||
}
|
||||
|
||||
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
|
||||
const std::vector<float>& expected_values) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
std::unique_ptr<float[]> actual_values(new float[expected_values.size()]);
|
||||
EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t));
|
||||
memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
|
||||
for (int i = 0; i < expected_values.size(); i++) {
|
||||
EXPECT_EQ(expected_values[i], actual_values[i])
|
||||
<< "Mismatch in expected values at (zero-based) index " << i;
|
||||
}
|
||||
}
|
||||
|
||||
void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
|
||||
const char* remote_device_name,
|
||||
const char* local_device_name) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
|
||||
|
||||
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0);
|
||||
TFE_OpSetDevice(matmul, remote_device_name, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
auto* retval_task0 =
|
||||
TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22});
|
||||
|
||||
TFE_DeleteTensorHandle(retval_task0);
|
||||
TFE_DeleteTensorHandle(h0_task0);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
|
||||
TFE_DeleteOp(matmul);
|
||||
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteChangeServerDef(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
const char remote_device_name[] =
|
||||
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
const char local_device_name[] =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
|
||||
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
|
||||
// Update the server def with a new set of names (worker instead of
|
||||
// localhost).
|
||||
tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2);
|
||||
serialized = updated_server_def.SerializeAsString();
|
||||
|
||||
updated_server_def.set_task_index(1);
|
||||
tensorflow::Status s = tensorflow::GrpcServer::Create(
|
||||
updated_server_def, tensorflow::Env::Default(), &worker_server);
|
||||
ASSERT_TRUE(s.ok()) << s.error_message();
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Create a new tensor_handle.
|
||||
TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle(ctx);
|
||||
|
||||
// Check that copying it to the old remote device (named localhost) fails.
|
||||
TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status);
|
||||
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Copying and executing on the new remote device works.
|
||||
const char new_remote_device_name[] =
|
||||
"/job:worker/replica:0/task:1/device:CPU:0";
|
||||
const char new_local_device_name[] =
|
||||
"/job:worker/replica:0/task:0/device:CPU:0";
|
||||
|
||||
auto* h0_task1_new = TFE_TensorHandleCopyToDevice(
|
||||
h0_task0_new, ctx, new_remote_device_name, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_DeleteTensorHandle(h0_task0_new);
|
||||
TFE_DeleteTensorHandle(h0_task1_new);
|
||||
|
||||
CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name,
|
||||
new_local_device_name);
|
||||
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteChangeServerDef) {
|
||||
TestRemoteExecuteChangeServerDef(false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteChangeServerDefAsync) {
|
||||
TestRemoteExecuteChangeServerDef(true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteUpdateServerDef(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
const char local_device_name[] =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
const char remote_device_name[] =
|
||||
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
|
||||
|
||||
TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(),
|
||||
status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteUpdateServerDef) {
|
||||
TestRemoteExecuteUpdateServerDef(false);
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteUpdateServerDefAsync) {
|
||||
TestRemoteExecuteUpdateServerDef(true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteUpdateServerDefWithFailures(bool async) {
|
||||
// Fail fast on GetStatus requests so we can get errors instead of timeout
|
||||
// when updating cluster with non-exsitent worker
|
||||
tensorflow::setenv("GRPC_FAIL_FAST", "TRUE", /*overwrite=*/1);
|
||||
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
const char local_device_name[] =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
const char remote_device_name[] =
|
||||
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
|
||||
|
||||
// Adding a non-existent remote worker to cluster def. This should cause the
|
||||
// UpdateServerDef call to fail.
|
||||
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
|
||||
tensorflow::JobDef* job_def = cluster_def->mutable_job(0);
|
||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||
job_def->mutable_tasks()->insert(
|
||||
{2, tensorflow::strings::StrCat("localhost:", port)});
|
||||
string serialized_update = server_def.SerializeAsString();
|
||||
TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
|
||||
serialized_update.size(), status);
|
||||
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Even after the prevoiusly failed cluster update, another update and op
|
||||
// execution should work fine as long as the provided server_def is valid.
|
||||
TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(),
|
||||
status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
tensorflow::unsetenv("GRPC_FAIL_FAST");
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteUpdateServerDefWithFailures) {
|
||||
TestRemoteExecuteUpdateServerDefWithFailures(false);
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteUpdateServerDefWithFailuresAsync) {
|
||||
TestRemoteExecuteUpdateServerDefWithFailures(true);
|
||||
}
|
||||
|
||||
} // namespace
|
@ -17,8 +17,11 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensor_debug_info_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
@ -54,7 +57,8 @@ extern "C" {
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
TFE_TensorHandle* h, TF_Status* status) {
|
||||
tensorflow::TensorHandle* handle = TensorHandleFromInterface(h->handle);
|
||||
tensorflow::TensorHandle* handle =
|
||||
TensorHandleFromInterface(tensorflow::unwrap(h));
|
||||
const tensorflow::Tensor* tensor;
|
||||
status->status = handle->Tensor(&tensor);
|
||||
if (!status->status.ok()) {
|
||||
|
@ -19,6 +19,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
@ -34,9 +37,10 @@ using tensorflow::string;
|
||||
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
||||
const char* raw_device_name, TF_Status* status) {
|
||||
if (op_to_reset) {
|
||||
op_to_reset->operation->Clear();
|
||||
status->status =
|
||||
op_to_reset->operation->Reset(op_or_function_name, raw_device_name);
|
||||
tensorflow::AbstractOperationInterface* op =
|
||||
tensorflow::unwrap(op_to_reset);
|
||||
op->Clear();
|
||||
status->status = op->Reset(op_or_function_name, raw_device_name);
|
||||
} else {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"op_to_reset should not be nullptr");
|
||||
@ -45,13 +49,13 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
||||
|
||||
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetShouldStoreGraphs(true);
|
||||
}
|
||||
|
||||
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetShouldStoreGraphs(false);
|
||||
}
|
||||
|
||||
@ -483,7 +487,7 @@ void TFE_ContextOptionsSetMirroringPolicy(TFE_ContextOptions* options,
|
||||
void TFE_ContextSetThreadLocalMirroringPolicy(
|
||||
TFE_Context* ctx, TFE_ContextMirroringPolicy policy) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetThreadLocalMirroringPolicy(
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(policy));
|
||||
}
|
||||
@ -494,7 +498,7 @@ void TFE_ContextSetThreadLocalMirroringPolicy(
|
||||
extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
|
||||
TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
return static_cast<TFE_ContextMirroringPolicy>(context->GetMirroringPolicy());
|
||||
}
|
||||
|
||||
@ -530,7 +534,7 @@ void TFE_OpSetCancellationManager(TFE_Op* op,
|
||||
TFE_CancellationManager* cancellation_manager,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerOperation* operation =
|
||||
tensorflow::OperationFromInterface(op->operation);
|
||||
tensorflow::OperationFromInterface(tensorflow::unwrap(op));
|
||||
operation->SetCancellationManager(
|
||||
&cancellation_manager->cancellation_manager);
|
||||
status->status = tensorflow::Status::OK();
|
||||
@ -557,19 +561,19 @@ void TFE_ExecutorClearError(TFE_Executor* executor) {
|
||||
|
||||
void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetExecutorForThread(executor->executor());
|
||||
}
|
||||
|
||||
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
return new TFE_Executor(&context->Executor());
|
||||
}
|
||||
|
||||
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
|
||||
context->HostCPU()->parsed_name());
|
||||
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
|
||||
@ -585,7 +589,7 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
||||
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
|
||||
TF_Buffer* buf, TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
auto* function_def = context->FindFunctionDef(function_name);
|
||||
if (function_def == nullptr) {
|
||||
status->status = tensorflow::errors::NotFound(
|
||||
@ -611,13 +615,14 @@ TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype,
|
||||
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
||||
}
|
||||
|
||||
if (ctx == nullptr || ctx->context == nullptr) {
|
||||
if (ctx == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid Context");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tensorflow::AbstractTensorInterface* t = ctx->context->CreateTensor(
|
||||
static_cast<tensorflow::DataType>(dtype), dimvec);
|
||||
tensorflow::AbstractTensorInterface* t =
|
||||
tensorflow::unwrap(ctx)->CreateTensor(
|
||||
static_cast<tensorflow::DataType>(dtype), dimvec);
|
||||
|
||||
if (t == nullptr) {
|
||||
status->status =
|
||||
@ -630,5 +635,6 @@ TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype,
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t,
|
||||
TF_Status* status) {
|
||||
return new TFE_TensorHandle{ctx->context->CreateLocalHandle(t->tensor)};
|
||||
return tensorflow::wrap(
|
||||
tensorflow::unwrap(ctx)->CreateLocalHandle(t->tensor));
|
||||
}
|
||||
|
@ -431,11 +431,9 @@ TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
|
||||
// A reference to an op's name -> attribute mapping
|
||||
typedef struct TFE_OpAttrs TFE_OpAttrs;
|
||||
|
||||
// Fetch a struct with a reference to information about attributes of `op`.
|
||||
//
|
||||
// The `attrs` struct does not own any memory, and `op` must outlive it.
|
||||
TF_CAPI_EXPORT extern void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs);
|
||||
|
||||
// Fetch a reference to `op`'s attributes. The returned reference is only valid
|
||||
// while `op` is alive.
|
||||
const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op);
|
||||
// Add attributes in `attrs` to `op`.
|
||||
//
|
||||
// Does not overwrite or update existing attributes, but adds new ones.
|
||||
|
@ -15,39 +15,17 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/context_interface.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
#include "tensorflow/c/eager/tfe_cancellation_manager_internal.h" // IWYU pragma: export
|
||||
#include "tensorflow/c/eager/tfe_executor_internal.h" // IWYU pragma: export
|
||||
#include "tensorflow/c/eager/tfe_monitoring_internal.h" // IWYU pragma: export
|
||||
#include "tensorflow/c/eager/tfe_op_attrs_internal.h" // IWYU pragma: export
|
||||
#include "tensorflow/c/eager/tfe_tensor_debug_info_internal.h" // IWYU pragma: export
|
||||
|
||||
// TODO(b/154564140): Move this to its own header. This requires splitting
|
||||
// c_api_experimental.h
|
||||
struct TFE_ContextOptions {
|
||||
TF_SessionOptions session_options;
|
||||
// true if async execution is enabled.
|
||||
@ -61,199 +39,4 @@ struct TFE_ContextOptions {
|
||||
bool use_tfrt = false;
|
||||
};
|
||||
|
||||
// Wraps a pointer to a context implementation.
|
||||
//
|
||||
// WARNING: Since the underlying object could be ref-counted a user of this
|
||||
// interface cannot destruct the underlying context object. Instead, call
|
||||
// TFE_DeleteContext who calls Release() on the context pointer and deletes
|
||||
// the TFE_Context structure.
|
||||
struct TFE_Context {
|
||||
tensorflow::AbstractContextInterface* context;
|
||||
};
|
||||
|
||||
// Wraps a pointer to a tensor handle implementation.
|
||||
//
|
||||
// WARNING: Since the underlying object could be ref-counted a user of this
|
||||
// interface cannot destruct the underlying handle object. Instead, call
|
||||
// TFE_DeleteTensorHandle who calls Release() on the handle pointer and deletes
|
||||
// the TFE_TensorHandle structure.
|
||||
struct TFE_TensorHandle {
|
||||
tensorflow::AbstractTensorHandleInterface* handle;
|
||||
};
|
||||
|
||||
struct TFE_TensorDebugInfo {
|
||||
explicit TFE_TensorDebugInfo(const std::vector<tensorflow::int64>& dims)
|
||||
: dev_dims(dims) {}
|
||||
|
||||
// Fully-padded, minor-to-major.
|
||||
std::vector<tensorflow::int64> dev_dims;
|
||||
};
|
||||
|
||||
// Wraps a pointer to an operation implementation.
|
||||
//
|
||||
// WARNING: Since the underlying object could be ref-counted a user of this
|
||||
// interface cannot destruct the underlying operation object. Instead, call
|
||||
// TFE_DeleteOp who calls Release() on the operation pointer and deletes
|
||||
// the TFE_Op structure.
|
||||
struct TFE_Op {
|
||||
tensorflow::AbstractOperationInterface* operation;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringCounterCell {
|
||||
tensorflow::monitoring::CounterCell cell;
|
||||
};
|
||||
|
||||
template <int NumLabels>
|
||||
struct TFE_MonitoringCounter {
|
||||
template <typename... LabelDesc>
|
||||
TFE_MonitoringCounter(const char* name, const char* description,
|
||||
LabelDesc&&... label) {
|
||||
counter = absl::WrapUnique(tensorflow::monitoring::Counter<NumLabels>::New(
|
||||
name, description, label...));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::monitoring::Counter<NumLabels>> counter;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringCounter0 : TFE_MonitoringCounter<0> {
|
||||
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
||||
};
|
||||
struct TFE_MonitoringCounter1 : TFE_MonitoringCounter<1> {
|
||||
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
||||
};
|
||||
struct TFE_MonitoringCounter2 : TFE_MonitoringCounter<2> {
|
||||
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringIntGaugeCell {
|
||||
tensorflow::monitoring::GaugeCell<tensorflow::int64> cell;
|
||||
};
|
||||
struct TFE_MonitoringStringGaugeCell {
|
||||
tensorflow::monitoring::GaugeCell<tensorflow::string> cell;
|
||||
};
|
||||
struct TFE_MonitoringBoolGaugeCell {
|
||||
tensorflow::monitoring::GaugeCell<bool> cell;
|
||||
};
|
||||
|
||||
template <typename ValueType, int NumLabels>
|
||||
struct TFE_MonitoringGauge {
|
||||
template <typename... LabelDesc>
|
||||
TFE_MonitoringGauge(const char* name, const char* description,
|
||||
LabelDesc&&... label) {
|
||||
gauge = absl::WrapUnique(
|
||||
tensorflow::monitoring::Gauge<ValueType, NumLabels>::New(
|
||||
name, description, label...));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::monitoring::Gauge<ValueType, NumLabels>> gauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringIntGauge0 : TFE_MonitoringGauge<tensorflow::int64, 0> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringIntGauge1 : TFE_MonitoringGauge<tensorflow::int64, 1> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringIntGauge2 : TFE_MonitoringGauge<tensorflow::int64, 2> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringStringGauge0 : TFE_MonitoringGauge<tensorflow::string, 0> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringStringGauge1 : TFE_MonitoringGauge<tensorflow::string, 1> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringStringGauge2 : TFE_MonitoringGauge<tensorflow::string, 2> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringBoolGauge0 : TFE_MonitoringGauge<bool, 0> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringBoolGauge1 : TFE_MonitoringGauge<bool, 1> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringBoolGauge2 : TFE_MonitoringGauge<bool, 2> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringBuckets {
|
||||
explicit TFE_MonitoringBuckets(
|
||||
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
|
||||
fn) {
|
||||
create_buckets = fn;
|
||||
}
|
||||
|
||||
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
|
||||
create_buckets;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringSamplerCell {
|
||||
tensorflow::monitoring::SamplerCell cell;
|
||||
};
|
||||
|
||||
template <int NumLabels>
|
||||
struct TFE_MonitoringSampler {
|
||||
template <typename... LabelDesc>
|
||||
TFE_MonitoringSampler(
|
||||
const char* name,
|
||||
std::unique_ptr<tensorflow::monitoring::Buckets> buckets,
|
||||
const char* description, LabelDesc&&... label) {
|
||||
sampler = absl::WrapUnique(tensorflow::monitoring::Sampler<NumLabels>::New(
|
||||
{name, description, label...}, std::move(buckets)));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::monitoring::Sampler<NumLabels>> sampler;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringSampler0 : TFE_MonitoringSampler<0> {
|
||||
using TFE_MonitoringSampler::TFE_MonitoringSampler;
|
||||
};
|
||||
struct TFE_MonitoringSampler1 : TFE_MonitoringSampler<1> {
|
||||
using TFE_MonitoringSampler::TFE_MonitoringSampler;
|
||||
};
|
||||
struct TFE_MonitoringSampler2 : TFE_MonitoringSampler<2> {
|
||||
using TFE_MonitoringSampler::TFE_MonitoringSampler;
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
// Set an AttrValue on the op. Doesn't handle the list types.
|
||||
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||
const tensorflow::AttrValue& default_value,
|
||||
const char* attr_name, TF_Status* status);
|
||||
} // namespace tensorflow
|
||||
|
||||
struct TFE_CancellationManager {
|
||||
tensorflow::CancellationManager cancellation_manager;
|
||||
};
|
||||
|
||||
struct TFE_Executor {
|
||||
explicit TFE_Executor(bool async)
|
||||
: owned_executor(new tensorflow::EagerExecutor(async)) {}
|
||||
|
||||
explicit TFE_Executor(tensorflow::EagerExecutor* executor)
|
||||
: owned_executor(nullptr), unowned_executor(executor) {}
|
||||
|
||||
tensorflow::EagerExecutor* executor() {
|
||||
return owned_executor == nullptr ? unowned_executor : owned_executor.get();
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::EagerExecutor> owned_executor;
|
||||
tensorflow::EagerExecutor* unowned_executor;
|
||||
};
|
||||
|
||||
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
|
||||
// that sometimes do not require serialization.
|
||||
struct TFE_OpAttrs {
|
||||
explicit TFE_OpAttrs() : name(nullptr), attributes(nullptr) {}
|
||||
|
||||
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value,
|
||||
const char* op_name)
|
||||
: name(op_name), attributes(value) {}
|
||||
|
||||
const char* name;
|
||||
const tensorflow::AttrBuilder* attributes;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
@ -167,7 +168,11 @@ string MatMulFunction() {
|
||||
return def.SerializeAsString();
|
||||
}
|
||||
|
||||
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
|
||||
// If heavy_load_on_streaming_rpc is true, send some rpc reqeusts before the one
|
||||
// which creates a remote remote input, to simulate a scenario that the remote
|
||||
// input is not ready when we start running an op or a function.
|
||||
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
|
||||
bool heavy_load_on_streaming_rpc) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
@ -192,48 +197,64 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx);
|
||||
std::vector<TFE_TensorHandle*> handles_task0;
|
||||
if (heavy_load_on_streaming_rpc) {
|
||||
// Send 50 tensor copy requests to simulate that there have been some RPC
|
||||
// requests been enqueued.
|
||||
for (int i = 0; i < 50; ++i) {
|
||||
handles_task0.push_back(TestMatrixTensorHandle(ctx));
|
||||
}
|
||||
}
|
||||
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
|
||||
|
||||
std::vector<TFE_TensorHandle*> handles_task2;
|
||||
for (auto* h_task0 : handles_task0) {
|
||||
handles_task2.push_back(
|
||||
TFE_TensorHandleCopyToDevice(h_task0, ctx, task2_name, status));
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
}
|
||||
|
||||
auto* h1_task2 =
|
||||
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TFE_Op* matmul = nullptr;
|
||||
if (func) {
|
||||
string function_def = MatMulFunction();
|
||||
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||
status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
matmul = TFE_NewOp(ctx, "MatMulFunction", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(matmul, h0_task0, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(matmul, h1_task2, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
} else {
|
||||
// Handles are on task0 (local), and task2, but op is on task1.
|
||||
matmul = MatMulOp(ctx, h0_task0, h1_task2);
|
||||
}
|
||||
if (remote) {
|
||||
TFE_OpSetDevice(matmul, task1_name, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
} else if (!async) {
|
||||
// Set the local device to CPU to easily validate mirroring
|
||||
string cpu_device_name;
|
||||
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
|
||||
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
auto remote_arg =
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
|
||||
// The input handles should never change since they have been mirrored.
|
||||
ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr));
|
||||
}
|
||||
@ -241,21 +262,22 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
// TODO(gjn): Add support for waiting on async local mirrors
|
||||
if (!remote && !async) {
|
||||
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
|
||||
auto remote_arg =
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
|
||||
// The input handles should never change since they have been mirrored.
|
||||
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
|
||||
}
|
||||
|
||||
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
|
||||
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(retval_task0);
|
||||
float product[4] = {0};
|
||||
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
|
||||
@ -270,12 +292,18 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
|
||||
TFE_DeleteTensorHandle(h1_task0);
|
||||
TFE_DeleteTensorHandle(h1_task2);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
for (auto* h : handles_task0) {
|
||||
TFE_DeleteTensorHandle(h);
|
||||
}
|
||||
for (auto* h : handles_task2) {
|
||||
TFE_DeleteTensorHandle(h);
|
||||
}
|
||||
|
||||
TFE_DeleteOp(matmul);
|
||||
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
if (func) {
|
||||
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
|
||||
@ -290,22 +318,37 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteSilentCopies) {
|
||||
TestRemoteExecuteSilentCopies(false, true, false);
|
||||
TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/true,
|
||||
/*func=*/false,
|
||||
/*heavy_load_on_streaming_rpc=*/false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
|
||||
TestRemoteExecuteSilentCopies(true, true, false);
|
||||
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/true, /*func=*/false,
|
||||
/*heavy_load_on_streaming_rpc=*/false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
|
||||
TestRemoteExecuteSilentCopies(true, true, true);
|
||||
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/true, /*func=*/true,
|
||||
/*heavy_load_on_streaming_rpc=*/false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
|
||||
TestRemoteExecuteSilentCopies(false, false, false);
|
||||
TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/false,
|
||||
/*func=*/false,
|
||||
/*heavy_load_on_streaming_rpc=*/false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
|
||||
TestRemoteExecuteSilentCopies(true, false, false);
|
||||
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false,
|
||||
/*func=*/false,
|
||||
/*heavy_load_on_streaming_rpc=*/false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
|
||||
TestRemoteExecuteSilentCopies(true, false, true);
|
||||
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true,
|
||||
/*heavy_load_on_streaming_rpc=*/false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
|
||||
// A remote input may be not ready when we start running a function. Test that
|
||||
// the function execution should wait until the remote input is ready.
|
||||
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true,
|
||||
/*heavy_load_on_streaming_rpc=*/true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||
@ -378,150 +421,4 @@ TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) {
|
||||
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) {
|
||||
TestRemoteExecuteDeleteContextWithOutstandingRPC(true);
|
||||
}
|
||||
|
||||
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
|
||||
const std::vector<float>& expected_values) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
std::unique_ptr<float[]> actual_values(new float[expected_values.size()]);
|
||||
EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t));
|
||||
memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
|
||||
for (int i = 0; i < expected_values.size(); i++) {
|
||||
EXPECT_EQ(expected_values[i], actual_values[i])
|
||||
<< "Mismatch in expected values at (zero-based) index " << i;
|
||||
}
|
||||
}
|
||||
|
||||
void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
|
||||
const char* remote_device_name,
|
||||
const char* local_device_name) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
|
||||
|
||||
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0);
|
||||
TFE_OpSetDevice(matmul, remote_device_name, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
auto* retval_task0 =
|
||||
TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22});
|
||||
|
||||
TFE_DeleteTensorHandle(retval_task0);
|
||||
TFE_DeleteTensorHandle(h0_task0);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
|
||||
TFE_DeleteOp(matmul);
|
||||
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteChangeServerDef(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
const char remote_device_name[] =
|
||||
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
const char local_device_name[] =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
|
||||
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
|
||||
// Update the server def with a new set of names (worker instead of
|
||||
// localhost).
|
||||
tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2);
|
||||
serialized = updated_server_def.SerializeAsString();
|
||||
|
||||
updated_server_def.set_task_index(1);
|
||||
tensorflow::Status s = tensorflow::GrpcServer::Create(
|
||||
updated_server_def, tensorflow::Env::Default(), &worker_server);
|
||||
ASSERT_TRUE(s.ok()) << s.error_message();
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Create a new tensor_handle.
|
||||
TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle(ctx);
|
||||
|
||||
// Check that copying it to the old remote device (named localhost) fails.
|
||||
TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status);
|
||||
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Copying and executing on the new remote device works.
|
||||
const char new_remote_device_name[] =
|
||||
"/job:worker/replica:0/task:1/device:CPU:0";
|
||||
const char new_local_device_name[] =
|
||||
"/job:worker/replica:0/task:0/device:CPU:0";
|
||||
|
||||
auto* h0_task1_new = TFE_TensorHandleCopyToDevice(
|
||||
h0_task0_new, ctx, new_remote_device_name, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_DeleteTensorHandle(h0_task0_new);
|
||||
TFE_DeleteTensorHandle(h0_task1_new);
|
||||
|
||||
CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name,
|
||||
new_local_device_name);
|
||||
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteChangeServerDef) {
|
||||
TestRemoteExecuteChangeServerDef(false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteChangeServerDefAsync) {
|
||||
TestRemoteExecuteChangeServerDef(true);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -27,6 +27,8 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
@ -416,8 +418,10 @@ void TensorHandleSilentCopy(bool async,
|
||||
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
|
||||
auto cpu_arg = tensorflow::TensorHandleFromInterface(hcpu->handle);
|
||||
auto gpu_arg = tensorflow::TensorHandleFromInterface(hgpu->handle);
|
||||
auto cpu_arg =
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hcpu));
|
||||
auto gpu_arg =
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hgpu));
|
||||
auto gpu_device = absl::get<tensorflow::Device*>(gpu_arg->device());
|
||||
ASSERT_FALSE(cpu_arg->HasLocalMirror(gpu_device));
|
||||
|
||||
@ -1346,7 +1350,7 @@ TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
|
||||
tensorflow::AttrValueMap ExtractAttrs(TFE_Op* op) {
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
tensorflow::EagerOperation* operation =
|
||||
tensorflow::OperationFromInterface(op->operation);
|
||||
tensorflow::OperationFromInterface(tensorflow::unwrap(op));
|
||||
operation->Attrs().FillAttrValueMap(&attr_values);
|
||||
return attr_values;
|
||||
}
|
||||
@ -1482,10 +1486,10 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
|
||||
TFE_TensorHandle* inputs[] = {input1, input2};
|
||||
TFE_OpAddInput(concatOp, dim, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
CHECK(concatOp->operation->OpDef());
|
||||
CHECK(tensorflow::unwrap(concatOp)->OpDef());
|
||||
TFE_OpAddInput(concatOp, inputs[0], status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_FALSE(concatOp->operation->OpDef())
|
||||
EXPECT_FALSE(tensorflow::unwrap(concatOp)->OpDef())
|
||||
<< "Inference context is still present";
|
||||
TFE_OpAddInput(concatOp, inputs[1], status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
@ -1577,7 +1581,7 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) {
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST(CAPI, TestTFE_OpGetAttrs) {
|
||||
TEST(CAPI, TestTFE_OpAddAttrs) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
@ -1587,12 +1591,11 @@ TEST(CAPI, TestTFE_OpGetAttrs) {
|
||||
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
|
||||
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
|
||||
TFE_OpAttrs attributes;
|
||||
TFE_OpGetAttrs(var_op, &attributes);
|
||||
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(var_op);
|
||||
|
||||
TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT);
|
||||
TFE_OpAddAttrs(copy_op, &attributes);
|
||||
TFE_OpAddAttrs(copy_op, attributes);
|
||||
unsigned char is_list = 0;
|
||||
ASSERT_EQ(TF_ATTR_TYPE,
|
||||
TFE_OpGetAttrType(copy_op, "dtype", &is_list, status));
|
||||
@ -1603,7 +1606,7 @@ TEST(CAPI, TestTFE_OpGetAttrs) {
|
||||
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
tensorflow::EagerOperation* op =
|
||||
tensorflow::OperationFromInterface(copy_op->operation);
|
||||
tensorflow::OperationFromInterface(tensorflow::unwrap(copy_op));
|
||||
op->Attrs().FillAttrValueMap(&attr_values);
|
||||
EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type());
|
||||
|
||||
@ -1624,11 +1627,10 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
|
||||
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
|
||||
TFE_OpAttrs attributes;
|
||||
TFE_OpGetAttrs(var_op, &attributes);
|
||||
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(var_op);
|
||||
|
||||
TF_Buffer* serialized_attr_values = TF_NewBuffer();
|
||||
TFE_OpAttrsSerialize(&attributes, serialized_attr_values, status);
|
||||
TFE_OpAttrsSerialize(attributes, serialized_attr_values, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
tensorflow::NameAttrList name_and_attrs;
|
||||
ASSERT_TRUE(name_and_attrs.ParseFromArray(serialized_attr_values->data,
|
||||
@ -1651,7 +1653,7 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) {
|
||||
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
tensorflow::EagerOperation* op =
|
||||
tensorflow::OperationFromInterface(var_op_2->operation);
|
||||
tensorflow::OperationFromInterface(tensorflow::unwrap(var_op_2));
|
||||
op->Attrs().FillAttrValueMap(&attr_values);
|
||||
EXPECT_EQ(tensorflow::DT_INT64, attr_values.find("dtype")->second.type());
|
||||
|
||||
|
@ -15,486 +15,78 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
using tensorflow::string;
|
||||
using tensorflow::internal::OutputList;
|
||||
using tensorflow::internal::unwrap;
|
||||
|
||||
// =============================================================================
|
||||
// Unified Execution APIs for Eager and tracing backends.
|
||||
// Public C API entry points
|
||||
//
|
||||
// These are only the generic entry points for the C API. This file does not
|
||||
// have any visibility into the graph/eager implementation and is only providing
|
||||
// C bindings to the abstract classes defined in the
|
||||
// c_api_unified_experimental_internal.h header.
|
||||
//
|
||||
// =============================================================================
|
||||
|
||||
typedef void (*ExecuteOperation)(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs,
|
||||
TF_OutputList* o, TF_ExecutionContext* ctx,
|
||||
TF_Status* s);
|
||||
|
||||
struct TF_ExecutionContext {
|
||||
// Needed to implement our own version of RTTI since dynamic_cast is not
|
||||
// supported in mobile builds.
|
||||
enum ExecutionContextKind { GraphContext, EagerContext };
|
||||
explicit TF_ExecutionContext(ExecutionContextKind kind) : k(kind) {}
|
||||
ExecutionContextKind getKind() const { return k; }
|
||||
|
||||
virtual void ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs,
|
||||
TF_OutputList* o, TF_Status* s) = 0;
|
||||
virtual TF_AbstractOp* CreateOperation() = 0;
|
||||
virtual void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) = 0;
|
||||
virtual ~TF_ExecutionContext() {}
|
||||
|
||||
private:
|
||||
const ExecutionContextKind k;
|
||||
};
|
||||
|
||||
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete c; }
|
||||
|
||||
template <typename T, typename S>
|
||||
T* dynamic_cast_helper(S source) {
|
||||
if (source->getKind() != T::kKind) {
|
||||
return nullptr;
|
||||
}
|
||||
return tensorflow::down_cast<T*>(source);
|
||||
}
|
||||
|
||||
class TF_GraphContext;
|
||||
class TF_EagerContext;
|
||||
|
||||
struct TF_GraphTensor {
|
||||
TF_Output output;
|
||||
TF_GraphContext* ctx;
|
||||
};
|
||||
|
||||
struct TF_AbstractTensor {
|
||||
absl::variant<TFE_TensorHandle*, TF_GraphTensor*> t;
|
||||
|
||||
~TF_AbstractTensor() {
|
||||
if (absl::holds_alternative<TFE_TensorHandle*>(t)) {
|
||||
TFE_DeleteTensorHandle(absl::get<TFE_TensorHandle*>(t));
|
||||
} else if (absl::holds_alternative<TF_GraphTensor*>(t)) {
|
||||
delete absl::get<TF_GraphTensor*>(t);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct TF_AbstractOp {
|
||||
// Needed to implement our own version of RTTI since dynamic_cast is not
|
||||
// supported in mobile builds.
|
||||
enum AbstractOpKind { GraphOp, EagerOp };
|
||||
explicit TF_AbstractOp(AbstractOpKind kind) : k(kind) {}
|
||||
AbstractOpKind getKind() const { return k; }
|
||||
virtual void SetOpType(const char* const op_type, TF_Status* s) = 0;
|
||||
virtual void SetOpName(const char* const op_name, TF_Status* s) = 0;
|
||||
virtual void SetAttrType(const char* const attr_name, TF_DataType value,
|
||||
TF_Status* s) = 0;
|
||||
virtual ~TF_AbstractOp() {}
|
||||
|
||||
private:
|
||||
const AbstractOpKind k;
|
||||
};
|
||||
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); }
|
||||
|
||||
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
|
||||
return c->CreateOperation();
|
||||
return wrap(unwrap(c)->CreateOperation());
|
||||
}
|
||||
|
||||
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete op; }
|
||||
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete unwrap(op); }
|
||||
|
||||
class TF_GraphOp : public TF_AbstractOp {
|
||||
public:
|
||||
explicit TF_GraphOp(TF_Graph* g) : TF_AbstractOp(kKind), g_(g) {}
|
||||
void SetOpType(const char* const op_type, TF_Status* s) override {
|
||||
if (op_) {
|
||||
TF_SetStatus(
|
||||
s, TF_FAILED_PRECONDITION,
|
||||
absl::StrCat("SetOpType called on already built op.").c_str());
|
||||
return;
|
||||
}
|
||||
if (op_name_ != nullptr) {
|
||||
op_.reset(TF_NewOperation(g_, op_type, op_name_));
|
||||
op_name_ = nullptr;
|
||||
} else {
|
||||
op_type_ = op_type;
|
||||
}
|
||||
}
|
||||
void SetOpName(const char* const op_name, TF_Status* s) override {
|
||||
if (op_) {
|
||||
TF_SetStatus(
|
||||
s, TF_FAILED_PRECONDITION,
|
||||
absl::StrCat("SetOpName called on already built op.").c_str());
|
||||
return;
|
||||
}
|
||||
if (op_type_ != nullptr) {
|
||||
op_.reset(TF_NewOperation(g_, op_type_, op_name));
|
||||
op_type_ = nullptr;
|
||||
} else {
|
||||
op_name_ = op_name;
|
||||
}
|
||||
}
|
||||
void SetAttrType(const char* const attr_name, TF_DataType value,
|
||||
TF_Status* s) override {
|
||||
if (!op_) {
|
||||
TF_SetStatus(
|
||||
s, TF_FAILED_PRECONDITION,
|
||||
"op_type and op_name must be specified before specifying attrs.");
|
||||
return;
|
||||
}
|
||||
TF_SetAttrType(op_.get(), attr_name, value);
|
||||
}
|
||||
~TF_GraphOp() override {}
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete unwrap(t); }
|
||||
|
||||
static constexpr AbstractOpKind kKind = GraphOp;
|
||||
|
||||
private:
|
||||
friend class TF_GraphContext; // For access to op_.
|
||||
TF_Graph* g_;
|
||||
std::unique_ptr<TF_OperationDescription> op_;
|
||||
// Hold `op_type` and `op_name` till both are available since we need both
|
||||
// to build a graph operation.
|
||||
const char* op_type_ = nullptr;
|
||||
const char* op_name_ = nullptr;
|
||||
};
|
||||
|
||||
class TF_EagerOp : public TF_AbstractOp {
|
||||
public:
|
||||
explicit TF_EagerOp(TFE_Context* ctx) : TF_AbstractOp(kKind), ctx_(ctx) {}
|
||||
void SetOpType(const char* const op_type, TF_Status* s) override {
|
||||
op_ = TFE_NewOp(ctx_, op_type, s);
|
||||
}
|
||||
void SetOpName(const char* const op_name, TF_Status* s) override {
|
||||
// Name is ignored in eager mode.
|
||||
}
|
||||
void SetAttrType(const char* const attr_name, TF_DataType value,
|
||||
TF_Status* s) override {
|
||||
if (op_ == nullptr) {
|
||||
TF_SetStatus(s, TF_FAILED_PRECONDITION,
|
||||
"op_type must be specified before specifying attrs.");
|
||||
return;
|
||||
}
|
||||
TFE_OpSetAttrType(op_, attr_name, value);
|
||||
}
|
||||
|
||||
~TF_EagerOp() override { TFE_DeleteOp(op_); }
|
||||
static constexpr AbstractOpKind kKind = EagerOp;
|
||||
|
||||
private:
|
||||
friend class TF_EagerContext; // For access to op_.
|
||||
TFE_Op* op_ = nullptr;
|
||||
TFE_Context* ctx_;
|
||||
};
|
||||
|
||||
bool IsEagerTensor(const TF_AbstractTensor* const t) {
|
||||
return absl::holds_alternative<TFE_TensorHandle*>(t->t);
|
||||
}
|
||||
|
||||
struct TF_OutputList {
|
||||
std::vector<TF_AbstractTensor*> outputs;
|
||||
int expected_num_outputs = -1;
|
||||
};
|
||||
|
||||
struct TF_AbstractFunction {
|
||||
TF_Function* func = nullptr;
|
||||
|
||||
~TF_AbstractFunction() { TF_DeleteFunction(func); }
|
||||
};
|
||||
|
||||
class TF_EagerContext : public TF_ExecutionContext {
|
||||
public:
|
||||
TF_EagerContext() : TF_ExecutionContext(kKind) {}
|
||||
|
||||
void Build(TFE_ContextOptions* options, TF_Status* status) {
|
||||
eager_ctx_ = TFE_NewContext(options, status);
|
||||
}
|
||||
|
||||
TF_AbstractOp* CreateOperation() override {
|
||||
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
||||
return new TF_EagerOp(eager_ctx_);
|
||||
}
|
||||
|
||||
void ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_Status* s) override {
|
||||
auto* eager_op = dynamic_cast_helper<TF_EagerOp>(op);
|
||||
if (eager_op == nullptr) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Unable to cast TF_AbstractOp to TF_EagerOp.");
|
||||
return;
|
||||
}
|
||||
auto* tfe_op = eager_op->op_;
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
if (!IsEagerTensor(inputs[i])) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
|
||||
return;
|
||||
}
|
||||
TFE_OpAddInput(tfe_op, absl::get<TFE_TensorHandle*>(inputs[i]->t), s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
}
|
||||
if (o->expected_num_outputs == -1) {
|
||||
string msg =
|
||||
"The number of outputs must be provided in eager mode. Use "
|
||||
"TF_OutputListSetNumOutputs.";
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return;
|
||||
}
|
||||
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals;
|
||||
int num_retvals = o->expected_num_outputs;
|
||||
retvals.resize(num_retvals);
|
||||
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
|
||||
if (TF_GetCode(s) != TF_OK) {
|
||||
return;
|
||||
}
|
||||
o->outputs.clear();
|
||||
o->outputs.reserve(num_retvals);
|
||||
for (int i = 0; i < num_retvals; ++i) {
|
||||
auto* t = new TF_AbstractTensor();
|
||||
t->t = retvals[i];
|
||||
o->outputs.push_back(t);
|
||||
}
|
||||
}
|
||||
|
||||
void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) override {
|
||||
TFE_ContextAddFunction(eager_ctx_, func->func, s);
|
||||
}
|
||||
|
||||
~TF_EagerContext() override { TFE_DeleteContext(eager_ctx_); }
|
||||
|
||||
static constexpr ExecutionContextKind kKind = EagerContext;
|
||||
|
||||
private:
|
||||
friend TFE_Context* TF_ExecutionContextGetTFEContext(
|
||||
TF_ExecutionContext* ctx);
|
||||
TFE_Context* eager_ctx_;
|
||||
};
|
||||
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete t; }
|
||||
|
||||
TF_GraphContext* GetGraphContext(TF_AbstractTensor const* t) {
|
||||
return absl::get<TF_GraphTensor*>(t->t)->ctx;
|
||||
}
|
||||
|
||||
class TF_GraphContext : public TF_ExecutionContext {
|
||||
public:
|
||||
TF_GraphContext()
|
||||
: TF_ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {}
|
||||
|
||||
TF_AbstractOp* CreateOperation() override {
|
||||
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
||||
return new TF_GraphOp(graph_.get());
|
||||
}
|
||||
|
||||
void ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_Status* s) override {
|
||||
auto* graph_op = dynamic_cast_helper<TF_GraphOp>(op);
|
||||
if (graph_op == nullptr) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Unable to cast TF_AbstractOp to TF_GraphOp.");
|
||||
return;
|
||||
}
|
||||
auto* tf_opdesc = graph_op->op_.release();
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
auto* input = inputs[i];
|
||||
if (IsEagerTensor(input)) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Capturing eager tensors is not supported yet.");
|
||||
return;
|
||||
} else {
|
||||
if (GetGraphContext(input) != this) {
|
||||
TF_SetStatus(
|
||||
s, TF_INVALID_ARGUMENT,
|
||||
"Capturing tensors from other graphs is not supported yet.");
|
||||
return;
|
||||
}
|
||||
TF_AddInput(tf_opdesc, absl::get<TF_GraphTensor*>(input->t)->output);
|
||||
}
|
||||
}
|
||||
auto* operation = TF_FinishOperation(tf_opdesc, s);
|
||||
// TF_FinishOperation deletes `tf_opdesc` so clear its reference.
|
||||
graph_op->op_ = nullptr;
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
int num_outputs = TF_OperationNumOutputs(operation);
|
||||
o->outputs.clear();
|
||||
o->outputs.reserve(num_outputs);
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
auto* t = new TF_AbstractTensor;
|
||||
TF_GraphTensor* graph_t = new TF_GraphTensor;
|
||||
graph_t->ctx = this;
|
||||
graph_t->output = {operation, i};
|
||||
t->t = graph_t;
|
||||
o->outputs.push_back(t);
|
||||
}
|
||||
}
|
||||
|
||||
TF_Function* ToFunction(const char* fn_name, int num_inputs,
|
||||
const TF_AbstractTensor* inputs, int num_outputs,
|
||||
const TF_AbstractTensor* outputs,
|
||||
TF_Status* status) const {
|
||||
std::vector<TF_Output> graph_inputs;
|
||||
graph_inputs.resize(num_inputs);
|
||||
std::vector<TF_Output> graph_outputs;
|
||||
graph_outputs.resize(num_outputs);
|
||||
for (int i = 0; i < num_inputs; i++) {
|
||||
graph_inputs[i] = absl::get<TF_GraphTensor*>(inputs[i].t)->output;
|
||||
}
|
||||
for (int i = 0; i < num_outputs; i++) {
|
||||
graph_outputs[i] = absl::get<TF_GraphTensor*>(outputs[i].t)->output;
|
||||
}
|
||||
|
||||
return TF_GraphToFunction(graph_.get(), fn_name, 0, -1, nullptr,
|
||||
graph_inputs.size(), graph_inputs.data(),
|
||||
graph_outputs.size(), graph_outputs.data(),
|
||||
nullptr, nullptr, fn_name, status);
|
||||
}
|
||||
|
||||
void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) override {
|
||||
TF_SetStatus(s, TF_UNIMPLEMENTED,
|
||||
"Registering graph functions has not been implemented yet.");
|
||||
}
|
||||
|
||||
~TF_GraphContext() override {}
|
||||
|
||||
static constexpr ExecutionContextKind kKind = GraphContext;
|
||||
|
||||
private:
|
||||
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
|
||||
};
|
||||
|
||||
struct TF_GraphContextOptions {};
|
||||
struct TF_EagerContextOptions {
|
||||
explicit TF_EagerContextOptions(TFE_ContextOptions* options)
|
||||
: options(options) {}
|
||||
TFE_ContextOptions* options; // Not owned.
|
||||
};
|
||||
|
||||
struct TF_ExecutionContextOptions {
|
||||
absl::variant<TF_GraphContextOptions*, TF_EagerContextOptions*> options;
|
||||
~TF_ExecutionContextOptions() {
|
||||
if (absl::holds_alternative<TF_GraphContextOptions*>(options)) {
|
||||
delete absl::get<TF_GraphContextOptions*>(options);
|
||||
} else if (absl::holds_alternative<TF_EagerContextOptions*>(options)) {
|
||||
delete absl::get<TF_EagerContextOptions*>(options);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TF_ExecutionContextOptions* TF_NewGraphContextOptions() {
|
||||
auto* options = new TF_ExecutionContextOptions();
|
||||
options->options = new TF_GraphContextOptions();
|
||||
return options;
|
||||
}
|
||||
|
||||
void TF_DeleteExecutionContextOptions(TF_ExecutionContextOptions* options) {
|
||||
delete options;
|
||||
}
|
||||
|
||||
TF_ExecutionContextOptions* TF_NewEagerContextOptions(
|
||||
TFE_ContextOptions* tfe_options) {
|
||||
auto* options = new TF_ExecutionContextOptions();
|
||||
options->options = new TF_EagerContextOptions(tfe_options);
|
||||
return options;
|
||||
}
|
||||
|
||||
TF_ExecutionContext* TF_NewExecutionContext(TF_ExecutionContextOptions* options,
|
||||
TF_Status* s) {
|
||||
if (absl::holds_alternative<TF_EagerContextOptions*>(options->options)) {
|
||||
auto* ctx = new TF_EagerContext();
|
||||
ctx->Build(absl::get<TF_EagerContextOptions*>(options->options)->options,
|
||||
s);
|
||||
return ctx;
|
||||
} else {
|
||||
return new TF_GraphContext();
|
||||
}
|
||||
}
|
||||
|
||||
TF_OutputList* TF_NewOutputList() { return new TF_OutputList; }
|
||||
void TF_DeleteOutputList(TF_OutputList* o) { delete o; }
|
||||
TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); }
|
||||
void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); }
|
||||
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
|
||||
TF_Status* s) {
|
||||
o->expected_num_outputs = num_outputs;
|
||||
unwrap(o)->expected_num_outputs = num_outputs;
|
||||
}
|
||||
int TF_OutputListNumOutputs(TF_OutputList* o) {
|
||||
return unwrap(o)->outputs.size();
|
||||
}
|
||||
int TF_OutputListNumOutputs(TF_OutputList* o) { return o->outputs.size(); }
|
||||
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
|
||||
return o->outputs[i];
|
||||
return wrap(unwrap(o)->outputs[i]);
|
||||
}
|
||||
|
||||
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
||||
TF_Status* s) {
|
||||
op->SetOpType(op_type, s);
|
||||
unwrap(op)->SetOpType(op_type, s);
|
||||
}
|
||||
|
||||
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
|
||||
TF_Status* s) {
|
||||
op->SetOpName(op_name, s);
|
||||
unwrap(op)->SetOpName(op_name, s);
|
||||
}
|
||||
|
||||
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
|
||||
TF_DataType value, TF_Status* s) {
|
||||
op->SetAttrType(attr_name, value, s);
|
||||
unwrap(op)->SetAttrType(attr_name, value, s);
|
||||
}
|
||||
|
||||
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_ExecutionContext* ctx, TF_Status* s) {
|
||||
ctx->ExecuteOperation(op, num_inputs, inputs, o, s);
|
||||
unwrap(ctx)->ExecuteOperation(unwrap(op), num_inputs, &unwrap(*inputs),
|
||||
unwrap(o), s);
|
||||
}
|
||||
|
||||
TF_AbstractFunction* TF_ExecutionContextToFunction(
|
||||
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
||||
const TF_AbstractTensor* inputs, int num_outputs,
|
||||
const TF_AbstractTensor* outputs, TF_Status* status) {
|
||||
auto* graph_ctx = dynamic_cast_helper<const TF_GraphContext>(fn_body);
|
||||
if (graph_ctx == nullptr) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"fn_body is not a TF_GraphContext.");
|
||||
return nullptr;
|
||||
}
|
||||
TF_AbstractFunction* func = new TF_AbstractFunction;
|
||||
func->func = graph_ctx->ToFunction(fn_name, num_inputs, inputs, num_outputs,
|
||||
outputs, status);
|
||||
return func;
|
||||
void TF_DeleteAbstractFunction(TF_AbstractFunction* func) {
|
||||
delete unwrap(func);
|
||||
}
|
||||
|
||||
void TF_DeleteAbstractFunction(TF_AbstractFunction* func) { delete func; }
|
||||
|
||||
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext* ctx,
|
||||
TF_AbstractFunction* func,
|
||||
TF_Status* s) {
|
||||
ctx->RegisterFunction(func, s);
|
||||
}
|
||||
|
||||
// Temporary APIs till we figure out how to create scalar valued Eager
|
||||
// tensors and how to get value out of eager abstract tensors.
|
||||
TF_AbstractTensor* TF_NewAbstractTensor() {
|
||||
TF_AbstractTensor* t = new TF_AbstractTensor;
|
||||
return t;
|
||||
}
|
||||
|
||||
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
|
||||
TF_Status* s) {
|
||||
at->t = t;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s) {
|
||||
if (!absl::holds_alternative<TFE_TensorHandle*>(at->t)) {
|
||||
string msg = absl::StrCat("Not an eager tensor handle.",
|
||||
reinterpret_cast<uintptr_t>(at));
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
return absl::get<TFE_TensorHandle*>(at->t);
|
||||
}
|
||||
|
||||
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx) {
|
||||
return dynamic_cast_helper<TF_EagerContext>(ctx)->eager_ctx_;
|
||||
unwrap(ctx)->RegisterFunction(unwrap(func), s);
|
||||
}
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
@ -34,26 +35,34 @@ extern "C" {
|
||||
// E.g. it could know whether we're in eager mode or in graph mode, keeps track
|
||||
// of gradient tapes, etc.
|
||||
typedef struct TF_ExecutionContext TF_ExecutionContext;
|
||||
|
||||
// A TF_AbstractTensor is an input to an operation. E.g. it could be a union
|
||||
// type of eager and graph tensors.
|
||||
// type of eager and graph tensors. It is also the result of executing an
|
||||
// operation.
|
||||
typedef struct TF_AbstractTensor TF_AbstractTensor;
|
||||
|
||||
// A TF_AbstractOp is the metadata we need to execute an operation. E.g. this
|
||||
// could contain the op type and other attributes.
|
||||
typedef struct TF_AbstractOp TF_AbstractOp;
|
||||
|
||||
// `TF_ExecutionContextOptions` define what type of `TF_ExecutionContext` is
|
||||
// created. It can be used to pass context specific params.
|
||||
typedef struct TF_ExecutionContextOptions TF_ExecutionContextOptions;
|
||||
void TF_DeleteExecutionContextOptions(TF_ExecutionContextOptions*);
|
||||
// Stores a function representation that can be used for execution or for
|
||||
// setting functional attributes of other composite ops e.g. control flow.
|
||||
typedef struct TF_AbstractFunction TF_AbstractFunction;
|
||||
|
||||
// Creates a context for tracing the execution of operations into a function.
|
||||
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s);
|
||||
|
||||
// Creates a context for eager execution of operations.
|
||||
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*,
|
||||
TF_Status* s);
|
||||
|
||||
TF_ExecutionContext* TF_NewExecutionContext(TF_ExecutionContextOptions*,
|
||||
TF_Status* s);
|
||||
void TF_DeleteExecutionContext(TF_ExecutionContext*);
|
||||
|
||||
// Create an operation suitable to use with the provided context. The operation
|
||||
// requires its type (e.g. "AddV2") to be set independently.
|
||||
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx);
|
||||
void TF_DeleteAbstractOp(TF_AbstractOp*);
|
||||
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
|
||||
// TODO(srbs): Add APIs for specifying attrs etc.
|
||||
// `op_type` must outlive `op`.
|
||||
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
||||
@ -65,9 +74,16 @@ void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
|
||||
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
|
||||
TF_DataType value, TF_Status* s);
|
||||
|
||||
// TF_OutputList just lets us not specify the number of outputs of an operation
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
|
||||
|
||||
// TF_OutputList holds the list of TF_AbstractTensor that results from executing
|
||||
// an operation.
|
||||
// It just lets us not specify the number of outputs of an operation
|
||||
// beforehand. This forces a memory allocation in the runtime, which is bad, but
|
||||
// it allows for generic code.
|
||||
// TODO(aminim): the description above isn't clear with respect to
|
||||
// TF_OutputListNumOutputs and the current eager implementation which requires
|
||||
// the number of outputs to be set by the client.
|
||||
typedef struct TF_OutputList TF_OutputList;
|
||||
TF_OutputList* TF_NewOutputList();
|
||||
void TF_DeleteOutputList(TF_OutputList* o);
|
||||
@ -75,38 +91,38 @@ void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*);
|
||||
int TF_OutputListNumOutputs(TF_OutputList* o);
|
||||
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
|
||||
|
||||
// Stores a function representation that can be used for execution or for
|
||||
// setting functional attributes of other composite ops e.g. control flow.
|
||||
typedef struct TF_AbstractFunction TF_AbstractFunction;
|
||||
TF_AbstractFunction* TF_ExecutionContextToFunction(
|
||||
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
||||
const TF_AbstractTensor* inputs, int num_outputs,
|
||||
const TF_AbstractTensor* outputs, TF_Status* status);
|
||||
void TF_DeleteAbstractFunction(TF_AbstractFunction*);
|
||||
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext*,
|
||||
TF_AbstractFunction*, TF_Status*);
|
||||
|
||||
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
|
||||
// capture some inputs and then add a node in the graph, and after
|
||||
// execution/node creation it'll go and record things that happened in any tape
|
||||
// which happens to be active.
|
||||
// capture some inputs and then add a node in the graph. The output tensors are
|
||||
// returned through the provided TF_OutputList.
|
||||
// Any active tape will observe the effects of this execution.
|
||||
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_ExecutionContext* ctx, TF_Status* s);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// APIs specific to Eager and graph modes
|
||||
// -----------------------------------------------------------------------------
|
||||
// Creates a new TF_AbstractFunction from the current tracing states in the
|
||||
// context. The returned TF_GraphToFunction must be deleted by the client.
|
||||
// TODO(aminim): clarify the contract on the state of the context after this
|
||||
// call.
|
||||
TF_AbstractFunction* TF_ExecutionContextToFunction(
|
||||
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
||||
const TF_AbstractTensor* inputs, int num_outputs,
|
||||
const TF_AbstractTensor* outputs, TF_Status* status);
|
||||
|
||||
TF_ExecutionContextOptions* TF_NewGraphContextOptions();
|
||||
TF_ExecutionContextOptions* TF_NewEagerContextOptions(TFE_ContextOptions*);
|
||||
void TF_DeleteAbstractFunction(TF_AbstractFunction*);
|
||||
|
||||
// Register the function with the given context. This is particularly useful for
|
||||
// making a function available to an eager context.
|
||||
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext*,
|
||||
TF_AbstractFunction*, TF_Status*);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// APIs specific to Eager modes
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// Temporary APIs till we figure out how to create scalar valued Eager
|
||||
// tensors and how to get value out of eager abstract tensors.
|
||||
TF_AbstractTensor* TF_NewAbstractTensor();
|
||||
void TF_AbstractTensorSetEagerTensor(
|
||||
TF_AbstractTensor* at, TFE_TensorHandle* t,
|
||||
TF_Status* s); // `at` takes ownership of `t`.
|
||||
TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t,
|
||||
TF_Status* s);
|
||||
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s);
|
||||
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext*);
|
||||
|
183
tensorflow/c/eager/c_api_unified_experimental_eager.cc
Normal file
183
tensorflow/c/eager/c_api_unified_experimental_eager.cc
Normal file
@ -0,0 +1,183 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
// Simple wrapper over a TFE_TensorHandle
|
||||
struct EagerTensor : public AbstractTensor {
|
||||
TFE_TensorHandle* t = nullptr;
|
||||
EagerTensor() : AbstractTensor(kKind) {}
|
||||
explicit EagerTensor(TFE_TensorHandle* t) : AbstractTensor(kKind), t(t) {}
|
||||
~EagerTensor() override { TFE_DeleteTensorHandle(t); }
|
||||
static constexpr AbstractTensorKind kKind = kEagerTensor;
|
||||
};
|
||||
|
||||
// Simple wrapper over a TFE_Op
|
||||
class EagerOp : public AbstractOp {
|
||||
public:
|
||||
explicit EagerOp(TFE_Context* ctx) : AbstractOp(kKind), ctx_(ctx) {}
|
||||
void SetOpType(const char* const op_type, TF_Status* s) override {
|
||||
op_ = TFE_NewOp(ctx_, op_type, s);
|
||||
}
|
||||
void SetOpName(const char* const op_name, TF_Status* s) override {
|
||||
// Name is ignored in eager mode.
|
||||
}
|
||||
void SetAttrType(const char* const attr_name, TF_DataType value,
|
||||
TF_Status* s) override {
|
||||
if (op_ == nullptr) {
|
||||
TF_SetStatus(s, TF_FAILED_PRECONDITION,
|
||||
"op_type must be specified before specifying attrs.");
|
||||
return;
|
||||
}
|
||||
TFE_OpSetAttrType(op_, attr_name, value);
|
||||
}
|
||||
|
||||
~EagerOp() override { TFE_DeleteOp(op_); }
|
||||
static constexpr AbstractOpKind kKind = kEagerOp;
|
||||
|
||||
private:
|
||||
friend class EagerContext; // For access to op_.
|
||||
TFE_Op* op_ = nullptr;
|
||||
TFE_Context* ctx_;
|
||||
};
|
||||
|
||||
// Wraps a TFE_Context and dispatch EagerOp with EagerTensor inputs.
|
||||
class EagerContext : public ExecutionContext {
|
||||
public:
|
||||
EagerContext() : ExecutionContext(kKind) {}
|
||||
|
||||
void Build(TFE_ContextOptions* options, TF_Status* status) {
|
||||
eager_ctx_ = TFE_NewContext(options, status);
|
||||
}
|
||||
|
||||
AbstractOp* CreateOperation() override {
|
||||
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
||||
return new EagerOp(eager_ctx_);
|
||||
}
|
||||
|
||||
void ExecuteOperation(AbstractOp* op, int num_inputs,
|
||||
AbstractTensor* const* inputs, OutputList* o,
|
||||
TF_Status* s) override {
|
||||
auto* eager_op = dyncast<EagerOp>(op);
|
||||
if (eager_op == nullptr) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Unable to cast AbstractOp to TF_EagerOp.");
|
||||
return;
|
||||
}
|
||||
auto* tfe_op = eager_op->op_;
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
auto* eager_tensor = dyncast<const EagerTensor>(inputs[i]);
|
||||
if (!eager_tensor) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
|
||||
return;
|
||||
}
|
||||
TFE_OpAddInput(tfe_op, eager_tensor->t, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
}
|
||||
if (o->expected_num_outputs == -1) {
|
||||
string msg =
|
||||
"The number of outputs must be provided in eager mode. Use "
|
||||
"TF_OutputListSetNumOutputs.";
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return;
|
||||
}
|
||||
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals;
|
||||
int num_retvals = o->expected_num_outputs;
|
||||
retvals.resize(num_retvals);
|
||||
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
|
||||
if (TF_GetCode(s) != TF_OK) {
|
||||
return;
|
||||
}
|
||||
o->outputs.clear();
|
||||
o->outputs.reserve(num_retvals);
|
||||
for (int i = 0; i < num_retvals; ++i) {
|
||||
o->outputs.push_back(new EagerTensor(retvals[i]));
|
||||
}
|
||||
}
|
||||
|
||||
void RegisterFunction(AbstractFunction* afunc, TF_Status* s) override {
|
||||
auto* func = afunc->GetTfFunction(s);
|
||||
if (!func) {
|
||||
return;
|
||||
}
|
||||
TFE_ContextAddFunction(eager_ctx_, func, s);
|
||||
}
|
||||
|
||||
~EagerContext() override { TFE_DeleteContext(eager_ctx_); }
|
||||
|
||||
static constexpr ExecutionContextKind kKind = kEagerContext;
|
||||
|
||||
private:
|
||||
friend TFE_Context* ::TF_ExecutionContextGetTFEContext(
|
||||
TF_ExecutionContext* ctx);
|
||||
TFE_Context* eager_ctx_;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
// =============================================================================
|
||||
// Public C API entry points
|
||||
// These are only the entry points specific to the Eager API.
|
||||
// =============================================================================
|
||||
|
||||
using tensorflow::internal::dyncast;
|
||||
using tensorflow::internal::unwrap;
|
||||
|
||||
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions* options,
|
||||
TF_Status* s) {
|
||||
auto* ctx = new tensorflow::internal::EagerContext();
|
||||
ctx->Build(options, s);
|
||||
return wrap(ctx);
|
||||
}
|
||||
|
||||
TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t,
|
||||
TF_Status* s) {
|
||||
return wrap(new tensorflow::internal::EagerTensor(t));
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s) {
|
||||
auto* eager_tensor = dyncast<tensorflow::internal::EagerTensor>(unwrap(at));
|
||||
if (!eager_tensor) {
|
||||
string msg = tensorflow::strings::StrCat("Not an eager tensor handle.",
|
||||
reinterpret_cast<uintptr_t>(at));
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
return eager_tensor->t;
|
||||
}
|
||||
|
||||
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx) {
|
||||
auto* eager_ctx = dyncast<tensorflow::internal::EagerContext>(unwrap(ctx));
|
||||
if (!eager_ctx) return nullptr;
|
||||
return eager_ctx->eager_ctx_;
|
||||
}
|
248
tensorflow/c/eager/c_api_unified_experimental_graph.cc
Normal file
248
tensorflow/c/eager/c_api_unified_experimental_graph.cc
Normal file
@ -0,0 +1,248 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
class GraphContext;
|
||||
|
||||
// GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index
|
||||
// into the list of outputs for the operation.
|
||||
struct GraphTensor : public AbstractTensor {
|
||||
TF_Output output{};
|
||||
GraphContext* ctx = nullptr;
|
||||
GraphTensor() : AbstractTensor(kKind) {}
|
||||
GraphTensor(TF_Output output, GraphContext* ctx)
|
||||
: AbstractTensor(kKind), output(output), ctx(ctx) {}
|
||||
static constexpr AbstractTensorKind kKind = kGraphTensor;
|
||||
};
|
||||
|
||||
// GraphOp wraps and populate a TF_OperationDescription.
|
||||
class GraphOp : public AbstractOp {
|
||||
public:
|
||||
explicit GraphOp(TF_Graph* g) : AbstractOp(kKind), g_(g) {}
|
||||
void SetOpType(const char* const op_type, TF_Status* s) override {
|
||||
if (op_) {
|
||||
TF_SetStatus(
|
||||
s, TF_FAILED_PRECONDITION,
|
||||
strings::StrCat("SetOpType called on already built op.").c_str());
|
||||
return;
|
||||
}
|
||||
if (op_name_ != nullptr) {
|
||||
op_.reset(TF_NewOperation(g_, op_type, op_name_));
|
||||
op_name_ = nullptr;
|
||||
} else {
|
||||
op_type_ = op_type;
|
||||
}
|
||||
}
|
||||
void SetOpName(const char* const op_name, TF_Status* s) override {
|
||||
if (op_) {
|
||||
TF_SetStatus(
|
||||
s, TF_FAILED_PRECONDITION,
|
||||
strings::StrCat("SetOpName called on already built op.").c_str());
|
||||
return;
|
||||
}
|
||||
if (op_type_ != nullptr) {
|
||||
op_.reset(TF_NewOperation(g_, op_type_, op_name));
|
||||
op_type_ = nullptr;
|
||||
} else {
|
||||
op_name_ = op_name;
|
||||
}
|
||||
}
|
||||
void SetAttrType(const char* const attr_name, TF_DataType value,
|
||||
TF_Status* s) override {
|
||||
if (!op_) {
|
||||
TF_SetStatus(
|
||||
s, TF_FAILED_PRECONDITION,
|
||||
"op_type and op_name must be specified before specifying attrs.");
|
||||
return;
|
||||
}
|
||||
TF_SetAttrType(op_.get(), attr_name, value);
|
||||
}
|
||||
~GraphOp() override {}
|
||||
|
||||
static constexpr AbstractOpKind kKind = kGraphOp;
|
||||
|
||||
private:
|
||||
friend class GraphContext; // For access to op_.
|
||||
TF_Graph* g_;
|
||||
std::unique_ptr<TF_OperationDescription> op_;
|
||||
// Hold `op_type` and `op_name` till both are available since we need both
|
||||
// to build a graph operation.
|
||||
const char* op_type_ = nullptr;
|
||||
const char* op_name_ = nullptr;
|
||||
};
|
||||
|
||||
// GraphFunction is a thin wrapper over a TF_Function.
|
||||
struct GraphFunction : public AbstractFunction {
|
||||
TF_Function* func = nullptr;
|
||||
GraphFunction() : AbstractFunction(kKind) {}
|
||||
explicit GraphFunction(TF_Function* func)
|
||||
: AbstractFunction(kKind), func(func) {}
|
||||
~GraphFunction() override {
|
||||
if (func) TF_DeleteFunction(func);
|
||||
}
|
||||
|
||||
TF_Function* GetTfFunction(TF_Status* s) override { return func; }
|
||||
|
||||
static constexpr AbstractFunctionKind kKind = kGraphFunc;
|
||||
};
|
||||
|
||||
// GraphContext wraps a TF_Graph and manages the "execution" of operation, i.e.
|
||||
// adding them to the graph.
|
||||
class GraphContext : public ExecutionContext {
|
||||
public:
|
||||
GraphContext()
|
||||
: ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {}
|
||||
|
||||
AbstractOp* CreateOperation() override {
|
||||
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
||||
return new GraphOp(graph_.get());
|
||||
}
|
||||
|
||||
void ExecuteOperation(AbstractOp* op, int num_inputs,
|
||||
AbstractTensor* const* inputs, OutputList* o,
|
||||
TF_Status* s) override {
|
||||
auto* graph_op = dyncast<GraphOp>(op);
|
||||
if (graph_op == nullptr) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Unable to cast AbstractOp to TF_GraphOp.");
|
||||
return;
|
||||
}
|
||||
auto* tf_opdesc = graph_op->op_.release();
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
auto* graph_tensor = dyncast<GraphTensor>(inputs[i]);
|
||||
if (!graph_tensor) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Capturing eager tensors is not supported yet.");
|
||||
return;
|
||||
} else {
|
||||
if (graph_tensor->ctx != this) {
|
||||
TF_SetStatus(
|
||||
s, TF_INVALID_ARGUMENT,
|
||||
"Capturing tensors from other graphs is not supported yet.");
|
||||
return;
|
||||
}
|
||||
TF_AddInput(tf_opdesc, graph_tensor->output);
|
||||
}
|
||||
}
|
||||
auto* operation = TF_FinishOperation(tf_opdesc, s);
|
||||
// TF_FinishOperation deletes `tf_opdesc` so clear its reference.
|
||||
graph_op->op_ = nullptr;
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
int num_outputs = TF_OperationNumOutputs(operation);
|
||||
o->outputs.clear();
|
||||
o->outputs.reserve(num_outputs);
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
o->outputs.push_back(new GraphTensor({operation, i}, this));
|
||||
}
|
||||
}
|
||||
|
||||
TF_Function* ToFunction(const char* fn_name, int num_inputs,
|
||||
const GraphTensor* inputs, int num_outputs,
|
||||
const GraphTensor* outputs, TF_Status* status) const {
|
||||
std::vector<TF_Output> graph_inputs;
|
||||
graph_inputs.resize(num_inputs);
|
||||
std::vector<TF_Output> graph_outputs;
|
||||
graph_outputs.resize(num_outputs);
|
||||
for (int i = 0; i < num_inputs; i++) {
|
||||
graph_inputs[i] = inputs[i].output;
|
||||
}
|
||||
for (int i = 0; i < num_outputs; i++) {
|
||||
graph_outputs[i] = outputs[i].output;
|
||||
}
|
||||
|
||||
return TF_GraphToFunction(graph_.get(), fn_name, 0, -1, nullptr,
|
||||
graph_inputs.size(), graph_inputs.data(),
|
||||
graph_outputs.size(), graph_outputs.data(),
|
||||
nullptr, nullptr, fn_name, status);
|
||||
}
|
||||
|
||||
void RegisterFunction(AbstractFunction* func, TF_Status* s) override {
|
||||
TF_SetStatus(s, TF_UNIMPLEMENTED,
|
||||
"Registering graph functions has not been implemented yet.");
|
||||
}
|
||||
|
||||
~GraphContext() override {}
|
||||
|
||||
static constexpr ExecutionContextKind kKind = kGraphContext;
|
||||
|
||||
private:
|
||||
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
|
||||
};
|
||||
|
||||
// Helper that converts the graph currently held in the context into a function.
|
||||
static AbstractFunction* ExecutionContextToFunction(
|
||||
const ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
||||
const AbstractTensor* inputs, int num_outputs,
|
||||
const AbstractTensor* outputs, TF_Status* status) {
|
||||
auto* graph_ctx = dyncast<const GraphContext>(fn_body);
|
||||
if (graph_ctx == nullptr) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"fn_body is not a TF_GraphContext.");
|
||||
return nullptr;
|
||||
}
|
||||
auto* graph_inputs = dyncast<const GraphTensor>(inputs);
|
||||
if (!graph_inputs) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, "inputs aren't GraphTensors.");
|
||||
return nullptr;
|
||||
}
|
||||
auto* graph_outputs = dyncast<const GraphTensor>(outputs);
|
||||
if (!graph_outputs) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, "outputs aren't GraphTensors.");
|
||||
return nullptr;
|
||||
}
|
||||
GraphFunction* func = new GraphFunction;
|
||||
func->func = graph_ctx->ToFunction(fn_name, num_inputs, graph_inputs,
|
||||
num_outputs, graph_outputs, status);
|
||||
return func;
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
// =============================================================================
|
||||
// Public C API entry points
|
||||
// These are only the entry points specific to the Graph API.
|
||||
// =============================================================================
|
||||
|
||||
using tensorflow::internal::unwrap;
|
||||
|
||||
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s) {
|
||||
return wrap(new tensorflow::internal::GraphContext());
|
||||
}
|
||||
|
||||
TF_AbstractFunction* TF_ExecutionContextToFunction(
|
||||
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
||||
const TF_AbstractTensor* inputs, int num_outputs,
|
||||
const TF_AbstractTensor* outputs, TF_Status* status) {
|
||||
return wrap(ExecutionContextToFunction(unwrap(fn_body), fn_name, num_inputs,
|
||||
unwrap(inputs), num_outputs,
|
||||
unwrap(outputs), status));
|
||||
}
|
184
tensorflow/c/eager/c_api_unified_experimental_internal.h
Normal file
184
tensorflow/c/eager/c_api_unified_experimental_internal.h
Normal file
@ -0,0 +1,184 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
// =============================================================================
|
||||
// Implementation detail for the unified execution APIs for Eager and tracing
|
||||
// backends (graph/MLIR).
|
||||
//
|
||||
// This defines a set of abstract classes that are intended to provide the
|
||||
// functionality of the opaque C types exposed in the public APIs defined in the
|
||||
// `c_api_unified_experimental.h` header.
|
||||
// =============================================================================
|
||||
|
||||
// We can't depend on C++ rtti, but we still want to be able to have a safe
|
||||
// dynamic_cast to provide diagnostics to the user when the API is misused.
|
||||
// Instead we model RTTI by listing all the possible subclasses for each
|
||||
// abstract base. Each subclass initializes the base class with the right
|
||||
// `kind`, which allows an equivalent to `std::dynamic_cast` provided by this
|
||||
// utility.
|
||||
template <typename T, typename S>
|
||||
T* dyncast(S source) {
|
||||
if (source->getKind() != T::kKind) {
|
||||
return nullptr;
|
||||
}
|
||||
return tensorflow::down_cast<T*>(source);
|
||||
}
|
||||
|
||||
// Represents either an EagerTensor or a GraphTensor.
|
||||
// This base class does not expose any public methods other than to distinguish
|
||||
// which subclass it actually is. The user is responsible to use the right
|
||||
// type of AbstractTensor in their context (do not pass an EagerTensor to a
|
||||
// GraphContext and vice-versa).
|
||||
class AbstractTensor {
|
||||
protected:
|
||||
enum AbstractTensorKind { kGraphTensor, kEagerTensor, kMLIRTensor };
|
||||
explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {}
|
||||
|
||||
public:
|
||||
// Returns which subclass is this instance of.
|
||||
AbstractTensorKind getKind() const { return kind_; }
|
||||
virtual ~AbstractTensor() = default;
|
||||
|
||||
private:
|
||||
const AbstractTensorKind kind_;
|
||||
};
|
||||
|
||||
// Represents the results of the execution of an operation.
|
||||
struct OutputList {
|
||||
std::vector<AbstractTensor*> outputs;
|
||||
int expected_num_outputs = -1;
|
||||
};
|
||||
|
||||
// Holds the result of tracing a function.
|
||||
class AbstractFunction {
|
||||
protected:
|
||||
enum AbstractFunctionKind { kGraphFunc };
|
||||
explicit AbstractFunction(AbstractFunctionKind kind) : kind_(kind) {}
|
||||
|
||||
public:
|
||||
// Returns which subclass is this instance of.
|
||||
AbstractFunctionKind getKind() const { return kind_; }
|
||||
virtual ~AbstractFunction() = default;
|
||||
|
||||
// Temporary API till we figure the right abstraction for AbstractFunction.
|
||||
// At the moment both Eager and Graph needs access to a "TF_Function" object.
|
||||
virtual TF_Function* GetTfFunction(TF_Status* s) = 0;
|
||||
|
||||
private:
|
||||
const AbstractFunctionKind kind_;
|
||||
};
|
||||
|
||||
// An abstract operation describes an operation by its type, name, and
|
||||
// attributes. It can be "executed" by the context with some input tensors.
|
||||
// It is allowed to reusing the same abstract operation for multiple execution
|
||||
// on a given context, with the same or different input tensors.
|
||||
class AbstractOp {
|
||||
protected:
|
||||
enum AbstractOpKind { kGraphOp, kEagerOp };
|
||||
explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {}
|
||||
|
||||
public:
|
||||
// Returns which subclass is this instance of.
|
||||
AbstractOpKind getKind() const { return kind_; }
|
||||
virtual ~AbstractOp() = default;
|
||||
|
||||
// Sets the type of the operation (for example `AddV2`).
|
||||
virtual void SetOpType(const char* op_type, TF_Status* s) = 0;
|
||||
|
||||
// Sets the name of the operation: this is an optional identifier that is
|
||||
// not intended to carry semantics and preserved/propagated without
|
||||
// guarantees.
|
||||
virtual void SetOpName(const char* op_name, TF_Status* s) = 0;
|
||||
|
||||
// Add a `TypeAttribute` on the operation.
|
||||
virtual void SetAttrType(const char* attr_name, TF_DataType value,
|
||||
TF_Status* s) = 0;
|
||||
|
||||
private:
|
||||
const AbstractOpKind kind_;
|
||||
};
|
||||
|
||||
// This holds the context for the execution: dispatching operations either to an
|
||||
// eager implementation or to a graph implementation.
|
||||
struct ExecutionContext {
|
||||
protected:
|
||||
enum ExecutionContextKind { kGraphContext, kEagerContext };
|
||||
explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {}
|
||||
|
||||
public:
|
||||
// Returns which subclass is this instance of.
|
||||
ExecutionContextKind getKind() const { return k; }
|
||||
virtual ~ExecutionContext() = default;
|
||||
|
||||
// Executes the operation on the provided inputs and populate the OutputList
|
||||
// with the results. The input tensors must match the current context.
|
||||
// The effect of "executing" an operation depends on the context: in an Eager
|
||||
// context it will dispatch it to the runtime for execution, while in a
|
||||
// tracing context it will add the operation to the current function.
|
||||
virtual void ExecuteOperation(AbstractOp* op, int num_inputs,
|
||||
AbstractTensor* const* inputs, OutputList* o,
|
||||
TF_Status* s) = 0;
|
||||
|
||||
// Creates an empty AbstractOperation suitable to use with this context.
|
||||
virtual AbstractOp* CreateOperation() = 0;
|
||||
|
||||
// Registers a functions with this context, after this the function is
|
||||
// available to be called/referenced by its name in this context.
|
||||
virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0;
|
||||
|
||||
private:
|
||||
const ExecutionContextKind k;
|
||||
};
|
||||
|
||||
// Create utilities to wrap/unwrap: this convert from the C opaque types to the
|
||||
// C++ implementation, and back.
|
||||
#define MAKE_WRAP_UNWRAP(C_TYPEDEF, CPP_CLASS) \
|
||||
static inline CPP_CLASS* const& unwrap(C_TYPEDEF* const& o) { \
|
||||
return reinterpret_cast<CPP_CLASS* const&>(o); \
|
||||
} \
|
||||
static inline const CPP_CLASS* const& unwrap(const C_TYPEDEF* const& o) { \
|
||||
return reinterpret_cast<const CPP_CLASS* const&>(o); \
|
||||
} \
|
||||
static inline C_TYPEDEF* const& wrap(CPP_CLASS* const& o) { \
|
||||
return reinterpret_cast<C_TYPEDEF* const&>(o); \
|
||||
} \
|
||||
static inline const C_TYPEDEF* const& wrap(const CPP_CLASS* const& o) { \
|
||||
return reinterpret_cast<const C_TYPEDEF* const&>(o); \
|
||||
}
|
||||
|
||||
MAKE_WRAP_UNWRAP(TF_ExecutionContext, ExecutionContext)
|
||||
MAKE_WRAP_UNWRAP(TF_AbstractFunction, AbstractFunction)
|
||||
MAKE_WRAP_UNWRAP(TF_AbstractTensor, AbstractTensor)
|
||||
MAKE_WRAP_UNWRAP(TF_AbstractOp, AbstractOp)
|
||||
MAKE_WRAP_UNWRAP(TF_OutputList, OutputList)
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_
|
@ -15,17 +15,14 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
|
||||
#include <string.h>
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/cc/profiler/profiler.h"
|
||||
#include "tensorflow/core/lib/monitoring/collection_registry.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/str_util.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
@ -36,8 +33,7 @@ TEST(UnifedCAPI, TestBasicEager) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TF_ExecutionContextOptions* options = TF_NewEagerContextOptions(opts);
|
||||
TF_ExecutionContext* ctx = TF_NewExecutionContext(options, status.get());
|
||||
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
@ -46,8 +42,8 @@ TEST(UnifedCAPI, TestBasicEager) {
|
||||
// Build an abstract input tensor.
|
||||
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx);
|
||||
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||
TF_AbstractTensor* at = TF_NewAbstractTensor();
|
||||
TF_AbstractTensorSetEagerTensor(at, t, status.get());
|
||||
TF_AbstractTensor* at =
|
||||
TF_CreateAbstractTensorFromEagerTensor(t, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract operation.
|
||||
@ -83,15 +79,12 @@ TEST(UnifedCAPI, TestBasicEager) {
|
||||
TF_DeleteAbstractTensor(result);
|
||||
TF_DeleteOutputList(o);
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
TF_DeleteExecutionContextOptions(options);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TestBasicGraph) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
|
||||
TF_ExecutionContext* graph_ctx =
|
||||
TF_NewExecutionContext(options, status.get());
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
@ -143,10 +136,8 @@ TEST(UnifedCAPI, TestBasicGraph) {
|
||||
|
||||
// Build eager context.
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TF_ExecutionContextOptions* eager_ctx_options =
|
||||
TF_NewEagerContextOptions(opts);
|
||||
TF_ExecutionContext* eager_execution_ctx =
|
||||
TF_NewExecutionContext(eager_ctx_options, status.get());
|
||||
TF_NewEagerExecutionContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
@ -158,11 +149,11 @@ TEST(UnifedCAPI, TestBasicGraph) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract input tensor.
|
||||
TF_AbstractTensor* input_t = TF_NewAbstractTensor();
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(eager_execution_ctx);
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||
TF_AbstractTensorSetEagerTensor(input_t, input_eager, status.get());
|
||||
TF_AbstractTensor* input_t =
|
||||
TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
|
||||
@ -191,16 +182,13 @@ TEST(UnifedCAPI, TestBasicGraph) {
|
||||
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||
TF_DeleteExecutionContextOptions(eager_ctx_options);
|
||||
TF_DeleteExecutionContextOptions(options);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TF_ExecutionContextOptions* options = TF_NewEagerContextOptions(opts);
|
||||
TF_ExecutionContext* ctx = TF_NewExecutionContext(options, status.get());
|
||||
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
@ -210,15 +198,12 @@ TEST(UnifedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
TF_DeleteExecutionContextOptions(options);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
|
||||
TF_ExecutionContext* graph_ctx =
|
||||
TF_NewExecutionContext(options, status.get());
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
@ -234,15 +219,12 @@ TEST(UnifedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
||||
|
||||
TF_DeleteAbstractOp(placeholder_op);
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
TF_DeleteExecutionContextOptions(options);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
|
||||
TF_ExecutionContext* graph_ctx =
|
||||
TF_NewExecutionContext(options, status.get());
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
@ -258,7 +240,6 @@ TEST(UnifedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
||||
|
||||
TF_DeleteAbstractOp(placeholder_op);
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
TF_DeleteExecutionContextOptions(options);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||
@ -266,8 +247,7 @@ TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TF_ExecutionContextOptions* options = TF_NewEagerContextOptions(opts);
|
||||
TF_ExecutionContext* ctx = TF_NewExecutionContext(options, status.get());
|
||||
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
@ -281,8 +261,8 @@ TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||
// Build an abstract input tensor.
|
||||
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx);
|
||||
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||
TF_AbstractTensor* at = TF_NewAbstractTensor();
|
||||
TF_AbstractTensorSetEagerTensor(at, t, status.get());
|
||||
TF_AbstractTensor* at =
|
||||
TF_CreateAbstractTensorFromEagerTensor(t, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
@ -292,9 +272,7 @@ TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build a Graph context.
|
||||
TF_ExecutionContextOptions* graph_options = TF_NewGraphContextOptions();
|
||||
TF_ExecutionContext* graph_ctx =
|
||||
TF_NewExecutionContext(graph_options, status.get());
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Execute eager op using graph context.
|
||||
@ -307,17 +285,13 @@ TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||
|
||||
TF_DeleteOutputList(o);
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
TF_DeleteExecutionContextOptions(options);
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
TF_DeleteExecutionContextOptions(graph_options);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TestExecutingGraphOpInEagerModeRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContextOptions* options = TF_NewGraphContextOptions();
|
||||
TF_ExecutionContext* graph_ctx =
|
||||
TF_NewExecutionContext(options, status.get());
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
@ -355,10 +329,8 @@ TEST(UnifedCAPI, TestExecutingGraphOpInEagerModeRaises) {
|
||||
|
||||
// Build eager context.
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TF_ExecutionContextOptions* eager_ctx_options =
|
||||
TF_NewEagerContextOptions(opts);
|
||||
TF_ExecutionContext* eager_execution_ctx =
|
||||
TF_NewExecutionContext(eager_ctx_options, status.get());
|
||||
TF_NewEagerExecutionContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
@ -374,8 +346,6 @@ TEST(UnifedCAPI, TestExecutingGraphOpInEagerModeRaises) {
|
||||
TF_DeleteOutputList(placeholder_outputs);
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||
TF_DeleteExecutionContextOptions(eager_ctx_options);
|
||||
TF_DeleteExecutionContextOptions(options);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -17,9 +17,11 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
@ -60,13 +62,31 @@ class AbstractContextInterface {
|
||||
// Create a handle to wrap and manage a Tensor
|
||||
virtual AbstractTensorHandleInterface* CreateLocalHandle(
|
||||
AbstractTensorInterface* t) = 0;
|
||||
// Copy the handle to another device.
|
||||
virtual AbstractTensorHandleInterface* CopyTensorHandleToDevice(
|
||||
AbstractTensorHandleInterface* handle, const char* device_name,
|
||||
Status* status) = 0;
|
||||
|
||||
// Create an operation to perform op execution
|
||||
virtual AbstractOperationInterface* CreateOperation() = 0;
|
||||
|
||||
// Load a SavedModelAPI object from the given directory and tags
|
||||
virtual std::unique_ptr<SavedModelAPI> LoadSavedModelAPI(
|
||||
const std::string& directory,
|
||||
const absl::optional<std::unordered_set<std::string>>& tags,
|
||||
tensorflow::Status* status) = 0;
|
||||
|
||||
// List attributes of available devices
|
||||
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
|
||||
|
||||
virtual void ClearCachesAndThreadExecutors() = 0;
|
||||
|
||||
// Initialize the step resource container for a training step. This is used
|
||||
// in current TF runtime. For tfrt, it is used by fallback op handler.
|
||||
virtual void StartStep() = 0;
|
||||
// Destroy the step resource container for a training step.
|
||||
virtual void EndStep() = 0;
|
||||
|
||||
protected:
|
||||
virtual ~AbstractContextInterface() {}
|
||||
};
|
||||
|
@ -16,8 +16,10 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/dlpack.h"
|
||||
|
||||
#include "include/dlpack/dlpack.h" // from @dlpack
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_reference.h"
|
||||
@ -41,15 +43,15 @@ struct TfDlManagedTensorCtx {
|
||||
|
||||
// Gets tensor from eager tensor handle.
|
||||
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::TensorHandleFromInterface(h->handle);
|
||||
if (handle->IsRemote()) {
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
|
||||
if (handle->Type() != TensorHandle::LOCAL) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"DLPack doesn't support remote tensor");
|
||||
"DLPack doesn't support ", handle->TypeString(), " tensor");
|
||||
return nullptr;
|
||||
}
|
||||
const tensorflow::Tensor* tensor;
|
||||
@ -107,7 +109,7 @@ DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) {
|
||||
// Gets DLPack's DLContext from eager tensor handle.
|
||||
DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) {
|
||||
DLContext ctx;
|
||||
const char* device_name = h->handle->DeviceName(&status->status);
|
||||
const char* device_name = tensorflow::unwrap(h)->DeviceName(&status->status);
|
||||
DeviceNameUtils::ParsedName parsed_name;
|
||||
tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
|
||||
std::string device_type = parsed_name.type;
|
||||
|
@ -7,10 +7,26 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
# Currently pybind extension shared objects must use only C API headers since
|
||||
# the C API has static initializers duplicated in the Python bindings. So we
|
||||
# need a second rule that omits .cc files, in
|
||||
# tensorflow/python:_pywrap_parallel_device.
|
||||
filegroup(
|
||||
name = "headers",
|
||||
srcs = ["parallel_device.h"],
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "sources",
|
||||
srcs = ["parallel_device.cc"],
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "parallel_device",
|
||||
srcs = ["parallel_device.cc"],
|
||||
hdrs = ["parallel_device.h"],
|
||||
srcs = [":sources"],
|
||||
hdrs = [":headers"],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
|
@ -574,23 +574,21 @@ void DeleteParallelDevice(void* device_info) {
|
||||
|
||||
} // namespace
|
||||
|
||||
void RegisterParallelDevice(TFE_Context* context, const char* device_name,
|
||||
const char** underlying_devices,
|
||||
int num_underlying_devices, TF_Status* status) {
|
||||
TFE_CustomDevice custom_device;
|
||||
custom_device.copy_tensor_to_device = &CopyToParallelDevice;
|
||||
custom_device.copy_tensor_from_device = &CopyTensorFromParallelDevice;
|
||||
custom_device.delete_device = &DeleteParallelDevice;
|
||||
custom_device.execute = &ParallelDeviceExecute;
|
||||
void AllocateParallelDevice(const char* device_name,
|
||||
const char* const* underlying_devices,
|
||||
int num_underlying_devices,
|
||||
TFE_CustomDevice* device, void** device_info) {
|
||||
device->copy_tensor_to_device = &CopyToParallelDevice;
|
||||
device->copy_tensor_from_device = &CopyTensorFromParallelDevice;
|
||||
device->delete_device = &DeleteParallelDevice;
|
||||
device->execute = &ParallelDeviceExecute;
|
||||
std::vector<std::string> underlying_devices_vector;
|
||||
underlying_devices_vector.reserve(num_underlying_devices);
|
||||
for (int device_index = 0; device_index < num_underlying_devices;
|
||||
++device_index) {
|
||||
underlying_devices_vector.push_back(underlying_devices[device_index]);
|
||||
}
|
||||
ParallelDevice* d =
|
||||
new ParallelDevice(device_name, underlying_devices_vector);
|
||||
TFE_RegisterCustomDevice(context, custom_device, device_name, d, status);
|
||||
*device_info = new ParallelDevice(device_name, underlying_devices_vector);
|
||||
}
|
||||
|
||||
} // namespace eager
|
||||
|
@ -16,12 +16,14 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
|
||||
#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace eager {
|
||||
|
||||
// Register a parallel device named `device_name` which forwards operations to
|
||||
// Allocate a parallel device named `device_name` which forwards operations to
|
||||
// `underlying_devices`, maintaining "parallel tensors" with components placed
|
||||
// on each underlying device.
|
||||
//
|
||||
@ -50,11 +52,12 @@ namespace eager {
|
||||
// TPUReplicatedOutput(input=x, num_replicas=2)` un-packs the parallel tensor
|
||||
// into its components.
|
||||
//
|
||||
// `context` owns the parallel device. `underlying_devices` must stay valid
|
||||
// while the parallel device is in use.
|
||||
void RegisterParallelDevice(TFE_Context* context, const char* device_name,
|
||||
const char** underlying_devices,
|
||||
int num_underlying_devices, TF_Status* status);
|
||||
// The filled `device` struct and the allocated `device_info` struct may be
|
||||
// passed to TFE_RegisterCustomDevice. The `device_name` arguments must match.
|
||||
void AllocateParallelDevice(const char* device_name,
|
||||
const char* const* underlying_devices,
|
||||
int num_underlying_devices,
|
||||
TFE_CustomDevice* device, void** device_info);
|
||||
|
||||
} // namespace eager
|
||||
} // namespace tensorflow
|
||||
|
@ -288,6 +288,19 @@ void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) {
|
||||
*static_cast<float*>(TF_TensorData(value_zero.get())));
|
||||
}
|
||||
|
||||
template <std::size_t num_devices>
|
||||
void RegisterParallelDevice(
|
||||
TFE_Context* context, const char* device_name,
|
||||
const std::array<const char*, num_devices>& underlying_devices,
|
||||
TF_Status* status) {
|
||||
TFE_CustomDevice device;
|
||||
void* device_info;
|
||||
tensorflow::eager::AllocateParallelDevice(
|
||||
device_name, underlying_devices.data(), underlying_devices.size(),
|
||||
&device, &device_info);
|
||||
TFE_RegisterCustomDevice(context, device, device_name, device_info, status);
|
||||
}
|
||||
|
||||
// Create and modify a variable placed on a parallel device which composes
|
||||
// `first_device` and `second_device`.
|
||||
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
@ -297,9 +310,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::array<const char*, 2> underlying_devices{first_device, second_device};
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context, device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
RegisterParallelDevice(context, device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a variable handle (uninitialized to start) placed on the parallel
|
||||
@ -456,16 +468,14 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> underlying_devices;
|
||||
const char* first_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
underlying_devices.push_back(first_device_name);
|
||||
const char* second_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1";
|
||||
underlying_devices.push_back(second_device_name);
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
std::array<const char*, 2> underlying_devices{first_device_name,
|
||||
second_device_name};
|
||||
RegisterParallelDevice(context.get(), device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TensorHandlePtr cpu_value(FloatTensorHandle(3., status.get()));
|
||||
@ -524,12 +534,11 @@ TEST(PARALLEL_DEVICE, TestDifferentShapes) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> underlying_devices;
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
std::array<const char*, 2> underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1"};
|
||||
RegisterParallelDevice(context.get(), device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create two vectors with different lengths
|
||||
@ -570,24 +579,22 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
|
||||
// Create a parallel device with two CPUs
|
||||
const char* first_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> first_underlying_devices{
|
||||
std::array<const char*, 2> first_underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1"};
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), first_device_name, first_underlying_devices.data(),
|
||||
first_underlying_devices.size(), status.get());
|
||||
RegisterParallelDevice(context.get(), first_device_name,
|
||||
first_underlying_devices, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a second parallel device with the first parallel device and one
|
||||
// additional CPU.
|
||||
const char* second_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CUSTOM:1";
|
||||
std::vector<const char*> second_underlying_devices{
|
||||
std::array<const char*, 2> second_underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CUSTOM:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:2"};
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), second_device_name, second_underlying_devices.data(),
|
||||
second_underlying_devices.size(), status.get());
|
||||
RegisterParallelDevice(context.get(), second_device_name,
|
||||
second_underlying_devices, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a tensor on the first parallel device
|
||||
@ -656,11 +663,10 @@ TEST(PARALLEL_DEVICE, TestInvalidPacking) {
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> underlying_devices;
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
std::array<const char*, 1> underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0"};
|
||||
RegisterParallelDevice(context.get(), device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
|
||||
@ -775,12 +781,11 @@ TEST(PARALLEL_DEVICE, TestCollective) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> underlying_devices;
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
std::array<const char*, 2> underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1"};
|
||||
RegisterParallelDevice(context.get(), device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a tensor on the parallel device
|
||||
@ -867,12 +872,11 @@ TEST(PARALLEL_DEVICE, TestFunction) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> underlying_devices;
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
std::array<const char*, 2> underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1"};
|
||||
RegisterParallelDevice(context.get(), device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* function_name = "test_reduce_mul";
|
||||
|
24
tensorflow/c/eager/tfe_cancellation_manager_internal.h
Normal file
24
tensorflow/c/eager/tfe_cancellation_manager_internal.h
Normal file
@ -0,0 +1,24 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
|
||||
struct TFE_CancellationManager {
|
||||
tensorflow::CancellationManager cancellation_manager;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_
|
35
tensorflow/c/eager/tfe_context_internal.h
Normal file
35
tensorflow/c/eager/tfe_context_internal.h
Normal file
@ -0,0 +1,35 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/context_interface.h"
|
||||
|
||||
// Wraps a pointer to a context implementation.
|
||||
//
|
||||
// WARNING: Since the underlying object could be ref-counted a user of this
|
||||
// interface cannot destruct the underlying context object. Instead, call
|
||||
// TFE_DeleteContext who calls Release() on the context pointer and deletes
|
||||
// the TFE_Context structure.
|
||||
typedef struct TFE_Context TFE_Context;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractContextInterface, TFE_Context);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
|
37
tensorflow/c/eager/tfe_executor_internal.h
Normal file
37
tensorflow/c/eager/tfe_executor_internal.h
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TFE_EXECUTOR_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_EXECUTOR_INTERNAL_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
||||
|
||||
struct TFE_Executor {
|
||||
explicit TFE_Executor(bool async)
|
||||
: owned_executor(new tensorflow::EagerExecutor(async)) {}
|
||||
|
||||
explicit TFE_Executor(tensorflow::EagerExecutor* executor)
|
||||
: owned_executor(nullptr), unowned_executor(executor) {}
|
||||
|
||||
tensorflow::EagerExecutor* executor() {
|
||||
return owned_executor == nullptr ? unowned_executor : owned_executor.get();
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::EagerExecutor> owned_executor;
|
||||
tensorflow::EagerExecutor* unowned_executor;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_EXECUTOR_INTERNAL_H_
|
146
tensorflow/c/eager/tfe_monitoring_internal.h
Normal file
146
tensorflow/c/eager/tfe_monitoring_internal.h
Normal file
@ -0,0 +1,146 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TFE_MONITORING_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_MONITORING_INTERNAL_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
struct TFE_MonitoringCounterCell {
|
||||
tensorflow::monitoring::CounterCell cell;
|
||||
};
|
||||
|
||||
template <int NumLabels>
|
||||
struct TFE_MonitoringCounter {
|
||||
template <typename... LabelDesc>
|
||||
TFE_MonitoringCounter(const char* name, const char* description,
|
||||
LabelDesc&&... label) {
|
||||
counter = absl::WrapUnique(tensorflow::monitoring::Counter<NumLabels>::New(
|
||||
name, description, label...));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::monitoring::Counter<NumLabels>> counter;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringCounter0 : TFE_MonitoringCounter<0> {
|
||||
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
||||
};
|
||||
struct TFE_MonitoringCounter1 : TFE_MonitoringCounter<1> {
|
||||
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
||||
};
|
||||
struct TFE_MonitoringCounter2 : TFE_MonitoringCounter<2> {
|
||||
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringIntGaugeCell {
|
||||
tensorflow::monitoring::GaugeCell<tensorflow::int64> cell;
|
||||
};
|
||||
struct TFE_MonitoringStringGaugeCell {
|
||||
tensorflow::monitoring::GaugeCell<tensorflow::string> cell;
|
||||
};
|
||||
struct TFE_MonitoringBoolGaugeCell {
|
||||
tensorflow::monitoring::GaugeCell<bool> cell;
|
||||
};
|
||||
|
||||
template <typename ValueType, int NumLabels>
|
||||
struct TFE_MonitoringGauge {
|
||||
template <typename... LabelDesc>
|
||||
TFE_MonitoringGauge(const char* name, const char* description,
|
||||
LabelDesc&&... label) {
|
||||
gauge = absl::WrapUnique(
|
||||
tensorflow::monitoring::Gauge<ValueType, NumLabels>::New(
|
||||
name, description, label...));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::monitoring::Gauge<ValueType, NumLabels>> gauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringIntGauge0 : TFE_MonitoringGauge<tensorflow::int64, 0> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringIntGauge1 : TFE_MonitoringGauge<tensorflow::int64, 1> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringIntGauge2 : TFE_MonitoringGauge<tensorflow::int64, 2> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringStringGauge0 : TFE_MonitoringGauge<tensorflow::string, 0> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringStringGauge1 : TFE_MonitoringGauge<tensorflow::string, 1> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringStringGauge2 : TFE_MonitoringGauge<tensorflow::string, 2> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringBoolGauge0 : TFE_MonitoringGauge<bool, 0> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringBoolGauge1 : TFE_MonitoringGauge<bool, 1> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringBoolGauge2 : TFE_MonitoringGauge<bool, 2> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringBuckets {
|
||||
explicit TFE_MonitoringBuckets(
|
||||
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
|
||||
fn) {
|
||||
create_buckets = fn;
|
||||
}
|
||||
|
||||
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
|
||||
create_buckets;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringSamplerCell {
|
||||
tensorflow::monitoring::SamplerCell cell;
|
||||
};
|
||||
|
||||
template <int NumLabels>
|
||||
struct TFE_MonitoringSampler {
|
||||
template <typename... LabelDesc>
|
||||
TFE_MonitoringSampler(
|
||||
const char* name,
|
||||
std::unique_ptr<tensorflow::monitoring::Buckets> buckets,
|
||||
const char* description, LabelDesc&&... label) {
|
||||
sampler = absl::WrapUnique(tensorflow::monitoring::Sampler<NumLabels>::New(
|
||||
{name, description, label...}, std::move(buckets)));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::monitoring::Sampler<NumLabels>> sampler;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringSampler0 : TFE_MonitoringSampler<0> {
|
||||
using TFE_MonitoringSampler::TFE_MonitoringSampler;
|
||||
};
|
||||
struct TFE_MonitoringSampler1 : TFE_MonitoringSampler<1> {
|
||||
using TFE_MonitoringSampler::TFE_MonitoringSampler;
|
||||
};
|
||||
struct TFE_MonitoringSampler2 : TFE_MonitoringSampler<2> {
|
||||
using TFE_MonitoringSampler::TFE_MonitoringSampler;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_MONITORING_INTERNAL_H_
|
39
tensorflow/c/eager/tfe_op_attrs_internal.h
Normal file
39
tensorflow/c/eager/tfe_op_attrs_internal.h
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
|
||||
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
|
||||
// that sometimes do not require serialization.
|
||||
typedef struct TFE_OpAttrs TFE_OpAttrs;
|
||||
|
||||
typedef struct TFE_Context TFE_Context;
|
||||
typedef struct TFE_Op TFE_Op;
|
||||
|
||||
namespace tensorflow {
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AttrBuilder, TFE_OpAttrs);
|
||||
|
||||
// Set an AttrValue on the op. Doesn't handle the list types.
|
||||
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||
const tensorflow::AttrValue& default_value,
|
||||
const char* attr_name, TF_Status* status);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
|
36
tensorflow/c/eager/tfe_op_internal.h
Normal file
36
tensorflow/c/eager/tfe_op_internal.h
Normal file
@ -0,0 +1,36 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
|
||||
// Wraps a pointer to an operation implementation.
|
||||
//
|
||||
// WARNING: Since the underlying object could be ref-counted a user of this
|
||||
// interface cannot destruct the underlying operation object. Instead, call
|
||||
// TFE_DeleteOp who calls Release() on the operation pointer and deletes
|
||||
// the TFE_Op structure.
|
||||
typedef struct TFE_Op TFE_Op;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface, TFE_Op);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface*, TFE_Op*);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
|
30
tensorflow/c/eager/tfe_tensor_debug_info_internal.h
Normal file
30
tensorflow/c/eager/tfe_tensor_debug_info_internal.h
Normal file
@ -0,0 +1,30 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TFE_TENSOR_DEBUG_INFO_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_TENSOR_DEBUG_INFO_INTERNAL_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
struct TFE_TensorDebugInfo {
|
||||
explicit TFE_TensorDebugInfo(const std::vector<tensorflow::int64>& dims)
|
||||
: dev_dims(dims) {}
|
||||
|
||||
// Fully-padded, minor-to-major.
|
||||
std::vector<tensorflow::int64> dev_dims;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_TENSOR_DEBUG_INFO_INTERNAL_H_
|
38
tensorflow/c/eager/tfe_tensorhandle_internal.h
Normal file
38
tensorflow/c/eager/tfe_tensorhandle_internal.h
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
|
||||
// Wraps a pointer to a tensor handle implementation.
|
||||
//
|
||||
// WARNING: Since the underlying object could be ref-counted a user of this
|
||||
// interface cannot destruct the underlying handle object. Instead, call
|
||||
// TFE_DeleteTensorHandle who calls Release() on the handle pointer and deletes
|
||||
// the TFE_TensorHandle structure.
|
||||
typedef struct TFE_TensorHandle TFE_TensorHandle;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface,
|
||||
TFE_TensorHandle);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface*,
|
||||
TFE_TensorHandle*);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
|
@ -85,17 +85,36 @@ class ModularFileSystemTest : public ::testing::TestWithParam<std::string> {
|
||||
const std::string test_name = tensorflow::str_util::StringReplace(
|
||||
::testing::UnitTest::GetInstance()->current_test_info()->name(), "/",
|
||||
"_", /*replace_all=*/true);
|
||||
root_dir_ = tensorflow::io::JoinPath(
|
||||
::testing::TempDir(),
|
||||
tensorflow::strings::StrCat("tf_fs_", rng_val_, "_", test_name));
|
||||
if (!cloud_path_.empty()) {
|
||||
// We have to join path for non-local filesystem manually to make sure
|
||||
// that this test will run on Windows since `tensorflow::io::JoinPath`
|
||||
// behaves differently on Windows. `tmp_dir` should be something like
|
||||
// `path/to/tmp/dir/`. After joining path, we will have
|
||||
// /path/to/tmp/dir/tf_fs_rng_name/`
|
||||
root_dir_ = tensorflow::strings::StrCat(
|
||||
"/", tmp_dir_,
|
||||
tensorflow::strings::StrCat("tf_fs_", rng_val_, "_", test_name), "/");
|
||||
} else {
|
||||
root_dir_ = tensorflow::io::JoinPath(
|
||||
tmp_dir_,
|
||||
tensorflow::strings::StrCat("tf_fs_", rng_val_, "_", test_name));
|
||||
}
|
||||
if (!GetParam().empty()) {
|
||||
root_dir_ = tensorflow::strings::StrCat(GetParam(), "://", cloud_path_,
|
||||
root_dir_);
|
||||
}
|
||||
env_ = Env::Default();
|
||||
}
|
||||
|
||||
void SetUp() override {
|
||||
if (mkdir(root_dir_.c_str(), 0755) != 0) {
|
||||
int error_code = errno;
|
||||
GTEST_SKIP() << "Cannot create working directory: "
|
||||
<< tensorflow::IOError(root_dir_, error_code);
|
||||
FileSystem* fs = nullptr;
|
||||
Status s = env_->GetFileSystemForFile(root_dir_, &fs);
|
||||
if (fs == nullptr || !s.ok())
|
||||
GTEST_SKIP() << "No filesystem registered: " << s;
|
||||
|
||||
s = fs->CreateDir(root_dir_);
|
||||
if (!s.ok()) {
|
||||
GTEST_SKIP() << "Cannot create working directory: " << s;
|
||||
}
|
||||
}
|
||||
|
||||
@ -115,9 +134,10 @@ class ModularFileSystemTest : public ::testing::TestWithParam<std::string> {
|
||||
std::string GetURIForPath(StringPiece path) {
|
||||
const std::string translated_name =
|
||||
tensorflow::io::JoinPath(root_dir_, path);
|
||||
if (GetParam().empty()) return translated_name;
|
||||
|
||||
return tensorflow::strings::StrCat(GetParam(), "://", translated_name);
|
||||
// We have already checked `GetParam().empty()` in
|
||||
// `ModularFileSystemTest()`. root_dir_ should contain `GetParam() + "://"`
|
||||
// if it isn't empty.
|
||||
return translated_name;
|
||||
}
|
||||
|
||||
// Converts absolute paths to paths relative to root_dir_.
|
||||
@ -133,15 +153,28 @@ class ModularFileSystemTest : public ::testing::TestWithParam<std::string> {
|
||||
rng_val_ = distribution(gen);
|
||||
}
|
||||
|
||||
static void SetCloudPath(const std::string& cloud_path) {
|
||||
cloud_path_ = cloud_path;
|
||||
if (cloud_path_.back() == '/') cloud_path_.pop_back();
|
||||
}
|
||||
|
||||
static void SetTmpDir(const std::string& tmp_dir) {
|
||||
tmp_dir_ = tmp_dir.empty() ? ::testing::TempDir() : tmp_dir;
|
||||
}
|
||||
|
||||
protected:
|
||||
Env* env_;
|
||||
|
||||
private:
|
||||
std::string root_dir_;
|
||||
static int rng_val_;
|
||||
static std::string cloud_path_;
|
||||
static std::string tmp_dir_;
|
||||
};
|
||||
|
||||
int ModularFileSystemTest::rng_val_;
|
||||
std::string ModularFileSystemTest::cloud_path_;
|
||||
std::string ModularFileSystemTest::tmp_dir_;
|
||||
|
||||
// As some of the implementations might be missing, the tests should still pass
|
||||
// if the returned `Status` signals the unimplemented state.
|
||||
@ -1729,6 +1762,20 @@ static bool GetURIScheme(const std::string& scheme) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// This function is used for cloud filesystem
|
||||
// `S3` and `GCS` require the `root_dir_` to have bucket name
|
||||
// `HDFS` requires the `root_dir` to have namenode
|
||||
// `root_dir_ = scheme + "://" cloud_path_ + root_dir_`
|
||||
static bool SetCloudPath(const std::string& cloud_path_) {
|
||||
ModularFileSystemTest::SetCloudPath(cloud_path_);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool SetTmpDir(const std::string& tmp_dir_) {
|
||||
ModularFileSystemTest::SetTmpDir(tmp_dir_);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
@ -1741,7 +1788,12 @@ GTEST_API_ int main(int argc, char** argv) {
|
||||
tensorflow::Flag("dso", tensorflow::LoadDSO, "",
|
||||
"Path to shared object to load"),
|
||||
tensorflow::Flag("scheme", tensorflow::GetURIScheme, "",
|
||||
"URI scheme to test")};
|
||||
"URI scheme to test"),
|
||||
tensorflow::Flag("cloud_path", tensorflow::SetCloudPath, "",
|
||||
"Path for cloud filesystem (namenode for hdfs, "
|
||||
"bucketname for s3/gcs)"),
|
||||
tensorflow::Flag("tmp_dir", tensorflow::SetTmpDir, "",
|
||||
"Temporary directory to store test data.")};
|
||||
if (!tensorflow::Flags::Parse(&argc, argv, flag_list)) {
|
||||
std::cout << tensorflow::Flags::Usage(argv[0], flag_list);
|
||||
return -1;
|
||||
|
66
tensorflow/c/experimental/saved_model/README.md
Normal file
66
tensorflow/c/experimental/saved_model/README.md
Normal file
@ -0,0 +1,66 @@
|
||||
# Tensorflow C SavedModel API
|
||||
|
||||
## Overview
|
||||
|
||||
These are the new experimental C SavedModel APIs for loading and running
|
||||
SavedModels in a TF2-idiomatic fashion. See
|
||||
[RFC 207](https://github.com/tensorflow/community/pull/207) for additional
|
||||
context.
|
||||
|
||||
The directory structure is as follows:
|
||||
|
||||
```none
|
||||
saved_model/
|
||||
|
||||
public/
|
||||
|
||||
internal/
|
||||
|
||||
core/
|
||||
|
||||
```
|
||||
|
||||
## saved_model/public
|
||||
|
||||
`saved_model/public` is intended to house *only the public headers* of the
|
||||
SavedModel C API.
|
||||
|
||||
These headers:
|
||||
|
||||
1. declare opaque C types (like `TF_SavedModel`),
|
||||
|
||||
2. declare the functions that operate on these types (like `TF_LoadSavedModel`).
|
||||
|
||||
Once they leave experimental, these APIs should be considered stable for use
|
||||
by external clients.
|
||||
|
||||
These headers are in a separate directory to make it obvious to clients which
|
||||
headers they should depend on, and which headers are implementation details.
|
||||
Separating these public headers by directory also allow future programmatic
|
||||
checks to ensure that TF public headers only `#include` other public TF headers.
|
||||
|
||||
## saved_model/internal
|
||||
|
||||
`saved_model/internal` is the "glue" between the C API and the internal C++
|
||||
implementation.
|
||||
|
||||
Its role is to:
|
||||
|
||||
1. implement the C API functions declared in `saved_model/public`
|
||||
|
||||
2. define the C API types declared in `saved_model/public`
|
||||
|
||||
The files fulfilling 1. are named `*.cc` (eg: `concrete_function.cc`), while
|
||||
the files fulfilling 2. are `*type.h` (eg: `concrete_function_type.h`).
|
||||
|
||||
The headers exposing the internal implementation of the opaque C types are only
|
||||
visible to other implementors of the C API. This is similar to how other
|
||||
TF C API implementations use `tf_status_internal.h` (to extract the underlying
|
||||
`tensorflow::Status`). All other targets in this directory are private.
|
||||
|
||||
## saved_model/core
|
||||
|
||||
`saved_model/core` contains pure C++ "Classes" underlying the C API types
|
||||
in `saved_model/public/`. These are implementation
|
||||
details subject to change, and have limited visibility to implementors only.
|
||||
This is the bottom-most layer of the `C++ -> C -> C++` sandwich.
|
85
tensorflow/c/experimental/saved_model/core/BUILD
Normal file
85
tensorflow/c/experimental/saved_model/core/BUILD
Normal file
@ -0,0 +1,85 @@
|
||||
# Experimental SavedModel C APIs for TensorFlow. See RFC
|
||||
# https://github.com/tensorflow/community/pull/207
|
||||
# Targets in this directory are pure C++ "Classes" underlying the C API types
|
||||
# under tf/c/experimental/saved_model/public/. They are subject to change and
|
||||
# have visibility limited to Tensorflow's implementation only.
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow/c:__subpackages__",
|
||||
"//tensorflow/c/experimental/saved_model/internal:__pkg__",
|
||||
"//tensorflow/core:__subpackages__",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "concrete_function",
|
||||
srcs = [
|
||||
"concrete_function.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"concrete_function.h",
|
||||
],
|
||||
deps = [
|
||||
":function_metadata",
|
||||
"//tensorflow/c/eager:operation_interface",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "function_metadata",
|
||||
hdrs = [
|
||||
"function_metadata.h",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "saved_model_api",
|
||||
hdrs = [
|
||||
"saved_model_api.h",
|
||||
],
|
||||
deps = [
|
||||
":concrete_function",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_saved_model_impl",
|
||||
srcs = [
|
||||
"tf_saved_model_impl.cc",
|
||||
],
|
||||
hdrs = ["tf_saved_model_impl.h"],
|
||||
deps = [
|
||||
":concrete_function",
|
||||
":saved_model_api",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "pywrap_required_hdrs",
|
||||
textual_hdrs = [
|
||||
"concrete_function.h",
|
||||
"function_metadata.h",
|
||||
"saved_model_api.h",
|
||||
],
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "mobile_srcs_only_runtime",
|
||||
srcs = [
|
||||
"concrete_function.cc",
|
||||
"concrete_function.h",
|
||||
"function_metadata.h",
|
||||
"saved_model_api.h",
|
||||
"tf_saved_model_impl.cc",
|
||||
"tf_saved_model_impl.h",
|
||||
],
|
||||
visibility = ["//tensorflow/core:__pkg__"],
|
||||
)
|
@ -0,0 +1,32 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const std::vector<tensorflow::AbstractTensorHandleInterface*>&
|
||||
ConcreteFunction::GetCaptures() const {
|
||||
return captures_;
|
||||
}
|
||||
|
||||
const FunctionMetadata& ConcreteFunction::GetFunctionMetadata() const {
|
||||
return metadata_;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,55 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Note that ConcreteFunctions's lifetimes are effectively bound
|
||||
// to the SavedModel they are loaded from, since they retain pointers
|
||||
// to the TensorHandles owned by the SavedModel, and the FunctionDef
|
||||
// of the SavedModel.
|
||||
// Note(bmzhao): This class is only TEMPORARILY virtual, as a way to unblock
|
||||
// TFRT integration with TF Serving. Do not add more virtual implementations of
|
||||
// this class. Eventually we want to remove this virtual base class indirection
|
||||
// and have only a single implementation.
|
||||
class ConcreteFunction {
|
||||
public:
|
||||
virtual ~ConcreteFunction() = 0;
|
||||
|
||||
// This method returns the "Call" Op used to execute the function.
|
||||
virtual AbstractOperationInterface* GetCallOp() = 0;
|
||||
|
||||
const std::vector<tensorflow::AbstractTensorHandleInterface*>& GetCaptures()
|
||||
const;
|
||||
const FunctionMetadata& GetFunctionMetadata() const;
|
||||
|
||||
private:
|
||||
FunctionMetadata metadata_;
|
||||
std::vector<tensorflow::AbstractTensorHandleInterface*> captures_;
|
||||
FunctionDef* function_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_
|
@ -0,0 +1,27 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_FUNCTION_METADATA_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_FUNCTION_METADATA_H_
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class FunctionMetadata {
|
||||
// TODO(bmzhao): Fill in with fields as necessary
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_FUNCTION_METADATA_H_
|
55
tensorflow/c/experimental/saved_model/core/saved_model_api.h
Normal file
55
tensorflow/c/experimental/saved_model/core/saved_model_api.h
Normal file
@ -0,0 +1,55 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_API_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_API_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Note(bmzhao): This class is only TEMPORARILY virtual, as a way to unblock
|
||||
// TFRT integration with TF Serving. Do not add more virtual implementations of
|
||||
// this class. Eventually we want to remove this virtual base class indirection
|
||||
// and have only a single implementation.
|
||||
class SavedModelAPI {
|
||||
public:
|
||||
// Retrieve a function from the TF2 SavedModel, using the "path" to a function
|
||||
// in a TF2 savedmodel.
|
||||
// Note: `function` is a double pointer, so that implementations are
|
||||
// able to return a pointer to an internal member.
|
||||
virtual Status GetFunction(const std::string& function_path,
|
||||
ConcreteFunction** function) = 0;
|
||||
|
||||
// Retrieve a function from a SavedModel, using the key of the
|
||||
// SignatureDef map:
|
||||
// https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89
|
||||
virtual Status GetSignatureDefFunction(const std::string& signature_def_key,
|
||||
ConcreteFunction** function) = 0;
|
||||
|
||||
virtual std::vector<ConcreteFunction*> ListFunctions() = 0;
|
||||
|
||||
virtual ~SavedModelAPI() = default;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_API_H_
|
@ -0,0 +1,60 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.h"
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status TFSavedModelAPIImpl::GetFunction(const std::string& function_path,
|
||||
ConcreteFunction** function) {
|
||||
// TODO(bmzhao): Add support for retrieving a function.
|
||||
return errors::Unimplemented(
|
||||
"Retrieving functions is unimplemented currently");
|
||||
}
|
||||
|
||||
Status TFSavedModelAPIImpl::GetSignatureDefFunction(
|
||||
const std::string& signature_def_key, ConcreteFunction** function) {
|
||||
// TODO(bmzhao): Add support for retrieving a signaturedef function.
|
||||
return errors::Unimplemented(
|
||||
"Retrieving functions is unimplemented currently");
|
||||
}
|
||||
|
||||
std::vector<ConcreteFunction*> TFSavedModelAPIImpl::ListFunctions() {
|
||||
std::vector<ConcreteFunction*> result;
|
||||
result.reserve(functions_.size());
|
||||
for (ConcreteFunction& function : functions_) {
|
||||
result.push_back(&function);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Status TFSavedModelAPIImpl::Load(
|
||||
const std::string& directory,
|
||||
const absl::optional<std::unordered_set<std::string>>& tags,
|
||||
TFSavedModelAPIImpl* out) {
|
||||
// TODO(bmzhao): Add support for loading a TFSavedModelImpl.
|
||||
return errors::Unimplemented(
|
||||
"TFSavedModelAPIImpl loading is unimplemented currently");
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,55 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TFSavedModelAPIImpl : public SavedModelAPI {
|
||||
public:
|
||||
TFSavedModelAPIImpl() = default;
|
||||
|
||||
Status GetFunction(const std::string& function_path,
|
||||
ConcreteFunction** function) override;
|
||||
|
||||
Status GetSignatureDefFunction(const std::string& signature_def_key,
|
||||
ConcreteFunction** function) override;
|
||||
|
||||
static Status Load(
|
||||
const std::string& directory,
|
||||
const absl::optional<std::unordered_set<std::string>>& tags,
|
||||
TFSavedModelAPIImpl* out);
|
||||
|
||||
std::vector<ConcreteFunction*> ListFunctions() override;
|
||||
|
||||
~TFSavedModelAPIImpl() override = default;
|
||||
|
||||
private:
|
||||
std::vector<ConcreteFunction> functions_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_
|
181
tensorflow/c/experimental/saved_model/internal/BUILD
Normal file
181
tensorflow/c/experimental/saved_model/internal/BUILD
Normal file
@ -0,0 +1,181 @@
|
||||
# Experimental Implementation of SavedModel C APIs for TensorFlow. See RFC
|
||||
# https://github.com/tensorflow/community/pull/207
|
||||
# External clients should not worry about this directory; all contents are implementation details.
|
||||
# Code in this directory is intended to form the glue between the C API and the internal C++
|
||||
# implementation by
|
||||
# 1. mapping C API calls onto correponding methods of C++ objects
|
||||
# 2. mapping opaque C types onto C++ classes
|
||||
|
||||
# Note(bmzhao): The *.cc files in this directory form the direct implementation of the
|
||||
# C API functions exposed in tf/c/experimental/saved_model/public/.
|
||||
|
||||
# Note(bmzhao): All *type.h files in this directory are the internal definitions of
|
||||
# the opaque C types. These headers should only be visible to internal tensorflow
|
||||
# implementors.
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
"tf_copts",
|
||||
)
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "concrete_function",
|
||||
srcs = [
|
||||
"concrete_function.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"//tensorflow/c/experimental/saved_model/public:concrete_function.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
# TODO(bmzhao): Remove this as we refactor C API to granular targets,
|
||||
# so that we can depend on c/eager/c_api_unified_experimental.h.
|
||||
features = ["-layering_check"],
|
||||
visibility = [
|
||||
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":concrete_function_type",
|
||||
":function_metadata",
|
||||
":function_metadata_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_internal",
|
||||
"//tensorflow/c/eager:tfe_op_internal",
|
||||
"//tensorflow/c/experimental/saved_model/core:concrete_function",
|
||||
"//tensorflow/c/experimental/saved_model/core:function_metadata",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "concrete_function_list",
|
||||
srcs = [
|
||||
"concrete_function_list.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"//tensorflow/c/experimental/saved_model/public:concrete_function_list.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = [
|
||||
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":concrete_function",
|
||||
":concrete_function_list_type",
|
||||
":concrete_function_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c/experimental/saved_model/core:concrete_function",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "concrete_function_list_type",
|
||||
hdrs = [
|
||||
"concrete_function_list_type.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/experimental/saved_model/core:concrete_function",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "concrete_function_type",
|
||||
hdrs = [
|
||||
"concrete_function_type.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c/experimental/saved_model/core:concrete_function",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "function_metadata",
|
||||
srcs = [
|
||||
"function_metadata.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"//tensorflow/c/experimental/saved_model/public:function_metadata.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = [
|
||||
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":function_metadata_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c/experimental/saved_model/core:function_metadata",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "function_metadata_type",
|
||||
hdrs = [
|
||||
"function_metadata_type.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c/experimental/saved_model/core:function_metadata",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "saved_model_api",
|
||||
srcs = [
|
||||
"saved_model_api.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"//tensorflow/c/experimental/saved_model/public:saved_model_api.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = [
|
||||
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":concrete_function",
|
||||
":concrete_function_list",
|
||||
":concrete_function_list_type",
|
||||
":concrete_function_type",
|
||||
":saved_model_api_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_status_internal",
|
||||
"//tensorflow/c/eager:tfe_context_internal",
|
||||
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "saved_model_api_type",
|
||||
hdrs = [
|
||||
"saved_model_api_type.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "saved_model_api_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"saved_model_api_test.cc",
|
||||
],
|
||||
data = [
|
||||
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/c/experimental/saved_model/public:saved_model_api",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
@ -0,0 +1,42 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) {
|
||||
return tensorflow::wrap(const_cast<tensorflow::FunctionMetadata*>(
|
||||
&tensorflow::unwrap(func)->GetFunctionMetadata()));
|
||||
}
|
||||
|
||||
TF_OutputList* TF_ConcreteFunctionGetCaptures(TF_ConcreteFunction* func) {
|
||||
// TODO(bmzhao): Refactor TF_OutputList struct definition into a separate
|
||||
// internal header, and implement this function.
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func) {
|
||||
return tensorflow::wrap(tensorflow::unwrap(func)->GetCallOp());
|
||||
}
|
||||
|
||||
} // end extern "C"
|
@ -0,0 +1,37 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
size_t TF_ConcreteFunctionListNumOutputs(TF_ConcreteFunctionList* list) {
|
||||
return list->list.size();
|
||||
}
|
||||
|
||||
TF_ConcreteFunction* TF_ConcreteFunctionListGet(TF_ConcreteFunctionList* list,
|
||||
int i) {
|
||||
return tensorflow::wrap(list->list[i]);
|
||||
}
|
||||
|
||||
void TF_DeleteConcreteFunctionList(TF_ConcreteFunctionList* list) {
|
||||
delete list;
|
||||
}
|
||||
|
||||
} // end extern "C"
|
@ -0,0 +1,30 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
|
||||
// Internal structures used by the SavedModel C API. These are likely to change
|
||||
// and should not be depended on.
|
||||
|
||||
struct TF_ConcreteFunctionList {
|
||||
std::vector<tensorflow::ConcreteFunction*> list;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
|
@ -0,0 +1,36 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
|
||||
// Internal structures used by the SavedModel C API. These are likely to change
|
||||
// and should not be depended on.
|
||||
|
||||
// It doesn't make sense to wrap tensorflow::ConcreteFunction* in a separate
|
||||
// struct, since the lifetime of the struct and the raw pointer it wraps would
|
||||
// be different. Therefore TF_ConcreteFunction* = tensorflow::ConcreteFunction*.
|
||||
typedef struct TF_ConcreteFunction TF_ConcreteFunction;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ConcreteFunction, TF_ConcreteFunction)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%{
|
||||
#include "tensorflow/lite/experimental/kernels/hashtable_ops.h"
|
||||
%}
|
||||
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
||||
|
||||
%include "tensorflow/lite/experimental/kernels/hashtable_ops.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h"
|
||||
|
||||
// TODO(bmzhao): Add getter functions here as necessary.
|
@ -0,0 +1,30 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
|
||||
typedef struct TF_FunctionMetadata TF_FunctionMetadata;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::FunctionMetadata, TF_FunctionMetadata)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_
|
@ -0,0 +1,97 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
|
||||
TF_Status* status) {
|
||||
std::string saved_model_dir(dirname);
|
||||
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> result =
|
||||
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, absl::nullopt,
|
||||
&status->status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TF_SavedModel{std::move(result)};
|
||||
}
|
||||
|
||||
TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
|
||||
const char* const* tags, int tags_len,
|
||||
TF_Status* status) {
|
||||
std::string saved_model_dir(dirname);
|
||||
|
||||
std::unordered_set<std::string> tagset;
|
||||
for (int i = 0; i < tags_len; ++i) {
|
||||
tagset.insert(std::string(tags[i]));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> result =
|
||||
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, std::move(tagset),
|
||||
&status->status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TF_SavedModel{std::move(result)};
|
||||
}
|
||||
|
||||
void TF_DeleteSavedModel(TF_SavedModel* model) { delete model; }
|
||||
|
||||
TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model,
|
||||
const char* function_path,
|
||||
TF_Status* status) {
|
||||
tensorflow::ConcreteFunction* result = nullptr;
|
||||
tensorflow::Status get_function_status =
|
||||
model->saved_model->GetFunction(function_path, &result);
|
||||
status->status.Update(get_function_status);
|
||||
if (!get_function_status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return tensorflow::wrap(result);
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
|
||||
TF_SavedModel* model, const char* signature_def_key, TF_Status* status) {
|
||||
tensorflow::ConcreteFunction* result = nullptr;
|
||||
tensorflow::Status get_function_status =
|
||||
model->saved_model->GetSignatureDefFunction(signature_def_key, &result);
|
||||
status->status.Update(get_function_status);
|
||||
if (!get_function_status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return tensorflow::wrap(result);
|
||||
}
|
||||
|
||||
TF_ConcreteFunctionList* TF_ListSavedModelFunctions(TF_SavedModel* model) {
|
||||
return new TF_ConcreteFunctionList{model->saved_model->ListFunctions()};
|
||||
}
|
||||
|
||||
} // end extern "C"
|
@ -0,0 +1,109 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kTestData[] = "cc/saved_model/testdata";
|
||||
const char* kServeTag[] = {"serve"};
|
||||
|
||||
std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
|
||||
return tensorflow::io::JoinPath(tensorflow::testing::TensorFlowSrcRoot(),
|
||||
kTestData, saved_model_dir);
|
||||
}
|
||||
|
||||
// This value parameterized test allows us to test both TFRT
|
||||
// and non TFRT runtimes.
|
||||
// https://github.com/google/googletest/blob/dcc92d0ab6c4ce022162a23566d44f673251eee4/googletest/docs/advanced.md#value-parameterized-tests
|
||||
class CSavedModelAPITest : public ::testing::TestWithParam<bool> {};
|
||||
|
||||
TEST_P(CSavedModelAPITest, LoadsSavedModelWithTags) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
bool use_tfrt = GetParam();
|
||||
if (use_tfrt) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TF_DeleteStatus(status);
|
||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||
}
|
||||
|
||||
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
||||
|
||||
TF_SavedModel* saved_model =
|
||||
TF_LoadSavedModelWithTags(model_dir.c_str(), ctx, kServeTag, 1, status);
|
||||
|
||||
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
||||
// That unblocks writing other tests that require a TF_SavedModel*,
|
||||
// like loading a ConcreteFunction. This test at least checks that the
|
||||
// C API builds and can be minimally run.
|
||||
EXPECT_EQ(TF_GetCode(status), TF_UNIMPLEMENTED);
|
||||
|
||||
TF_DeleteSavedModel(saved_model);
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST_P(CSavedModelAPITest, LoadsSavedModel) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
bool use_tfrt = GetParam();
|
||||
if (use_tfrt) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TF_DeleteStatus(status);
|
||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||
}
|
||||
|
||||
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
||||
|
||||
TF_SavedModel* saved_model =
|
||||
TF_LoadSavedModel(model_dir.c_str(), ctx, status);
|
||||
|
||||
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
||||
// That unblocks writing other tests that require a TF_SavedModel*,
|
||||
// like loading a ConcreteFunction. This test at least checks that the
|
||||
// C API builds and can be minimally run.
|
||||
EXPECT_EQ(TF_GetCode(status), TF_UNIMPLEMENTED);
|
||||
|
||||
TF_DeleteSavedModel(saved_model);
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticSavedModelTests, CSavedModelAPITest,
|
||||
::testing::Bool());
|
||||
|
||||
} // namespace
|
@ -0,0 +1,30 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
|
||||
// Internal structures used by the SavedModel C API. These are likely to change
|
||||
// and should not be depended on.
|
||||
|
||||
struct TF_SavedModel {
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> saved_model;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_
|
63
tensorflow/c/experimental/saved_model/public/BUILD
Normal file
63
tensorflow/c/experimental/saved_model/public/BUILD
Normal file
@ -0,0 +1,63 @@
|
||||
# Experimental SavedModel C APIs for TensorFlow.
|
||||
# See RFC https://github.com/tensorflow/community/pull/207
|
||||
# All headers are on the public surface of Tensorflow's C API.
|
||||
# Once moved out of experimental, these will be stable.
|
||||
# The idea behind a separate public/ directory is to make apparent
|
||||
# which headers are part of TF's public interface (and which headers)
|
||||
# are implementation details. This structure allows us to also perform future
|
||||
# programmatic checks that all "public" headers only include other "public"
|
||||
# headers.
|
||||
|
||||
package(
|
||||
# This is intentionally public
|
||||
default_visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
# TODO(bmzhao): Remove these exports_files and rules, swap with cc_public_library instead.
|
||||
# cc_public_library would allows us to separate the header dep graph from header+srcs dep graph.
|
||||
exports_files(
|
||||
[
|
||||
"concrete_function.h",
|
||||
"concrete_function_list.h",
|
||||
"function_metadata.h",
|
||||
"saved_model_api.h",
|
||||
],
|
||||
visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"],
|
||||
)
|
||||
|
||||
# The purpose of this header is to provide insulation against
|
||||
# future changes where we rename/move a public header, without
|
||||
# forcing all clients to change their "#includes".
|
||||
cc_library(
|
||||
name = "c_saved_model_api",
|
||||
hdrs = ["c_saved_model_api.h"],
|
||||
deps = [
|
||||
":concrete_function",
|
||||
":concrete_function_list",
|
||||
":function_metadata",
|
||||
":saved_model_api",
|
||||
],
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "concrete_function",
|
||||
actual = "//tensorflow/c/experimental/saved_model/internal:concrete_function",
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "concrete_function_list",
|
||||
actual = "//tensorflow/c/experimental/saved_model/internal:concrete_function_list",
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "function_metadata",
|
||||
actual = "//tensorflow/c/experimental/saved_model/internal:function_metadata",
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "saved_model_api",
|
||||
actual = "//tensorflow/c/experimental/saved_model/internal:saved_model_api",
|
||||
)
|
@ -0,0 +1,26 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_
|
@ -0,0 +1,50 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_
|
||||
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// An opaque type that corresponds to a Function loaded from a SavedModel.
|
||||
// TODO(bmzhao): Work together w/srbs@ to make sure this composes w/the
|
||||
// C++ Unified Eager/Graph API's AbstractFunction
|
||||
typedef struct TF_ConcreteFunction TF_ConcreteFunction;
|
||||
|
||||
// Returns FunctionMetadata associated with `func`. Metadata's lifetime is
|
||||
// bound to `func`, which is bound to the TF_SavedModel it was loaded from.
|
||||
TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(
|
||||
TF_ConcreteFunction* func);
|
||||
|
||||
// Returns a list of TensorHandles implicitly captured by this function.
|
||||
TF_CAPI_EXPORT extern TF_OutputList* TF_ConcreteFunctionGetCaptures(
|
||||
TF_ConcreteFunction* func);
|
||||
|
||||
// Returns a TFE_Op suitable for executing this function.
|
||||
TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp(
|
||||
TF_ConcreteFunction* func);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_
|
@ -0,0 +1,39 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||
|
||||
// An opaque type that is acts like a list of TF_ConcreteFunction pointers.
|
||||
typedef struct TF_ConcreteFunctionList TF_ConcreteFunctionList;
|
||||
|
||||
// Returns the size of `list`.
|
||||
TF_CAPI_EXPORT size_t
|
||||
TF_ConcreteFunctionListSize(TF_ConcreteFunctionList* list);
|
||||
|
||||
// Returns the `i`th TF_ConcreteFunction in the list.
|
||||
TF_CAPI_EXPORT TF_ConcreteFunction* TF_ConcreteFunctionListGet(
|
||||
TF_ConcreteFunctionList* list, int i);
|
||||
|
||||
// Deletes `list`.
|
||||
TF_CAPI_EXPORT void TF_DeleteConcreteFunctionList(
|
||||
TF_ConcreteFunctionList* list);
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
@ -0,0 +1,35 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_FUNCTION_METADATA_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_FUNCTION_METADATA_H_
|
||||
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// An opaque type used to store any metadata associated with a function.
|
||||
typedef struct TF_FunctionMetadata TF_FunctionMetadata;
|
||||
|
||||
// TODO(bmzhao): Add getters for fields as we determine what metadata
|
||||
// we want to expose.
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_FUNCTION_METADATA_H_
|
108
tensorflow/c/experimental/saved_model/public/saved_model_api.h
Normal file
108
tensorflow/c/experimental/saved_model/public/saved_model_api.h
Normal file
@ -0,0 +1,108 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SAVED_MODEL_API_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SAVED_MODEL_API_H_
|
||||
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// An opaque type representing a Tensorflow "SavedModel"
|
||||
// (https://www.tensorflow.org/guide/saved_model) that we always pass by pointer
|
||||
// to achieve ABI stability.
|
||||
typedef struct TF_SavedModel TF_SavedModel;
|
||||
|
||||
// Load a SavedModel from `dirname`. We expect the SavedModel to contain a
|
||||
// single Metagraph (as for those exported from TF2's `tf.saved_model.save`).
|
||||
//
|
||||
// Params:
|
||||
// dirname - A directory filepath that the SavedModel is at.
|
||||
// ctx - A TFE_Context containing optional load/TF runtime options.
|
||||
// `ctx` must outlive the returned TF_SavedModel pointer.
|
||||
// status - Set to OK on success and an appropriate error on failure.
|
||||
// Returns:
|
||||
// If status is not OK, returns nullptr. Otherwise, returns a newly created
|
||||
// TF_SavedModel instance. It must be deleted by calling TF_DeleteSavedModel.
|
||||
TF_CAPI_EXPORT extern TF_SavedModel* TF_LoadSavedModel(const char* dirname,
|
||||
TFE_Context* ctx,
|
||||
TF_Status* status);
|
||||
|
||||
// Load a SavedModel from `dirname`.
|
||||
//
|
||||
// Params:
|
||||
// dirname - A directory filepath that the SavedModel is at.
|
||||
// ctx - A TFE_Context containing optional load/TF runtime options.
|
||||
// `ctx` must outlive the returned TF_SavedModel pointer.
|
||||
// tags - char* array of SavedModel tags. We will load the metagraph matching
|
||||
// the tags.
|
||||
// tags_len - number of elements in the `tags` array.
|
||||
// status - Set to OK on success and an appropriate error on failure.
|
||||
// Returns:
|
||||
// If status is not OK, returns nullptr. Otherwise, returns a newly created
|
||||
// TF_SavedModel instance. It must be deleted by calling TF_DeleteSavedModel.
|
||||
TF_CAPI_EXPORT extern TF_SavedModel* TF_LoadSavedModelWithTags(
|
||||
const char* dirname, TFE_Context* ctx, const char* const* tags,
|
||||
int tags_len, TF_Status* status);
|
||||
|
||||
// Deletes a TF_SavedModel, and frees any resources owned by it.
|
||||
TF_CAPI_EXPORT extern void TF_DeleteSavedModel(TF_SavedModel* model);
|
||||
|
||||
// Retrieve a function from the TF2 SavedModel via function path.
|
||||
//
|
||||
// Params:
|
||||
// model - The TF2 SavedModel to load a function from.
|
||||
// function_path - A string containing the path from the root saved python
|
||||
// object to a tf.function method.
|
||||
// TODO(bmzhao): Add a detailed example of this with a
|
||||
// python tf.module before moving this out of experimental.
|
||||
// status - Set to OK on success and an appropriate error on failure.
|
||||
// Returns:
|
||||
// If status is not OK, returns nullptr. Otherwise, returns a
|
||||
// TF_ConcreteFunction instance. The lifetime of this instance is
|
||||
// "conceptually" bound to `model`. Once `model` is deleted, all
|
||||
// `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted.
|
||||
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(
|
||||
TF_SavedModel* model, const char* function_path, TF_Status* status);
|
||||
|
||||
// Retrieve a function from the TF SavedModel via a SignatureDef key.
|
||||
//
|
||||
// Params:
|
||||
// model - The SavedModel to load a function from.
|
||||
// signature_def_key - The string key of the SignatureDef map of a SavedModel:
|
||||
// https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89
|
||||
// status - Set to OK on success and an appropriate error on failure.
|
||||
// Returns:
|
||||
// If status is not OK, returns nullptr. Otherwise, returns a
|
||||
// TF_ConcreteFunction instance. Once `model` is deleted, all
|
||||
// `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted.
|
||||
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
|
||||
TF_SavedModel* model, const char* signature_def_key, TF_Status* status);
|
||||
|
||||
// Returns a list of all ConcreteFunctions stored in this SavedModel.
|
||||
// The lifetime of the returned list is bound to `model`.
|
||||
TF_CAPI_EXPORT extern TF_ConcreteFunctionList* TF_ListSavedModelFunctions(
|
||||
TF_SavedModel* model);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SAVED_MODEL_API_H_
|
@ -182,6 +182,7 @@ cc_library_with_android_deps(
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:ops",
|
||||
|
63
tensorflow/cc/experimental/base/public/BUILD
Normal file
63
tensorflow/cc/experimental/base/public/BUILD
Normal file
@ -0,0 +1,63 @@
|
||||
# Experimental C++ APIs for TensorFlow.
|
||||
# New TF C++ APIs under the tensorflow::cc namespace aim to guarantee ABI stability.
|
||||
# Users are expected to compile against public c++ headers, and link against
|
||||
# libtensorflow (https://www.tensorflow.org/install/lang_c).
|
||||
# We aim to achieve ABI stability in new C++ APIs by only using types
|
||||
# on the API surface that:
|
||||
# 1. Have a header-only implementation
|
||||
# 2. Are std:: types
|
||||
# 3. Wrap an opaque C type
|
||||
|
||||
package(
|
||||
# This is intentionally public
|
||||
default_visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "runtime",
|
||||
hdrs = [
|
||||
"runtime.h",
|
||||
],
|
||||
deps = [
|
||||
":status",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "runtime_builder",
|
||||
hdrs = [
|
||||
"runtime_builder.h",
|
||||
],
|
||||
deps = [
|
||||
":runtime",
|
||||
":status",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "status",
|
||||
hdrs = [
|
||||
"status.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:tf_status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor",
|
||||
hdrs = [
|
||||
"tensor.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/c:tf_tensor",
|
||||
],
|
||||
)
|
68
tensorflow/cc/experimental/base/public/runtime.h
Normal file
68
tensorflow/cc/experimental/base/public/runtime.h
Normal file
@ -0,0 +1,68 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_
|
||||
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace cc {
|
||||
|
||||
// Runtime represents an opaque instance of a Tensorflow runtime, with its own
|
||||
// resources, threadpools, etc. Clients are expected to construct a Runtime
|
||||
// object through tensorflow::cc::RuntimeBuilder::Build, after setting any
|
||||
// relevant configuration options. Many Tensorflow functions take a reference to
|
||||
// the runtime as an argument (eg: tensorflow::cc::SavedModelAPI::Load), and
|
||||
// may have different implementations depending on the runtime. For many of
|
||||
// these Runtime-attached objects (such as tensorflow::cc::TensorHandle), the
|
||||
// Runtime must outlive these objects.
|
||||
class Runtime {
|
||||
public:
|
||||
// Runtime is movable, but not copyable.
|
||||
Runtime(Runtime&&) = default;
|
||||
Runtime& operator=(Runtime&&) = default;
|
||||
|
||||
private:
|
||||
friend class RuntimeBuilder;
|
||||
friend class SavedModelAPI;
|
||||
|
||||
// Wraps a TFE_Context. Takes ownership of ctx.
|
||||
explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {}
|
||||
|
||||
// Deletes the currently wrapped TFE_Context, swaps it with ctx,
|
||||
// and takes ownership of ctx.
|
||||
void Reset(TFE_Context* ctx) { ctx_.reset(ctx); }
|
||||
|
||||
// Returns the TFE_Context that this object wraps. This object
|
||||
// retains ownership of the pointer.
|
||||
TFE_Context* GetTFEContext() const { return ctx_.get(); }
|
||||
|
||||
// Runtime is not copyable
|
||||
Runtime(const Runtime&) = delete;
|
||||
Runtime& operator=(const Runtime&) = delete;
|
||||
|
||||
struct TFEContextDeleter {
|
||||
void operator()(TFE_Context* p) const { TFE_DeleteContext(p); }
|
||||
};
|
||||
std::unique_ptr<TFE_Context, TFEContextDeleter> ctx_;
|
||||
};
|
||||
|
||||
} // namespace cc
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_
|
84
tensorflow/cc/experimental/base/public/runtime_builder.h
Normal file
84
tensorflow/cc/experimental/base/public/runtime_builder.h
Normal file
@ -0,0 +1,84 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_
|
||||
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/cc/experimental/base/public/runtime.h"
|
||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace cc {
|
||||
|
||||
// RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime.
|
||||
// Use this to set configuration options, like threadpool size, etc.
|
||||
class RuntimeBuilder {
|
||||
public:
|
||||
RuntimeBuilder() : options_(TFE_NewContextOptions()) {}
|
||||
|
||||
// If `use_tfrt` is true, we will use the new Tensorflow Runtime
|
||||
// (https://blog.tensorflow.org/2020/04/tfrt-new-tensorflow-runtime.html) as
|
||||
// our runtime implementation.
|
||||
RuntimeBuilder& SetUseTFRT(bool use_tfrt);
|
||||
|
||||
// Build a Tensorflow Runtime.
|
||||
//
|
||||
// Params:
|
||||
// status - Set to OK on success and an appropriate error on failure.
|
||||
// Returns:
|
||||
// If status is not OK, returns nullptr. Otherwise, returns a
|
||||
// unique_ptr<tensorflow::cc::Runtime>.
|
||||
std::unique_ptr<Runtime> Build(Status* status);
|
||||
|
||||
// RuntimeBuilder is movable, but not copyable.
|
||||
RuntimeBuilder(RuntimeBuilder&&) = default;
|
||||
RuntimeBuilder& operator=(RuntimeBuilder&&) = default;
|
||||
|
||||
private:
|
||||
// RuntimeBuilder is not copyable
|
||||
RuntimeBuilder(const RuntimeBuilder&) = delete;
|
||||
RuntimeBuilder& operator=(const RuntimeBuilder&) = delete;
|
||||
|
||||
struct TFEContextOptionsDeleter {
|
||||
void operator()(TFE_ContextOptions* p) const {
|
||||
TFE_DeleteContextOptions(p);
|
||||
}
|
||||
};
|
||||
std::unique_ptr<TFE_ContextOptions, TFEContextOptionsDeleter> options_;
|
||||
};
|
||||
|
||||
inline RuntimeBuilder& RuntimeBuilder::SetUseTFRT(bool use_tfrt) {
|
||||
TFE_ContextOptionsSetTfrt(options_.get(), use_tfrt);
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline std::unique_ptr<Runtime> RuntimeBuilder::Build(Status* status) {
|
||||
TFE_Context* result = TFE_NewContext(options_.get(), status->GetTFStatus());
|
||||
if (!status->ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
// We can't use std::make_unique here because of its interaction with a
|
||||
// private constructor: https://abseil.io/tips/134
|
||||
return std::unique_ptr<Runtime>(new Runtime(result));
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_
|
93
tensorflow/cc/experimental/base/public/status.h
Normal file
93
tensorflow/cc/experimental/base/public/status.h
Normal file
@ -0,0 +1,93 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_
|
||||
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace cc {
|
||||
|
||||
// Status is a wrapper around an error code and an optional error message.
|
||||
// The set of error codes are defined here:
|
||||
// https://github.com/tensorflow/tensorflow/blob/08931c1e3e9eb2e26230502d678408e66730826c/tensorflow/c/tf_status.h#L39-L60
|
||||
// Many Tensorflow APIs return a Status, or take a Status as an out parameter.
|
||||
// Clients should check for status.ok() after calling these APIs, and either
|
||||
// handle or propagate the error appropriately.
|
||||
// TODO(bmzhao): Add a detailed code example before moving out of experimental.
|
||||
class Status {
|
||||
public:
|
||||
// Create a success status
|
||||
Status() : status_(TF_NewStatus()) {}
|
||||
|
||||
// Return the status code
|
||||
TF_Code code() const;
|
||||
|
||||
// Returns the error message in Status.
|
||||
std::string message() const;
|
||||
|
||||
// Returns the error message in Status.
|
||||
bool ok() const;
|
||||
|
||||
// Record <code, msg> in Status. Any previous information is lost.
|
||||
// A common use is to clear a status: SetStatus(TF_OK, "");
|
||||
void SetStatus(TF_Code code, const std::string& msg);
|
||||
|
||||
// Status is movable, but not copyable.
|
||||
Status(Status&&) = default;
|
||||
Status& operator=(Status&&) = default;
|
||||
|
||||
private:
|
||||
friend class RuntimeBuilder;
|
||||
friend class Runtime;
|
||||
friend class SavedModelAPI;
|
||||
|
||||
// Wraps a TF_Status*, and takes ownership of it.
|
||||
explicit Status(TF_Status* status) : status_(status) {}
|
||||
|
||||
// Status is not copyable
|
||||
Status(const Status&) = delete;
|
||||
Status& operator=(const Status&) = delete;
|
||||
|
||||
// Returns the TF_Status that this object wraps. This object
|
||||
// retains ownership of the pointer.
|
||||
TF_Status* GetTFStatus() const { return status_.get(); }
|
||||
|
||||
struct TFStatusDeleter {
|
||||
void operator()(TF_Status* p) const { TF_DeleteStatus(p); }
|
||||
};
|
||||
std::unique_ptr<TF_Status, TFStatusDeleter> status_;
|
||||
};
|
||||
|
||||
inline TF_Code Status::code() const { return TF_GetCode(status_.get()); }
|
||||
|
||||
inline std::string Status::message() const {
|
||||
return std::string(TF_Message(status_.get()));
|
||||
}
|
||||
|
||||
inline bool Status::ok() const { return code() == TF_OK; }
|
||||
|
||||
inline void Status::SetStatus(TF_Code code, const std::string& msg) {
|
||||
TF_SetStatus(status_.get(), code, msg.c_str());
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_
|
117
tensorflow/cc/experimental/base/public/tensor.h
Normal file
117
tensorflow/cc/experimental/base/public/tensor.h
Normal file
@ -0,0 +1,117 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_
|
||||
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace cc {
|
||||
|
||||
// Tensor represents an n-dimensional array of values.
|
||||
class Tensor {
|
||||
public:
|
||||
// TODO(bmzhao): Add a factory function that constructs a Tensor from a char
|
||||
// buffer, with an options struct (to specify the buffer's layout, device?,
|
||||
// whether to create a TFRT or TF tensor, whether we should take ownership of
|
||||
// the memory, etc). This requires extending TF_NewTensor with an options
|
||||
// struct:
|
||||
// https://github.com/tensorflow/tensorflow/blob/3c520614a3c056d56afdc79b59979b9b0087f8b9/tensorflow/c/tf_tensor.h#L77-L80
|
||||
|
||||
// TODO(bmzhao): In the case we construct a tensor from non-owned memory,
|
||||
// we should offer a way to deep copy the tensor into a new tensor, which
|
||||
// owns the underlying memory. This could be a .deepcopy()/clone() method.
|
||||
|
||||
// TODO(bmzhao): In the future, we want to relax the non-copyability
|
||||
// constraint. To do so, we can add a C API function that acts like CopyFrom:
|
||||
// https://github.com/tensorflow/tensorflow/blob/08931c1e3e9eb2e26230502d678408e66730826c/tensorflow/core/framework/tensor.h#L301-L311
|
||||
|
||||
// Tensor is movable, but not copyable
|
||||
Tensor(Tensor&&) = default;
|
||||
Tensor& operator=(Tensor&&) = default;
|
||||
|
||||
// Returns the number of dimensions in the tensor. Can be -1, which represents
|
||||
// unknown rank.
|
||||
int dims() const;
|
||||
|
||||
// Returns the number of elements in in demension `d`.
|
||||
// REQUIRES: `0 <= d < dims()`
|
||||
int64_t dim_size(int d) const;
|
||||
|
||||
// Returns a pointer to the underlying data buffer.
|
||||
void* data() const;
|
||||
|
||||
// Returns the data type of the tensor.
|
||||
TF_DataType dtype() const;
|
||||
|
||||
// Returns the number of elements in the tensor. For a tensor with a partially
|
||||
// defined shape, -1 means not fully defined.
|
||||
int64_t num_elements() const;
|
||||
|
||||
// Returns the size of the underlying data in bytes.
|
||||
size_t num_bytes() const;
|
||||
|
||||
private:
|
||||
friend class TensorHandle;
|
||||
friend class Runtime;
|
||||
|
||||
// Wraps a TF_Tensor. Takes ownership of handle.
|
||||
explicit Tensor(TF_Tensor* tensor) : tensor_(tensor) {}
|
||||
|
||||
// Tensor is not copyable
|
||||
Tensor(const Tensor&) = delete;
|
||||
Tensor& operator=(const Tensor&) = delete;
|
||||
|
||||
// Returns the underlying TF_Tensor that this object wraps.
|
||||
// This object retains ownership of the pointer.
|
||||
TF_Tensor* GetTFTensor() const { return tensor_.get(); }
|
||||
|
||||
struct TFTensorDeleter {
|
||||
void operator()(TF_Tensor* p) const { TF_DeleteTensor(p); }
|
||||
};
|
||||
std::unique_ptr<TF_Tensor, TFTensorDeleter> tensor_;
|
||||
};
|
||||
|
||||
inline void* Tensor::data() const { return TF_TensorData(tensor_.get()); }
|
||||
|
||||
inline int Tensor::dims() const { return TF_NumDims(tensor_.get()); }
|
||||
|
||||
inline int64_t Tensor::dim_size(int d) const {
|
||||
return TF_Dim(tensor_.get(), d);
|
||||
}
|
||||
|
||||
inline TF_DataType Tensor::dtype() const {
|
||||
return TF_TensorType(tensor_.get());
|
||||
}
|
||||
|
||||
inline int64_t Tensor::num_elements() const {
|
||||
return TF_TensorElementCount(tensor_.get());
|
||||
}
|
||||
|
||||
inline size_t Tensor::num_bytes() const {
|
||||
return TF_TensorByteSize(tensor_.get());
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_
|
@ -13,19 +13,20 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/framework/gradients.h"
|
||||
|
||||
#include <deque>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||
#include "tensorflow/cc/framework/gradients.h"
|
||||
#include "tensorflow/cc/framework/while_gradients.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/while_context.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
|
||||
|
58
tensorflow/cc/saved_model/experimental/public/BUILD
Normal file
58
tensorflow/cc/saved_model/experimental/public/BUILD
Normal file
@ -0,0 +1,58 @@
|
||||
# Experimental C++ SavedModel Header Only APIs. See RFC
|
||||
# https://github.com/tensorflow/community/pull/207
|
||||
|
||||
package(
|
||||
# This is intentionally public
|
||||
default_visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "concrete_function",
|
||||
hdrs = [
|
||||
"concrete_function.h",
|
||||
],
|
||||
deps = [
|
||||
":function_metadata",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/experimental/saved_model/public:concrete_function",
|
||||
"//tensorflow/cc/experimental/base/public:status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "concrete_function_list",
|
||||
hdrs = [
|
||||
"concrete_function_list.h",
|
||||
],
|
||||
deps = [
|
||||
":concrete_function",
|
||||
"//tensorflow/c/experimental/saved_model/public:concrete_function_list",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "function_metadata",
|
||||
hdrs = [
|
||||
"function_metadata.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/experimental/saved_model/public:function_metadata",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "saved_model_api",
|
||||
hdrs = [
|
||||
"saved_model_api.h",
|
||||
],
|
||||
deps = [
|
||||
":concrete_function",
|
||||
":concrete_function_list",
|
||||
"//tensorflow/c/experimental/saved_model/public:saved_model_api",
|
||||
"//tensorflow/cc/experimental/base/public:runtime",
|
||||
"//tensorflow/cc/experimental/base/public:status",
|
||||
],
|
||||
)
|
@ -0,0 +1,59 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_
|
||||
#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||
#include "tensorflow/cc/saved_model/experimental/public/function_metadata.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace cc {
|
||||
|
||||
// ConcreteFunction is an executable "function" loaded from a SavedModelAPI.
|
||||
class ConcreteFunction final {
|
||||
public:
|
||||
// TODO(bmzhao): Adding ConcreteFunction::Run in subsequent CL, since
|
||||
// it depends on tensorflow::cc::Tensor and tensorflow::cc::TensorHandle
|
||||
|
||||
// Returns FunctionMetadata associated with this ConcreteFunction.
|
||||
const FunctionMetadata* GetFunctionMetadata();
|
||||
|
||||
private:
|
||||
friend class SavedModelAPI;
|
||||
friend class ConcreteFunctionList;
|
||||
|
||||
// TODO(bmzhao): Consider adding a macro for wrapping/unwrapping
|
||||
// when moving out of experimental.
|
||||
static ConcreteFunction* wrap(TF_ConcreteFunction* p) {
|
||||
return reinterpret_cast<ConcreteFunction*>(p);
|
||||
}
|
||||
static TF_ConcreteFunction* unwrap(ConcreteFunction* p) {
|
||||
return reinterpret_cast<TF_ConcreteFunction*>(p);
|
||||
}
|
||||
};
|
||||
|
||||
inline const FunctionMetadata* ConcreteFunction::GetFunctionMetadata() {
|
||||
return FunctionMetadata::wrap(TF_ConcreteFunctionGetMetadata(unwrap(this)));
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_
|
@ -0,0 +1,61 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
||||
#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
|
||||
#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace cc {
|
||||
|
||||
// ConcreteFunctionList helps convert an opaque pointer to an array of
|
||||
// ConcreteFunction pointers to a std::vector.
|
||||
class ConcreteFunctionList {
|
||||
public:
|
||||
// Converts this object to a std::vector<ConcreteFunction*>
|
||||
std::vector<ConcreteFunction*> ToVector();
|
||||
|
||||
private:
|
||||
friend class SavedModelAPI;
|
||||
// Wraps a TF_ConcreteFunctionList. Takes ownership of list.
|
||||
explicit ConcreteFunctionList(TF_ConcreteFunctionList* list) : list_(list) {}
|
||||
|
||||
struct TFConcreteFunctionListDeleter {
|
||||
void operator()(TF_ConcreteFunctionList* p) const {
|
||||
TF_DeleteConcreteFunctionList(p);
|
||||
}
|
||||
};
|
||||
std::unique_ptr<TF_ConcreteFunctionList, TFConcreteFunctionListDeleter> list_;
|
||||
};
|
||||
|
||||
inline std::vector<ConcreteFunction*> ConcreteFunctionList::ToVector() {
|
||||
int size = TF_ConcreteFunctionListSize(list_.get());
|
||||
std::vector<ConcreteFunction*> result;
|
||||
result.reserve(size);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
result.push_back(
|
||||
ConcreteFunction::wrap(TF_ConcreteFunctionListGet(list_.get(), i)));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
@ -0,0 +1,45 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_
|
||||
#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace cc {
|
||||
|
||||
// FunctionMetadata stores additional function information, including
|
||||
// optional signaturedef feeds/fetches (for TF1-based ConcreteFunctions),
|
||||
// a valid function path (for TF2-based ConcreteFunctions), and
|
||||
// the types + number of inputs and outputs.
|
||||
class FunctionMetadata final {
|
||||
// TODO(bmzhao): Add getters here as necessary.
|
||||
private:
|
||||
friend class ConcreteFunction;
|
||||
static FunctionMetadata* wrap(TF_FunctionMetadata* p) {
|
||||
return reinterpret_cast<FunctionMetadata*>(p);
|
||||
}
|
||||
static TF_FunctionMetadata* unwrap(FunctionMetadata* p) {
|
||||
return reinterpret_cast<TF_FunctionMetadata*>(p);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cc
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_
|
160
tensorflow/cc/saved_model/experimental/public/saved_model_api.h
Normal file
160
tensorflow/cc/saved_model/experimental/public/saved_model_api.h
Normal file
@ -0,0 +1,160 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_
|
||||
#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
|
||||
#include "tensorflow/cc/experimental/base/public/runtime.h"
|
||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||
#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h"
|
||||
#include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace cc {
|
||||
|
||||
// SavedModelAPI offers a way to load Tensorflow Saved Models
|
||||
// (https://www.tensorflow.org/guide/saved_model) and execute saved
|
||||
// tf.functions or legacy SignatureDefs in a TF2-idiomatic fashion.
|
||||
// See RFC 207
|
||||
// (https://github.com/tensorflow/community/blob/master/rfcs/20200218-tf-c-saved-model.md)
|
||||
// TODO(bmzhao): Add an e2e example here, once ConcreteFunction::Run is added.
|
||||
class SavedModelAPI {
|
||||
public:
|
||||
// Load a SavedModel from `dirname`.
|
||||
//
|
||||
// Params:
|
||||
// saved_model_path - A directory filepath that the SavedModel is at.
|
||||
// runtime - A runtime used to load SavedModelAPI. `runtime` must outlive the
|
||||
// returned TF_SavedModel pointer.
|
||||
// tags - Optional set of tags. If tags = nullptr, we expect the SavedModel
|
||||
// to contain a single Metagraph (as for those exported from TF2's
|
||||
// `tf.saved_model.save`). If tags != nullptr, we load the metagraph
|
||||
// matching the tags:
|
||||
// https://github.com/tensorflow/tensorflow/blob/428cdeda09aef81e958eeb274b83d27ad635b57b/tensorflow/core/protobuf/meta_graph.proto#L50-L56
|
||||
// status - Set to OK on success and an appropriate error on failure.
|
||||
// Returns:
|
||||
// If status is not OK, returns nullptr.
|
||||
static std::unique_ptr<SavedModelAPI> Load(
|
||||
const std::string& saved_model_path, const Runtime& runtime,
|
||||
Status* status, const std::unordered_set<std::string>* tags = nullptr);
|
||||
|
||||
// Retrieve a function from the TF2 SavedModel via function path.
|
||||
//
|
||||
// Params:
|
||||
// function_path - A string containing the path from the root saved python
|
||||
// object to a tf.function method.
|
||||
// status - Set to OK on success and an appropriate error on failure.
|
||||
// Returns:
|
||||
// If status is not OK, returns nullptr. Otherwise, returns a
|
||||
// tensorflow::cc::ConcreteFunction pointer. The lifetime of this pointer
|
||||
// is bound to SavedModelAPI it was loaded from.
|
||||
ConcreteFunction* GetConcreteFunction(const std::string& function_path,
|
||||
Status* status);
|
||||
|
||||
// Retrieve a function from the TF SavedModel via a SignatureDef key.
|
||||
//
|
||||
// Params:
|
||||
// signature_def_key - String key of SignatureDef map of a SavedModel:
|
||||
// https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89
|
||||
// status - Set to OK on success and an appropriate error on failure.
|
||||
// Returns:
|
||||
// If status is not OK, returns nullptr. Otherwise, returns a
|
||||
// tensorflow::cc::ConcreteFunction pointer. The lifetime of this pointer
|
||||
// is bound to SavedModelAPI it was loaded from.
|
||||
ConcreteFunction* GetSignatureDefFunction(const std::string& function_path,
|
||||
Status* status);
|
||||
|
||||
// Lists all Conrete Functions available from the SavedModel.
|
||||
std::vector<ConcreteFunction*> ListFunctions();
|
||||
|
||||
// SavedModelAPI is movable, but not copyable.
|
||||
SavedModelAPI(SavedModelAPI&&) = default;
|
||||
SavedModelAPI& operator=(SavedModelAPI&&) = default;
|
||||
|
||||
private:
|
||||
SavedModelAPI(const SavedModelAPI&) = delete;
|
||||
SavedModelAPI& operator=(const SavedModelAPI&) = delete;
|
||||
|
||||
explicit SavedModelAPI(TF_SavedModel* model) : saved_model_(model) {}
|
||||
struct TFSavedModelDeleter {
|
||||
void operator()(TF_SavedModel* p) const { TF_DeleteSavedModel(p); }
|
||||
};
|
||||
std::unique_ptr<TF_SavedModel, TFSavedModelDeleter> saved_model_;
|
||||
};
|
||||
|
||||
inline std::unique_ptr<SavedModelAPI> SavedModelAPI::Load(
|
||||
const std::string& saved_model_path, const Runtime& runtime, Status* status,
|
||||
const std::unordered_set<std::string>* tags) {
|
||||
TF_SavedModel* saved_model = nullptr;
|
||||
|
||||
if (tags == nullptr) {
|
||||
saved_model =
|
||||
TF_LoadSavedModel(saved_model_path.c_str(), runtime.GetTFEContext(),
|
||||
status->GetTFStatus());
|
||||
} else {
|
||||
std::vector<const char*> tags_vector;
|
||||
tags_vector.reserve(tags->size());
|
||||
for (const std::string& tag : *tags) {
|
||||
tags_vector.push_back(tag.c_str());
|
||||
}
|
||||
saved_model = TF_LoadSavedModelWithTags(
|
||||
saved_model_path.c_str(), runtime.GetTFEContext(), tags_vector.data(),
|
||||
tags_vector.size(), status->GetTFStatus());
|
||||
}
|
||||
|
||||
if (!status->ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// We can't use std::make_unique here because of its interaction with a
|
||||
// private constructor: https://abseil.io/tips/134
|
||||
return std::unique_ptr<SavedModelAPI>(new SavedModelAPI(saved_model));
|
||||
}
|
||||
|
||||
inline ConcreteFunction* SavedModelAPI::GetConcreteFunction(
|
||||
const std::string& function_path, Status* status) {
|
||||
TF_ConcreteFunction* function = TF_GetSavedModelConcreteFunction(
|
||||
saved_model_.get(), function_path.c_str(), status->GetTFStatus());
|
||||
if (!status->ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return ConcreteFunction::wrap(function);
|
||||
}
|
||||
|
||||
inline ConcreteFunction* SavedModelAPI::GetSignatureDefFunction(
|
||||
const std::string& function_path, Status* status) {
|
||||
TF_ConcreteFunction* function = TF_GetSavedModelSignatureDefFunction(
|
||||
saved_model_.get(), function_path.c_str(), status->GetTFStatus());
|
||||
if (!status->ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return ConcreteFunction::wrap(function);
|
||||
}
|
||||
|
||||
inline std::vector<ConcreteFunction*> SavedModelAPI::ListFunctions() {
|
||||
ConcreteFunctionList list(TF_ListSavedModelFunctions(saved_model_.get()));
|
||||
return list.ToVector();
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_
|
22
tensorflow/cc/saved_model/experimental/tests/BUILD
Normal file
22
tensorflow/cc/saved_model/experimental/tests/BUILD
Normal file
@ -0,0 +1,22 @@
|
||||
# Tests for the C++ header-only SavedModelAPI.
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "saved_model_api_test",
|
||||
srcs = [
|
||||
"saved_model_api_test.cc",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/cc/experimental/base/public:runtime",
|
||||
"//tensorflow/cc/experimental/base/public:runtime_builder",
|
||||
"//tensorflow/cc/experimental/base/public:status",
|
||||
"//tensorflow/cc/saved_model/experimental/public:saved_model_api",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
@ -0,0 +1,97 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/saved_model/experimental/public/saved_model_api.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/cc/experimental/base/public/runtime.h"
|
||||
#include "tensorflow/cc/experimental/base/public/runtime_builder.h"
|
||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kTestData[] = "cc/saved_model/testdata";
|
||||
|
||||
std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
|
||||
return tensorflow::io::JoinPath(tensorflow::testing::TensorFlowSrcRoot(),
|
||||
kTestData, saved_model_dir);
|
||||
}
|
||||
|
||||
// This value parameterized test allows us to test both TFRT
|
||||
// and non TFRT runtimes.
|
||||
// https://github.com/google/googletest/blob/dcc92d0ab6c4ce022162a23566d44f673251eee4/googletest/docs/advanced.md#value-parameterized-tests
|
||||
class CPPSavedModelAPITest : public ::testing::TestWithParam<bool> {};
|
||||
|
||||
TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
|
||||
cc::Status status;
|
||||
cc::RuntimeBuilder builder;
|
||||
bool use_tfrt = GetParam();
|
||||
if (use_tfrt) {
|
||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||
}
|
||||
|
||||
builder.SetUseTFRT(use_tfrt);
|
||||
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
||||
std::unordered_set<std::string> tags = {"serve"};
|
||||
std::unique_ptr<cc::SavedModelAPI> model =
|
||||
cc::SavedModelAPI::Load(model_dir, *runtime, &status, &tags);
|
||||
|
||||
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
||||
// That unblocks writing other tests that require a TF_SavedModel*,
|
||||
// like loading a ConcreteFunction. This test at least checks that the
|
||||
// C API builds and can be minimally run.
|
||||
EXPECT_EQ(status.code(), TF_UNIMPLEMENTED);
|
||||
}
|
||||
|
||||
TEST_P(CPPSavedModelAPITest, LoadsSavedModel) {
|
||||
cc::Status status;
|
||||
cc::RuntimeBuilder builder;
|
||||
bool use_tfrt = GetParam();
|
||||
if (use_tfrt) {
|
||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||
}
|
||||
|
||||
builder.SetUseTFRT(use_tfrt);
|
||||
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
||||
std::unique_ptr<cc::SavedModelAPI> model =
|
||||
cc::SavedModelAPI::Load(model_dir, *runtime, &status);
|
||||
|
||||
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
||||
// That unblocks writing other tests that require a TF_SavedModel*,
|
||||
// like loading a ConcreteFunction. This test at least checks that the
|
||||
// C API builds and can be minimally run.
|
||||
EXPECT_EQ(status.code(), TF_UNIMPLEMENTED);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests,
|
||||
CPPSavedModelAPITest, ::testing::Bool());
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace tensorflow
|
@ -19,12 +19,16 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/cc/saved_model/reader.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/protobuf_internal.h"
|
||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||
#include "tensorflow/core/protobuf/saver.pb.h"
|
||||
@ -65,12 +69,39 @@ uint64 GetLatencyMicroseconds(const uint64 start_microseconds) {
|
||||
return end_microseconds - start_microseconds;
|
||||
}
|
||||
|
||||
// Ensure that constant tensors loaded from the saved model have valid shape.
|
||||
// Also ensure that constant nodes have a value assigned to them.
|
||||
// TODO(b/154763635): this is temporary and will be replaced with a better audit
|
||||
static Status ValidateSavedTensors(const GraphDef& graph_def) {
|
||||
for (const auto& node : graph_def.node()) {
|
||||
const auto node_iterator = node.attr().find("value");
|
||||
if (node_iterator != node.attr().end()) {
|
||||
AttrValue node_value = node_iterator->second;
|
||||
if (node_value.has_tensor()) {
|
||||
const PartialTensorShape node_shape(node_value.tensor().tensor_shape());
|
||||
if (node_shape.num_elements() < 0) {
|
||||
return errors::FailedPrecondition(
|
||||
"Saved model contains node \"", node.name(), "\" (op \"",
|
||||
node.op(), "\") which initializes from a tensor with ",
|
||||
node_shape.num_elements(), " elements");
|
||||
}
|
||||
}
|
||||
} else if (node.op() == "Const") {
|
||||
return errors::FailedPrecondition(
|
||||
"Saved model contains node \"", node.name(),
|
||||
"\" which is a constant tensor but no value has been provided");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def,
|
||||
const SessionOptions& session_options,
|
||||
std::unique_ptr<Session>* session) {
|
||||
Session* session_p = nullptr;
|
||||
TF_RETURN_IF_ERROR(NewSession(session_options, &session_p));
|
||||
session->reset(session_p);
|
||||
TF_RETURN_IF_ERROR(ValidateSavedTensors(meta_graph_def.graph_def()));
|
||||
return (*session)->Create(meta_graph_def.graph_def());
|
||||
}
|
||||
|
||||
|
@ -40,6 +40,10 @@ constexpr char kTestDataInitOpV2[] =
|
||||
"cc/saved_model/testdata/half_plus_two_v2/00000123";
|
||||
constexpr char kTestDataV2DebugInfo[] =
|
||||
"cc/saved_model/testdata/x_plus_y_v2_debuginfo";
|
||||
constexpr char kTestFuzzGeneratedNegativeShape[] =
|
||||
"cc/saved_model/testdata/fuzz_generated/negative_shape";
|
||||
constexpr char kTestFuzzGeneratedConstWithNoValue[] =
|
||||
"cc/saved_model/testdata/fuzz_generated/const_with_no_value";
|
||||
|
||||
class LoaderTest : public ::testing::Test {
|
||||
protected:
|
||||
@ -256,5 +260,29 @@ TEST_F(LoaderTest, SavedModelV2DebugInfo) {
|
||||
EXPECT_NE(bundle.debug_info.get(), nullptr);
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, NegativeShapeDimension) {
|
||||
SavedModelBundle bundle;
|
||||
RunOptions run_options;
|
||||
SessionOptions session_options;
|
||||
|
||||
const string export_dir = io::JoinPath(testing::TensorFlowSrcRoot(),
|
||||
kTestFuzzGeneratedNegativeShape);
|
||||
Status st = LoadSavedModel(session_options, run_options, export_dir,
|
||||
{kSavedModelTagServe}, &bundle);
|
||||
EXPECT_FALSE(st.ok());
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, ConstNoValue) {
|
||||
SavedModelBundle bundle;
|
||||
RunOptions run_options;
|
||||
SessionOptions session_options;
|
||||
|
||||
const string export_dir = io::JoinPath(testing::TensorFlowSrcRoot(),
|
||||
kTestFuzzGeneratedConstWithNoValue);
|
||||
Status st = LoadSavedModel(session_options, run_options, export_dir,
|
||||
{kSavedModelTagServe}, &bundle);
|
||||
EXPECT_FALSE(st.ok());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
BIN
tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value
vendored
Normal file
Binary file not shown.
BIN
tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape
vendored
Normal file
Binary file not shown.
@ -38,7 +38,7 @@ namespace benchmark {
|
||||
struct Options {
|
||||
// kDefaultMicros specifies the default time to run the benchmark, and is used
|
||||
// if neither max_iters nor max_micros is set.
|
||||
static const int64 kDefaultMicros = 3000000;
|
||||
static constexpr int64 kDefaultMicros = 3000000;
|
||||
|
||||
int64 max_iters = 0; // Maximum iterations to run, ignored if <= 0.
|
||||
int64 max_micros = 0; // Maximum microseconds to run, ignored if <= 0.
|
||||
|
@ -38,6 +38,7 @@ def tf_library(
|
||||
tfcompile_tool = "//tensorflow/compiler/aot:tfcompile",
|
||||
include_standard_runtime_deps = True,
|
||||
enable_xla_hlo_profiling = False,
|
||||
enable_tracemes = False,
|
||||
mlir_components = "None",
|
||||
deps = None,
|
||||
tags = []):
|
||||
@ -89,6 +90,9 @@ def tf_library(
|
||||
enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated
|
||||
program, and emit metadata that lets us pretty-print the gathered
|
||||
profile counters.
|
||||
enable_tracemes: Tell tfcompile to generate calls to
|
||||
TraceMe::Activity{Start|End} around HLO instructions that can be used by
|
||||
Xprof to construct profiler timelines.
|
||||
mlir_components: When the value is "None", no components use MLIR. When
|
||||
the value is "Bridge", use MLIR to translate GraphDef to HLO.
|
||||
deps: a list of deps to include on the build rules for the generated
|
||||
@ -190,6 +194,11 @@ def tf_library(
|
||||
else:
|
||||
profiling_flag = ""
|
||||
|
||||
if enable_tracemes:
|
||||
traceme_flag = "--xla_cpu_enable_xprof_traceme=true"
|
||||
else:
|
||||
traceme_flag = "--xla_cpu_enable_xprof_traceme=false"
|
||||
|
||||
mlir_flag = "--mlir_components=" + mlir_components
|
||||
|
||||
srcs = [tfcompile_graph, config]
|
||||
@ -218,7 +227,7 @@ def tf_library(
|
||||
" --out_header=$(@D)/" + header_file +
|
||||
" --out_metadata_object=$(@D)/" + metadata_object_file +
|
||||
" --out_function_object=$(@D)/" + function_object_file +
|
||||
" " + flags + " " + profiling_flag + " " + mlir_flag
|
||||
" " + flags + " " + profiling_flag + " " + mlir_flag + " " + traceme_flag
|
||||
),
|
||||
tools = [tfcompile_tool],
|
||||
visibility = visibility,
|
||||
|
@ -34,6 +34,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/framework/memory_types.h"
|
||||
@ -41,7 +42,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
@ -18,12 +18,12 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/test_util.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/graph_def_builder_util.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
@ -46,6 +46,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
@ -55,7 +56,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/control_flow.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
@ -33,6 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
@ -45,7 +46,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/control_flow.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
@ -21,12 +21,12 @@ limitations under the License.
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/common_runtime/graph_def_builder_util.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
|
@ -25,12 +25,11 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
|
@ -28,9 +28,9 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
@ -21,8 +21,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||
#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/test_util.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/lib/strings/proto_serialization.h"
|
||||
|
@ -24,9 +24,9 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/test_util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
@ -358,13 +358,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_,
|
||||
resources_, constants_, /*lazy=*/false, &client, &variables, &kernel,
|
||||
&executable);
|
||||
if (!s.ok() && (platform_info_.device_type().type_string() == DEVICE_CPU ||
|
||||
platform_info_.device_type().type_string() == DEVICE_GPU)) {
|
||||
// Suggest auto jit if the failure was with GPU or CPU.
|
||||
errors::AppendToMessage(&s,
|
||||
xla::status_macros::kPossibleAutoJitAlternative);
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(ctx, s);
|
||||
}
|
||||
|
||||
|
@ -40,6 +40,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/framework/memory_types.h"
|
||||
@ -49,7 +50,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/control_flow.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
@ -2078,6 +2078,8 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"XlaSend",
|
||||
"XlaSharding",
|
||||
"XlaSort",
|
||||
"XlaSpmdFullToShardShape",
|
||||
"XlaSpmdShardToFullShape",
|
||||
"XlaSvd",
|
||||
"XlaWhile",
|
||||
"_Arg",
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user